From 5a7d40f60420cc005d93234d1b1378fc1dcbe887 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 27 Apr 2023 11:16:51 -0400 Subject: [PATCH] Move the saving of the tokens to the impl and not the callbacks responsibility. --- llm.cpp | 9 --------- llmodel/gptj.cpp | 15 +++++++++++++-- llmodel/llamamodel.cpp | 14 ++++++++++++-- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/llm.cpp b/llm.cpp index a218d40d..bdc36669 100644 --- a/llm.cpp +++ b/llm.cpp @@ -265,10 +265,6 @@ QList LLMObject::modelList() const bool LLMObject::handlePrompt(int32_t token) { - if (s_ctx.tokens.size() == s_ctx.n_ctx) - s_ctx.tokens.erase(s_ctx.tokens.begin()); - s_ctx.tokens.push_back(token); - // m_promptResponseTokens and m_responseLogits are related to last prompt/response not // the entire context window which we can reset on regenerate prompt ++m_promptResponseTokens; @@ -289,11 +285,6 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response) return false; } - // Save the token to our prompt ctxt - if (s_ctx.tokens.size() == s_ctx.n_ctx) - s_ctx.tokens.erase(s_ctx.tokens.begin()); - s_ctx.tokens.push_back(token); - // m_promptResponseTokens and m_responseLogits are related to last prompt/response not // the entire context window which we can reset on regenerate prompt ++m_promptResponseTokens; diff --git a/llmodel/gptj.cpp b/llmodel/gptj.cpp index 36eeaf27..c3ee6585 100644 --- a/llmodel/gptj.cpp +++ b/llmodel/gptj.cpp @@ -753,9 +753,13 @@ void GPTJ::prompt(const std::string &prompt, } size_t tokens = batch_end - i; - for (size_t t = 0; t < tokens; ++t) + for (size_t t = 0; t < tokens; ++t) { + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(batch.at(t)); if (!promptCallback(batch.at(t))) return; + } promptCtx.n_past += batch.size(); i = batch_end; } @@ -806,7 +810,14 @@ void GPTJ::prompt(const std::string &prompt, promptCtx.n_past += 1; // display text ++totalPredictions; - if (id == 50256 /*end of text*/ || !responseCallback(id, d_ptr->vocab.id_to_token[id])) + + if (id == 50256 /*end of text*/) + goto stop_generating; + + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(id); + if (!responseCallback(id, d_ptr->vocab.id_to_token[id])) goto stop_generating; } diff --git a/llmodel/llamamodel.cpp b/llmodel/llamamodel.cpp index 89c230fc..c1638c10 100644 --- a/llmodel/llamamodel.cpp +++ b/llmodel/llamamodel.cpp @@ -139,9 +139,13 @@ void LLamaModel::prompt(const std::string &prompt, } size_t tokens = batch_end - i; - for (size_t t = 0; t < tokens; ++t) + for (size_t t = 0; t < tokens; ++t) { + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(batch.at(t)); if (!promptCallback(batch.at(t))) return; + } promptCtx.n_past += batch.size(); i = batch_end; } @@ -174,7 +178,13 @@ void LLamaModel::prompt(const std::string &prompt, promptCtx.n_past += 1; // display text ++totalPredictions; - if (id == llama_token_eos() || !responseCallback(id, llama_token_to_str(d_ptr->ctx, id))) + if (id == llama_token_eos()) + return; + + if (promptCtx.tokens.size() == promptCtx.n_ctx) + promptCtx.tokens.erase(promptCtx.tokens.begin()); + promptCtx.tokens.push_back(id); + if (!responseCallback(id, llama_token_to_str(d_ptr->ctx, id))) return; } }