mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Add thread count setting
This commit is contained in:
parent
e6cb6a2ae3
commit
f1b87d0b56
8
gptj.cpp
8
gptj.cpp
@ -659,6 +659,14 @@ bool GPTJ::loadModel(const std::string &modelPath, std::istream &fin) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void GPTJ::setThreadCount(int32_t n_threads) {
|
||||
d_ptr->n_threads = n_threads;
|
||||
}
|
||||
|
||||
int32_t GPTJ::threadCount() {
|
||||
return d_ptr->n_threads;
|
||||
}
|
||||
|
||||
GPTJ::~GPTJ()
|
||||
{
|
||||
ggml_free(d_ptr->model.ctx);
|
||||
|
4
gptj.h
4
gptj.h
@ -17,9 +17,11 @@ public:
|
||||
void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 50400, float top_p = 1.0f,
|
||||
float temp = 0.0f, int32_t n_batch = 9) override;
|
||||
void setThreadCount(int32_t n_threads) override;
|
||||
int32_t threadCount() override;
|
||||
|
||||
private:
|
||||
GPTJPrivate *d_ptr;
|
||||
};
|
||||
|
||||
#endif // GPTJ_H
|
||||
#endif // GPTJ_H
|
||||
|
22
llm.cpp
22
llm.cpp
@ -62,6 +62,7 @@ bool LLMObject::loadModelPrivate(const QString &modelName)
|
||||
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
|
||||
m_llmodel->loadModel(modelName.toStdString(), fin);
|
||||
emit isModelLoadedChanged();
|
||||
emit threadCountChanged();
|
||||
}
|
||||
|
||||
if (m_llmodel)
|
||||
@ -70,6 +71,15 @@ bool LLMObject::loadModelPrivate(const QString &modelName)
|
||||
return m_llmodel;
|
||||
}
|
||||
|
||||
void LLMObject::setThreadCount(int32_t n_threads) {
|
||||
m_llmodel->setThreadCount(n_threads);
|
||||
emit threadCountChanged();
|
||||
}
|
||||
|
||||
int32_t LLMObject::threadCount() {
|
||||
return m_llmodel->threadCount();
|
||||
}
|
||||
|
||||
bool LLMObject::isModelLoaded() const
|
||||
{
|
||||
return m_llmodel && m_llmodel->isModelLoaded();
|
||||
@ -225,6 +235,9 @@ LLM::LLM()
|
||||
connect(m_llmodel, &LLMObject::responseStopped, this, &LLM::responseStopped, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLMObject::modelNameChanged, this, &LLM::modelNameChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLMObject::modelListChanged, this, &LLM::modelListChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLMObject::threadCountChanged, this, &LLM::threadCountChanged, Qt::QueuedConnection);
|
||||
|
||||
|
||||
connect(this, &LLM::promptRequested, m_llmodel, &LLMObject::prompt, Qt::QueuedConnection);
|
||||
connect(this, &LLM::modelNameChangeRequested, m_llmodel, &LLMObject::modelNameChangeRequested, Qt::QueuedConnection);
|
||||
|
||||
@ -233,6 +246,7 @@ LLM::LLM()
|
||||
connect(this, &LLM::regenerateResponseRequested, m_llmodel, &LLMObject::regenerateResponse, Qt::BlockingQueuedConnection);
|
||||
connect(this, &LLM::resetResponseRequested, m_llmodel, &LLMObject::resetResponse, Qt::BlockingQueuedConnection);
|
||||
connect(this, &LLM::resetContextRequested, m_llmodel, &LLMObject::resetContext, Qt::BlockingQueuedConnection);
|
||||
connect(this, &LLM::setThreadCountRequested, m_llmodel, &LLMObject::setThreadCount, Qt::QueuedConnection);
|
||||
}
|
||||
|
||||
bool LLM::isModelLoaded() const
|
||||
@ -300,6 +314,14 @@ QList<QString> LLM::modelList() const
|
||||
return m_llmodel->modelList();
|
||||
}
|
||||
|
||||
void LLM::setThreadCount(int32_t n_threads) {
|
||||
emit setThreadCountRequested(n_threads);
|
||||
}
|
||||
|
||||
int32_t LLM::threadCount() {
|
||||
return m_llmodel->threadCount();
|
||||
}
|
||||
|
||||
bool LLM::checkForUpdates() const
|
||||
{
|
||||
#if defined(Q_OS_LINUX)
|
||||
|
10
llm.h
10
llm.h
@ -12,6 +12,8 @@ class LLMObject : public QObject
|
||||
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
|
||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||
Q_PROPERTY(QString modelName READ modelName NOTIFY modelNameChanged)
|
||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||
|
||||
public:
|
||||
|
||||
@ -22,6 +24,8 @@ public:
|
||||
void resetResponse();
|
||||
void resetContext();
|
||||
void stopGenerating() { m_stopGenerating = true; }
|
||||
void setThreadCount(int32_t n_threads);
|
||||
int32_t threadCount();
|
||||
|
||||
QString response() const;
|
||||
QString modelName() const;
|
||||
@ -42,6 +46,7 @@ Q_SIGNALS:
|
||||
void responseStopped();
|
||||
void modelNameChanged();
|
||||
void modelListChanged();
|
||||
void threadCountChanged();
|
||||
|
||||
private:
|
||||
bool loadModelPrivate(const QString &modelName);
|
||||
@ -65,6 +70,7 @@ class LLM : public QObject
|
||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||
public:
|
||||
|
||||
static LLM *globalInstance();
|
||||
@ -76,6 +82,8 @@ public:
|
||||
Q_INVOKABLE void resetResponse();
|
||||
Q_INVOKABLE void resetContext();
|
||||
Q_INVOKABLE void stopGenerating();
|
||||
Q_INVOKABLE void setThreadCount(int32_t n_threads);
|
||||
Q_INVOKABLE int32_t threadCount();
|
||||
|
||||
QString response() const;
|
||||
bool responseInProgress() const { return m_responseInProgress; }
|
||||
@ -99,6 +107,8 @@ Q_SIGNALS:
|
||||
void modelNameChangeRequested(const QString &modelName);
|
||||
void modelNameChanged();
|
||||
void modelListChanged();
|
||||
void threadCountChanged();
|
||||
void setThreadCountRequested(int32_t threadCount);
|
||||
|
||||
private Q_SLOTS:
|
||||
void responseStarted();
|
||||
|
@ -19,6 +19,8 @@ public:
|
||||
virtual void prompt(const std::string &prompt, std::function<bool(const std::string&)> response,
|
||||
PromptContext &ctx, int32_t n_predict = 200, int32_t top_k = 40, float top_p = 0.9f,
|
||||
float temp = 0.9f, int32_t n_batch = 9) = 0;
|
||||
virtual void setThreadCount(int32_t n_threads);
|
||||
virtual int32_t threadCount();
|
||||
};
|
||||
|
||||
#endif // LLMODEL_H
|
||||
#endif // LLMODEL_H
|
||||
|
38
main.qml
38
main.qml
@ -107,7 +107,6 @@ Window {
|
||||
property int defaultTopK: 40
|
||||
property int defaultMaxLength: 4096
|
||||
property int defaultPromptBatchSize: 9
|
||||
|
||||
property string defaultPromptTemplate: "The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
|
||||
### Prompt:
|
||||
%1
|
||||
@ -141,7 +140,7 @@ Window {
|
||||
|
||||
GridLayout {
|
||||
columns: 2
|
||||
rowSpacing: 10
|
||||
rowSpacing: 2
|
||||
columnSpacing: 10
|
||||
anchors.fill: parent
|
||||
|
||||
@ -278,14 +277,41 @@ Window {
|
||||
}
|
||||
|
||||
Label {
|
||||
id: promptTemplateLabel
|
||||
text: qsTr("Prompt Template:")
|
||||
id: nThreadsLabel
|
||||
text: qsTr("CPU Threads")
|
||||
Layout.row: 5
|
||||
Layout.column: 0
|
||||
}
|
||||
Rectangle {
|
||||
TextField {
|
||||
text: LLM.threadCount.toString()
|
||||
ToolTip.text: qsTr("Amount of processing threads to use")
|
||||
ToolTip.visible: hovered
|
||||
Layout.row: 5
|
||||
Layout.column: 1
|
||||
validator: IntValidator { bottom: 1 }
|
||||
onAccepted: {
|
||||
var val = parseInt(text)
|
||||
if (!isNaN(val)) {
|
||||
LLM.threadCount = val
|
||||
focus = false
|
||||
} else {
|
||||
text = settingsDialog.nThreads.toString()
|
||||
}
|
||||
}
|
||||
Accessible.role: Accessible.EditableText
|
||||
Accessible.name: nThreadsLabel.text
|
||||
Accessible.description: ToolTip.text
|
||||
}
|
||||
|
||||
Label {
|
||||
id: promptTemplateLabel
|
||||
text: qsTr("Prompt Template:")
|
||||
Layout.row: 6
|
||||
Layout.column: 0
|
||||
}
|
||||
Rectangle {
|
||||
Layout.row: 6
|
||||
Layout.column: 1
|
||||
Layout.fillWidth: true
|
||||
height: 200
|
||||
color: "transparent"
|
||||
@ -319,7 +345,7 @@ Window {
|
||||
}
|
||||
}
|
||||
Button {
|
||||
Layout.row: 6
|
||||
Layout.row: 7
|
||||
Layout.column: 1
|
||||
Layout.fillWidth: true
|
||||
padding: 15
|
||||
|
Loading…
Reference in New Issue
Block a user