Add thread count setting

This commit is contained in:
Aaron Miller 2023-04-18 06:46:03 -07:00 committed by AT
parent e6cb6a2ae3
commit f1b87d0b56
6 changed files with 78 additions and 8 deletions

View File

@ -659,6 +659,14 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
return true;
}
void GPTJ::setThreadCount(int32_t n_threads) {
d_ptr->n_threads = n_threads;
}
int32_t GPTJ::threadCount() {
return d_ptr->n_threads;
}
GPTJ::~GPTJ()
{
ggml_free(d_ptr->model.ctx);

4
gptj.h
View File

@ -17,9 +17,11 @@ public:
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
float temp = 0.0f, int32_t n_batch = 9) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() override;
private:
GPTJPrivate *d_ptr;
};
#endif // GPTJ_H
#endif // GPTJ_H

22
llm.cpp
View File

@ -62,6 +62,7 @@ bool LLMObject::loadModelPrivate(const QString &modelName)
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
m_llmodel->loadModel(modelName.toStdString(), fin);
emit isModelLoadedChanged();
emit threadCountChanged();
}
if (m_llmodel)
@ -70,6 +71,15 @@ bool LLMObject::loadModelPrivate(const QString &modelName)
return m_llmodel;
}
void LLMObject::setThreadCount(int32_t n_threads) {
m_llmodel->setThreadCount(n_threads);
emit threadCountChanged();
}
int32_t LLMObject::threadCount() {
return m_llmodel->threadCount();
}
bool LLMObject::isModelLoaded() const
{
return m_llmodel && m_llmodel->isModelLoaded();
@ -225,6 +235,9 @@ LLM::LLM()
connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection);
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
@ -233,6 +246,7 @@ LLM::LLM()
connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
connect(this, &LLM::setThreadCountRequested, m_llmodel, &LLMObject::setThreadCount, Qt::QueuedConnection);
}
bool LLM::isModelLoaded() const
@ -300,6 +314,14 @@ QList<QString> LLM::modelList() const
return m_llmodel->modelList();
}
void LLM::setThreadCount(int32_t n_threads) {
emit setThreadCountRequested(n_threads);
}
int32_t LLM::threadCount() {
return m_llmodel->threadCount();
}
bool LLM::checkForUpdates() const
{
#if defined(Q_OS_LINUX)

10
llm.h
View File

@ -12,6 +12,8 @@ class LLMObject : public QObject
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
public:
@ -22,6 +24,8 @@ public:
void resetResponse();
void resetContext();
void stopGenerating() { m_stopGenerating = true; }
void setThreadCount(int32_t n_threads);
int32_t threadCount();
QString response() const;
QString modelName() const;
@ -42,6 +46,7 @@ Q_SIGNALS:
void responseStopped();
void modelNameChanged();
void modelListChanged();
void threadCountChanged();
private:
bool loadModelPrivate(const QString &modelName);
@ -65,6 +70,7 @@ class LLM : public QObject
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
public:
static LLM *globalInstance();
@ -76,6 +82,8 @@ public:
Q_INVOKABLE void resetResponse();
Q_INVOKABLE void resetContext();
Q_INVOKABLE void stopGenerating();
Q_INVOKABLE void setThreadCount(int32_t n_threads);
Q_INVOKABLE int32_t threadCount();
QString response() const;
bool responseInProgress() const { return m_responseInProgress; }
@ -99,6 +107,8 @@ Q_SIGNALS:
void modelNameChangeRequested(const QString &modelName);
void modelNameChanged();
void modelListChanged();
void threadCountChanged();
void setThreadCountRequested(int32_t threadCount);
private Q_SLOTS:
void responseStarted();

View File

@ -19,6 +19,8 @@ public:
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
float temp = 0.9f, int32_t n_batch = 9) = 0;
virtual void setThreadCount(int32_t n_threads);
virtual int32_t threadCount();
};
#endif // LLMODEL_H
#endif // LLMODEL_H

View File

@ -107,7 +107,6 @@ Window {
property int defaultTopK: 40
property int defaultMaxLength: 4096
property int defaultPromptBatchSize: 9
property string defaultPromptTemplate: "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
### Prompt:
%1
@ -141,7 +140,7 @@ Window {
GridLayout {
columns: 2
rowSpacing: 10
rowSpacing: 2
columnSpacing: 10
anchors.fill: parent
@ -278,14 +277,41 @@ Window {
}
Label {
id: promptTemplateLabel
text: qsTr("Prompt Template:")
id: nThreadsLabel
text: qsTr("CPU Threads")
Layout.row: 5
Layout.column: 0
}
Rectangle {
TextField {
text: LLM.threadCount.toString()
ToolTip.text: qsTr("Amount of processing threads to use")
ToolTip.visible: hovered
Layout.row: 5
Layout.column: 1
validator: IntValidator { bottom: 1 }
onAccepted: {
var val = parseInt(text)
if (!isNaN(val)) {
LLM.threadCount = val
focus = false
} else {
text = settingsDialog.nThreads.toString()
}
}
Accessible.role: Accessible.EditableText
Accessible.name: nThreadsLabel.text
Accessible.description: ToolTip.text
}
Label {
id: promptTemplateLabel
text: qsTr("Prompt Template:")
Layout.row: 6
Layout.column: 0
}
Rectangle {
Layout.row: 6
Layout.column: 1
Layout.fillWidth: true
height: 200
color: "transparent"
@ -319,7 +345,7 @@ Window {
}
}
Button {
Layout.row: 6
Layout.row: 7
Layout.column: 1
Layout.fillWidth: true
padding: 15