diff --git a/gptj.cpp b/gptj.cpp index aa7db13e..34aa16f9 100644 --- a/gptj.cpp +++ b/gptj.cpp @@ -707,9 +707,11 @@ void GPTJ::prompt(const std::string &prompt, std::function response, - PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f, - float temp = 0.9f, int32_t n_batch = 9) override; + PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f, + float temp = 0.0f, int32_t n_batch = 9) override; private: GPTJPrivate *d_ptr; diff --git a/llm.cpp b/llm.cpp index 38ea62e3..b44d9ca3 100644 --- a/llm.cpp +++ b/llm.cpp @@ -19,6 +19,7 @@ static LLModel::PromptContext s_ctx; LLMObject::LLMObject() : QObject{nullptr} , m_llmodel(new GPTJ) + , m_responseTokens(0) { moveToThread(&m_llmThread); connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel); @@ -64,6 +65,9 @@ bool LLMObject::isModelLoaded() const void LLMObject::resetResponse() { + s_ctx.n_past -= m_responseTokens; + s_ctx.logits.erase(s_ctx.logits.end() -= m_responseTokens, s_ctx.logits.end()); + m_responseTokens = 0; m_response = std::string(); emit responseChanged(); } @@ -89,6 +93,7 @@ bool LLMObject::handleResponse(const std::string &response) printf("%s", response.c_str()); fflush(stdout); #endif + ++m_responseTokens; if (!response.empty()) { m_response.append(response); emit responseChanged(); diff --git a/llm.h b/llm.h index 3740723d..d47ab148 100644 --- a/llm.h +++ b/llm.h @@ -41,6 +41,7 @@ private: private: LLModel *m_llmodel; std::string m_response; + quint32 m_responseTokens; QString m_modelName; QThread m_llmThread; std::atomic m_stopGenerating;