From 12f943e9668c4b3b08e71c2e5de5605a7615a47e Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 3 Oct 2023 12:42:31 -0400 Subject: [PATCH] Fix regenerate button to be deterministic and bump the llama version to latest we have for gguf. --- gpt4all-backend/llama.cpp-mainline | 2 +- gpt4all-backend/llamamodel.cpp | 11 +---------- gpt4all-backend/llmodel_shared.cpp | 5 ++--- gpt4all-chat/chatllm.cpp | 2 +- 4 files changed, 5 insertions(+), 15 deletions(-) diff --git a/gpt4all-backend/llama.cpp-mainline b/gpt4all-backend/llama.cpp-mainline index 37a0be31..70a6537c 160000 --- a/gpt4all-backend/llama.cpp-mainline +++ b/gpt4all-backend/llama.cpp-mainline @@ -1 +1 @@ -Subproject commit 37a0be313d21f8b61184a3adcaac123353128238 +Subproject commit 70a6537c4aae9951ba7fff740135ca7dbe14d0f1 diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index 980a53dc..d19c9d97 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -249,16 +249,7 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &tokens) const { - // When we recalculate context we could have erased the original BOS token... we need to replace it - const bool useBOS = ctx.n_past == 0 && (ctx.tokens.empty() || ctx.tokens.front() != llama_token_bos(d_ptr->ctx)); - if (useBOS) { - std::vector myTokens; - myTokens.push_back(llama_token_bos(d_ptr->ctx)); - myTokens.insert(myTokens.end(), tokens.begin(), tokens.end()); - ctx.n_past += 1; - return llama_eval(d_ptr->ctx, myTokens.data(), myTokens.size(), ctx.n_past, d_ptr->n_threads) == 0; - } else - return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0; + return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0; } int32_t LLamaModel::contextLength() const diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index f3022f7e..74e69786 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -92,10 +92,10 @@ void LLModel::prompt(const std::string &prompt, if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) promptCtx.tokens.erase(promptCtx.tokens.begin()); promptCtx.tokens.push_back(batch.at(t)); + promptCtx.n_past += 1; if (!promptCallback(batch.at(t))) return; } - promptCtx.n_past += batch.size(); i = batch_end; } @@ -126,8 +126,6 @@ void LLModel::prompt(const std::string &prompt, return; } - promptCtx.n_past += 1; - // display text for (const auto token : endTokens()) { if (id == token) return; @@ -162,6 +160,7 @@ void LLModel::prompt(const std::string &prompt, if (int32_t(promptCtx.tokens.size()) == promptCtx.n_ctx) promptCtx.tokens.erase(promptCtx.tokens.begin()); promptCtx.tokens.push_back(t); + promptCtx.n_past += 1; //TODO: Conversion to std::string can be avoided here... if (!responseCallback(t, std::string(tokenToString(t)))) return; diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 3212d51c..b0d1b6f1 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -371,7 +371,7 @@ void ChatLLM::regenerateResponse() else m_ctx.n_past -= m_promptResponseTokens; m_ctx.n_past = std::max(0, m_ctx.n_past); - m_ctx.tokens.erase(m_ctx.tokens.end() -= m_promptResponseTokens, m_ctx.tokens.end()); + m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end()); m_promptResponseTokens = 0; m_promptTokens = 0; m_response = std::string();