Infinite context window through trimming.

This commit is contained in:
Adam Treat 2023-04-25 11:20:51 -04:00
parent a79bc4233c
commit cf8a4dd868
9 changed files with 187 additions and 29 deletions

View File

@ -635,6 +635,7 @@ struct GPTJPrivate {
gpt_vocab vocab;
gptj_model model;
int64_t n_threads = 0;
size_t mem_per_token = 0;
std::mt19937 rng;
};
@ -662,6 +663,7 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->modelLoaded = true;
fflush(stdout);
return true;
}
@ -685,6 +687,7 @@ bool GPTJ::isModelLoaded() const
void GPTJ::prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
PromptContext &promptCtx) {
if (!isModelLoaded()) {
@ -711,9 +714,9 @@ void GPTJ::prompt(const std::string &prompt,
static bool initialized = false;
static std::vector<gpt_vocab::id> p_instruct;
static std::vector<gpt_vocab::id> r_instruct;
size_t mem_per_token = 0;
if (!initialized) {
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits, mem_per_token);
gptj_eval(d_ptr->model, d_ptr->n_threads, 0, { 0, 1, 2, 3 }, promptCtx.logits,
d_ptr->mem_per_token);
initialized = true;
}
@ -726,12 +729,17 @@ void GPTJ::prompt(const std::string &prompt,
// Check if the context has run out...
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
}
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, mem_per_token)) {
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
d_ptr->mem_per_token)) {
std::cerr << "GPT-J ERROR: Failed to process prompt\n";
return;
}
@ -770,13 +778,18 @@ void GPTJ::prompt(const std::string &prompt,
// Check if the context has run out...
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
std::cerr << "GPT-J WARNING: reached the end of the context window!\n";
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << "GPTJ: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
}
const int64_t t_start_predict_us = ggml_time_us();
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, mem_per_token)) {
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits,
d_ptr->mem_per_token)) {
std::cerr << "GPT-J ERROR: Failed to predict next token\n";
return;
}
@ -807,3 +820,29 @@ stop_generating:
return;
}
void GPTJ::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
{
size_t i = 0;
promptCtx.n_past = 0;
while (i < promptCtx.tokens.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<gpt_vocab::id> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
if (!gptj_eval(d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits,
d_ptr->mem_per_token)) {
std::cerr << "GPTJ ERROR: Failed to process prompt\n";
goto stop_generating;
}
promptCtx.n_past += batch.size();
if (!recalculate(true))
goto stop_generating;
i = batch_end;
}
assert(promptCtx.n_past == promptCtx.tokens.size());
stop_generating:
recalculate(false);
}

5
gptj.h
View File

@ -17,10 +17,15 @@ public:
bool isModelLoaded() const override;
void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
PromptContext &ctx) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override;
protected:
void recalculateContext(PromptContext &promptCtx,
std::function<bool(bool)> recalculate) override;
private:
GPTJPrivate *d_ptr;
};

View File

@ -58,6 +58,7 @@ bool LLamaModel::loadModel(const std::string &modelPath)
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
d_ptr->modelLoaded = true;
fflush(stderr);
return true;
}
@ -80,6 +81,7 @@ bool LLamaModel::isModelLoaded() const
void LLamaModel::prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
PromptContext &promptCtx) {
if (!isModelLoaded()) {
@ -119,9 +121,13 @@ void LLamaModel::prompt(const std::string &prompt,
// Check if the context has run out...
if (promptCtx.n_past + batch.size() > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this
promptCtx.n_past = std::min(promptCtx.n_past, int(promptCtx.n_ctx - batch.size()));
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
}
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
@ -149,9 +155,13 @@ void LLamaModel::prompt(const std::string &prompt,
// Check if the context has run out...
if (promptCtx.n_past + 1 > promptCtx.n_ctx) {
// FIXME: will produce gibberish after this
promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx - 1);
std::cerr << "LLAMA WARNING: reached the end of the context window!\n";
const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase;
// Erase the first percentage of context from the tokens...
std::cerr << "LLAMA: reached the end of the context window so resizing\n";
promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint);
promptCtx.n_past = promptCtx.tokens.size();
recalculateContext(promptCtx, recalculate);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
}
if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) {
@ -166,3 +176,28 @@ void LLamaModel::prompt(const std::string &prompt,
return;
}
}
void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function<bool(bool)> recalculate)
{
size_t i = 0;
promptCtx.n_past = 0;
while (i < promptCtx.tokens.size()) {
size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size());
std::vector<llama_token> batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end);
assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx);
if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) {
std::cerr << "LLAMA ERROR: Failed to process prompt\n";
goto stop_generating;
}
promptCtx.n_past += batch.size();
if (!recalculate(true))
goto stop_generating;
i = batch_end;
}
assert(promptCtx.n_past == promptCtx.tokens.size());
stop_generating:
recalculate(false);
}

View File

@ -17,10 +17,15 @@ public:
bool isModelLoaded() const override;
void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
PromptContext &ctx) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override;
protected:
void recalculateContext(PromptContext &promptCtx,
std::function<bool(bool)> recalculate) override;
private:
LLamaPrivate *d_ptr;
};

22
llm.cpp
View File

@ -39,6 +39,7 @@ LLMObject::LLMObject()
, m_llmodel(nullptr)
, m_responseTokens(0)
, m_responseLogits(0)
, m_isRecalc(false)
{
moveToThread(&m_llmThread);
connect(&m_llmThread, &QThread::started, this, &LLMObject::loadModel);
@ -271,6 +272,15 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response)
return !m_stopGenerating;
}
bool LLMObject::handleRecalculate(bool isRecalc)
{
if (m_isRecalc != isRecalc) {
m_isRecalc = isRecalc;
emit recalcChanged();
}
return !m_stopGenerating;
}
bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch)
{
@ -280,7 +290,9 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
QString instructPrompt = prompt_template.arg(prompt);
m_stopGenerating = false;
auto func = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2);
auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1);
emit responseStarted();
qint32 logitsBefore = s_ctx.logits.size();
s_ctx.n_predict = n_predict;
@ -288,7 +300,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
s_ctx.top_p = top_p;
s_ctx.temp = temp;
s_ctx.n_batch = n_batch;
m_llmodel->prompt(instructPrompt.toStdString(), func, s_ctx);
m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx);
m_responseLogits += s_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) {
@ -314,7 +326,7 @@ LLM::LLM()
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::syncThreadCount, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::recalcChanged, this, &LLM::recalcChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
@ -428,3 +440,7 @@ bool LLM::checkForUpdates() const
return QProcess::startDetached(fileName);
}
bool LLM::isRecalc() const
{
return m_llmodel->isRecalc();
}

11
llm.h
View File

@ -14,6 +14,7 @@ class LLMObject : public QObject
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
public:
@ -33,6 +34,8 @@ public:
QList<QString> modelList() const;
void setModelName(const QString &modelName);
bool isRecalc() const { return m_isRecalc; }
public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch);
@ -47,10 +50,12 @@ Q_SIGNALS:
void modelNameChanged();
void modelListChanged();
void threadCountChanged();
void recalcChanged();
private:
bool loadModelPrivate(const QString &modelName);
bool handleResponse(int32_t token, const std::string &response);
bool handleRecalculate(bool isRecalc);
private:
LLModel *m_llmodel;
@ -60,6 +65,7 @@ private:
QString m_modelName;
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
bool m_isRecalc;
};
class LLM : public QObject
@ -71,6 +77,8 @@ class LLM : public QObject
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
public:
static LLM *globalInstance();
@ -96,6 +104,8 @@ public:
Q_INVOKABLE bool checkForUpdates() const;
bool isRecalc() const;
Q_SIGNALS:
void isModelLoadedChanged();
void responseChanged();
@ -110,6 +120,7 @@ Q_SIGNALS:
void modelListChanged();
void threadCountChanged();
void setThreadCountRequested(int32_t threadCount);
void recalcChanged();
private Q_SLOTS:
void responseStarted();

View File

@ -25,13 +25,19 @@ public:
int32_t n_batch = 9;
float repeat_penalty = 1.10f;
int32_t repeat_last_n = 64; // last n tokens to penalize
float contextErase = 0.75f; // percent of context to erase if we exceed the context
// window
};
virtual void prompt(const std::string &prompt,
std::function<bool(int32_t, const std::string&)> response,
std::function<bool(bool)> recalculate,
PromptContext &ctx) = 0;
virtual void setThreadCount(int32_t n_threads) {}
virtual int32_t threadCount() { return 1; }
protected:
virtual void recalculateContext(PromptContext &promptCtx,
std::function<bool(bool)> recalculate) = 0;
};
#endif // LLMODEL_H

View File

@ -288,6 +288,24 @@ Window {
text: qsTr("Connection to datalake failed.")
}
PopupDialog {
id: recalcPopup
anchors.centerIn: parent
shouldTimeOut: false
shouldShowBusy: true
text: qsTr("Recalculating context.")
Connections {
target: LLM
function onRecalcChanged() {
if (LLM.isRecalc)
recalcPopup.open()
else
recalcPopup.close()
}
}
}
Button {
id: copyButton
anchors.right: settingsButton.left

View File

@ -7,23 +7,45 @@ import QtQuick.Layouts
Dialog {
id: popupDialog
anchors.centerIn: parent
modal: false
opacity: 0.9
padding: 20
property alias text: textField.text
property bool shouldTimeOut: true
property bool shouldShowBusy: false
modal: shouldShowBusy
closePolicy: shouldShowBusy ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside)
Theme {
id: theme
}
Text {
id: textField
horizontalAlignment: Text.AlignJustify
color: theme.textColor
Accessible.role: Accessible.HelpBalloon
Accessible.name: text
Accessible.description: qsTr("Reveals a shortlived help balloon")
Row {
anchors.centerIn: parent
width: childrenRect.width
height: childrenRect.height
spacing: 20
Text {
id: textField
anchors.verticalCenter: busyIndicator.verticalCenter
horizontalAlignment: Text.AlignJustify
color: theme.textColor
Accessible.role: Accessible.HelpBalloon
Accessible.name: text
Accessible.description: qsTr("Reveals a shortlived help balloon")
}
BusyIndicator {
id: busyIndicator
visible: shouldShowBusy
running: shouldShowBusy
Accessible.role: Accessible.Animation
Accessible.name: qsTr("Busy indicator")
Accessible.description: qsTr("Displayed when the popup is showing busy")
}
}
background: Rectangle {
anchors.fill: parent
color: theme.backgroundDarkest
@ -37,7 +59,8 @@ Dialog {
}
onOpened: {
timer.start()
if (shouldTimeOut)
timer.start()
}
Timer {