Erase the correct amount of logits when regenerating which is not the same

as the number of tokens.
This commit is contained in:
Adam Treat 2023-04-15 09:19:06 -04:00
parent 12bf78bf24
commit 2f3a46c17f
2 changed files with 6 additions and 1 deletions

View File

@ -20,6 +20,7 @@ LLMObject::LLMObject()
: QObject{nullptr} : QObject{nullptr}
, m_llmodel(new GPTJ) , m_llmodel(new GPTJ)
, m_responseTokens(0) , m_responseTokens(0)
, m_responseLogits(0)
{ {
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel); connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel);
@ -66,8 +67,9 @@ bool LLMObject::isModelLoaded() const
void LLMObject::resetResponse() void LLMObject::resetResponse()
{ {
s_ctx.n_past -= m_responseTokens; s_ctx.n_past -= m_responseTokens;
s_ctx.logits.erase(s_ctx.logits.end() -= m_responseTokens, s_ctx.logits.end()); s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end());
m_responseTokens = 0; m_responseTokens = 0;
m_responseLogits = 0;
m_response = std::string(); m_response = std::string();
emit responseChanged(); emit responseChanged();
} }
@ -110,7 +112,9 @@ bool LLMObject::prompt(const QString &prompt)
m_stopGenerating = false; m_stopGenerating = false;
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1); auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1);
emit responseStarted(); emit responseStarted();
qint32 logitsBefore = s_ctx.logits.size();
m_llmodel->prompt(prompt.toStdString(), func, s_ctx, 4096 /*number of chars to predict*/); m_llmodel->prompt(prompt.toStdString(), func, s_ctx, 4096 /*number of chars to predict*/);
m_responseLogits += s_ctx.logits.size() - logitsBefore;
emit responseStopped(); emit responseStopped();
return true; return true;
} }

1
llm.h
View File

@ -42,6 +42,7 @@ private:
LLModel *m_llmodel; LLModel *m_llmodel;
std::string m_response; std::string m_response;
quint32 m_responseTokens; quint32 m_responseTokens;
quint32 m_responseLogits;
QString m_modelName; QString m_modelName;
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;