new settings (model path, repeat penalty) w/ tabs

This commit is contained in:
Aaron Miller 2023-04-25 07:57:40 -07:00 committed by AT
parent cd03c5b7d5
commit 15a979b327
7 changed files with 504 additions and 294 deletions

View File

@ -27,7 +27,9 @@ Download::Download()
&Download::handleHashAndSaveFinished, Qt::QueuedConnection); &Download::handleHashAndSaveFinished, Qt::QueuedConnection);
connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this, connect(&m_networkManager, &QNetworkAccessManager::sslErrors, this,
&Download::handleSslErrors); &Download::handleSslErrors);
connect(this, &Download::downloadLocalModelsPathChanged, this, &Download::updateModelList);
updateModelList(); updateModelList();
m_downloadLocalModelsPath = defaultLocalModelsPath();
} }
QList<ModelInfo> Download::modelList() const QList<ModelInfo> Download::modelList() const
@ -46,7 +48,22 @@ QList<ModelInfo> Download::modelList() const
return values; return values;
} }
QString Download::downloadLocalModelsPath() const QString Download::downloadLocalModelsPath() const {
return m_downloadLocalModelsPath;
}
void Download::setDownloadLocalModelsPath(const QString &modelPath) {
QString filePath = (modelPath.startsWith("file://") ?
QUrl(modelPath).toLocalFile() : modelPath);
QString canonical = QFileInfo(filePath).canonicalFilePath() + QDir::separator();
qDebug() << "Set model path: " << canonical;
if (m_downloadLocalModelsPath != canonical) {
m_downloadLocalModelsPath = canonical;
emit downloadLocalModelsPathChanged();
}
}
QString Download::defaultLocalModelsPath() const
{ {
QString localPath = QStandardPaths::writableLocation(QStandardPaths::AppLocalDataLocation) QString localPath = QStandardPaths::writableLocation(QStandardPaths::AppLocalDataLocation)
+ QDir::separator(); + QDir::separator();

View File

@ -50,6 +50,9 @@ class Download : public QObject
{ {
Q_OBJECT Q_OBJECT
Q_PROPERTY(QList<ModelInfo> modelList READ modelList NOTIFY modelListChanged) Q_PROPERTY(QList<ModelInfo> modelList READ modelList NOTIFY modelListChanged)
Q_PROPERTY(QString downloadLocalModelsPath READ downloadLocalModelsPath
WRITE setDownloadLocalModelsPath
NOTIFY downloadLocalModelsPathChanged)
public: public:
static Download *globalInstance(); static Download *globalInstance();
@ -58,7 +61,9 @@ public:
Q_INVOKABLE void updateModelList(); Q_INVOKABLE void updateModelList();
Q_INVOKABLE void downloadModel(const QString &modelFile); Q_INVOKABLE void downloadModel(const QString &modelFile);
Q_INVOKABLE void cancelDownload(const QString &modelFile); Q_INVOKABLE void cancelDownload(const QString &modelFile);
Q_INVOKABLE QString defaultLocalModelsPath() const;
Q_INVOKABLE QString downloadLocalModelsPath() const; Q_INVOKABLE QString downloadLocalModelsPath() const;
Q_INVOKABLE void setDownloadLocalModelsPath(const QString &modelPath);
private Q_SLOTS: private Q_SLOTS:
void handleSslErrors(QNetworkReply *reply, const QList<QSslError> &errors); void handleSslErrors(QNetworkReply *reply, const QList<QSslError> &errors);
@ -73,6 +78,7 @@ Q_SIGNALS:
void downloadProgress(qint64 bytesReceived, qint64 bytesTotal, const QString &modelFile); void downloadProgress(qint64 bytesReceived, qint64 bytesTotal, const QString &modelFile);
void downloadFinished(const QString &modelFile); void downloadFinished(const QString &modelFile);
void modelListChanged(); void modelListChanged();
void downloadLocalModelsPathChanged();
void requestHashAndSave(const QString &hash, const QString &saveFilePath, void requestHashAndSave(const QString &hash, const QString &saveFilePath,
QTemporaryFile *tempFile, QNetworkReply *modelReply); QTemporaryFile *tempFile, QNetworkReply *modelReply);
@ -83,6 +89,7 @@ private:
QMap<QString, ModelInfo> m_modelMap; QMap<QString, ModelInfo> m_modelMap;
QNetworkAccessManager m_networkManager; QNetworkAccessManager m_networkManager;
QMap<QNetworkReply*, QTemporaryFile*> m_activeDownloads; QMap<QNetworkReply*, QTemporaryFile*> m_activeDownloads;
QString m_downloadLocalModelsPath;
private: private:
explicit Download(); explicit Download();

View File

@ -282,7 +282,7 @@ bool LLMObject::handleRecalculate(bool isRecalc)
} }
bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch) float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
{ {
if (!isModelLoaded()) if (!isModelLoaded())
return false; return false;
@ -300,6 +300,8 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
s_ctx.top_p = top_p; s_ctx.top_p = top_p;
s_ctx.temp = temp; s_ctx.temp = temp;
s_ctx.n_batch = n_batch; s_ctx.n_batch = n_batch;
s_ctx.repeat_penalty = repeat_penalty;
s_ctx.repeat_last_n = repeat_penalty_tokens;
m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx); m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx);
m_responseLogits += s_ctx.logits.size() - logitsBefore; m_responseLogits += s_ctx.logits.size() - logitsBefore;
std::string trimmed = trim_whitespace(m_response); std::string trimmed = trim_whitespace(m_response);
@ -345,9 +347,9 @@ bool LLM::isModelLoaded() const
} }
void LLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, void LLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch) float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
{ {
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch); emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens);
} }
void LLM::regenerateResponse() void LLM::regenerateResponse()

6
llm.h
View File

@ -38,7 +38,7 @@ public:
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch); float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
bool loadModel(); bool loadModel();
void modelNameChangeRequested(const QString &modelName); void modelNameChangeRequested(const QString &modelName);
@ -85,7 +85,7 @@ public:
Q_INVOKABLE bool isModelLoaded() const; Q_INVOKABLE bool isModelLoaded() const;
Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch); float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
Q_INVOKABLE void regenerateResponse(); Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void resetResponse(); Q_INVOKABLE void resetResponse();
Q_INVOKABLE void resetContext(); Q_INVOKABLE void resetContext();
@ -111,7 +111,7 @@ Q_SIGNALS:
void responseChanged(); void responseChanged();
void responseInProgressChanged(); void responseInProgressChanged();
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
float temp, int32_t n_batch); float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
void regenerateResponseRequested(); void regenerateResponseRequested();
void resetResponseRequested(); void resetResponseRequested();
void resetContextRequested(); void resetContextRequested();

View File

@ -824,7 +824,9 @@ Window {
settingsDialog.maxLength, settingsDialog.maxLength,
settingsDialog.topK, settingsDialog.topP, settingsDialog.topK, settingsDialog.topP,
settingsDialog.temperature, settingsDialog.temperature,
settingsDialog.promptBatchSize) settingsDialog.promptBatchSize,
settingsDialog.repeatPenalty,
settingsDialog.repeatPenaltyTokens)
} }
} }
} }
@ -905,7 +907,9 @@ Window {
settingsDialog.topK, settingsDialog.topK,
settingsDialog.topP, settingsDialog.topP,
settingsDialog.temperature, settingsDialog.temperature,
settingsDialog.promptBatchSize) settingsDialog.promptBatchSize,
settingsDialog.repeatPenalty,
settingsDialog.repeatPenaltyTokens)
textInput.text = "" textInput.text = ""
} }
} }

View File

@ -293,7 +293,7 @@ Dialog {
Label { Label {
Layout.alignment: Qt.AlignLeft Layout.alignment: Qt.AlignLeft
Layout.fillWidth: true Layout.fillWidth: true
text: qsTr("NOTE: models will be downloaded to\n") + Download.downloadLocalModelsPath() text: qsTr("NOTE: models will be downloaded to\n") + Download.downloadLocalModelsPath
wrapMode: Text.WrapAnywhere wrapMode: Text.WrapAnywhere
horizontalAlignment: Text.AlignHCenter horizontalAlignment: Text.AlignHCenter
color: theme.textColor color: theme.textColor

View File

@ -2,6 +2,7 @@ import QtCore
import QtQuick import QtQuick
import QtQuick.Controls import QtQuick.Controls
import QtQuick.Controls.Basic import QtQuick.Controls.Basic
import QtQuick.Dialogs
import QtQuick.Layouts import QtQuick.Layouts
import download import download
import network import network
@ -31,12 +32,15 @@ Dialog {
property int defaultTopK: 40 property int defaultTopK: 40
property int defaultMaxLength: 4096 property int defaultMaxLength: 4096
property int defaultPromptBatchSize: 9 property int defaultPromptBatchSize: 9
property real defaultRepeatPenalty: 1.10
property int defaultRepeatPenaltyTokens: 64
property int defaultThreadCount: 0 property int defaultThreadCount: 0
property string defaultPromptTemplate: "### Instruction: property string defaultPromptTemplate: "### Instruction:
The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.
### Prompt: ### Prompt:
%1 %1
### Response:\n" ### Response:\n"
property string defaultModelPath: Download.defaultLocalModelsPath()
property alias temperature: settings.temperature property alias temperature: settings.temperature
property alias topP: settings.topP property alias topP: settings.topP
@ -44,7 +48,10 @@ The prompt below is a question to answer, a task to complete, or a conversation
property alias maxLength: settings.maxLength property alias maxLength: settings.maxLength
property alias promptBatchSize: settings.promptBatchSize property alias promptBatchSize: settings.promptBatchSize
property alias promptTemplate: settings.promptTemplate property alias promptTemplate: settings.promptTemplate
property alias repeatPenalty: settings.repeatPenalty
property alias repeatPenaltyTokens: settings.repeatPenaltyTokens
property alias threadCount: settings.threadCount property alias threadCount: settings.threadCount
property alias modelPath: settings.modelPath
Settings { Settings {
id: settings id: settings
@ -54,23 +61,34 @@ The prompt below is a question to answer, a task to complete, or a conversation
property int maxLength: settingsDialog.defaultMaxLength property int maxLength: settingsDialog.defaultMaxLength
property int promptBatchSize: settingsDialog.defaultPromptBatchSize property int promptBatchSize: settingsDialog.defaultPromptBatchSize
property int threadCount: settingsDialog.defaultThreadCount property int threadCount: settingsDialog.defaultThreadCount
property real repeatPenalty: settingsDialog.defaultRepeatPenalty
property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens
property string promptTemplate: settingsDialog.defaultPromptTemplate property string promptTemplate: settingsDialog.defaultPromptTemplate
property string modelPath: settingsDialog.defaultModelPath
} }
function restoreDefaults() {
function restoreGenerationDefaults() {
settings.temperature = defaultTemperature; settings.temperature = defaultTemperature;
settings.topP = defaultTopP; settings.topP = defaultTopP;
settings.topK = defaultTopK; settings.topK = defaultTopK;
settings.maxLength = defaultMaxLength; settings.maxLength = defaultMaxLength;
settings.promptBatchSize = defaultPromptBatchSize; settings.promptBatchSize = defaultPromptBatchSize;
settings.promptTemplate = defaultPromptTemplate; settings.promptTemplate = defaultPromptTemplate;
settings.threadCount = defaultThreadCount
settings.sync() settings.sync()
}
function restoreApplicationDefaults() {
settings.modelPath = settingsDialog.defaultModelPath;
settings.threadCount = defaultThreadCount
Download.downloadLocalModelsPath = settings.modelPath;
LLM.threadCount = settings.threadCount; LLM.threadCount = settings.threadCount;
settings.sync()
} }
Component.onCompleted: { Component.onCompleted: {
LLM.threadCount = settings.threadCount; LLM.threadCount = settings.threadCount;
Download.downloadLocalModelsPath = settings.modelPath;
} }
Component.onDestruction: { Component.onDestruction: {
@ -80,9 +98,33 @@ The prompt below is a question to answer, a task to complete, or a conversation
Item { Item {
Accessible.role: Accessible.Dialog Accessible.role: Accessible.Dialog
Accessible.name: qsTr("Settings dialog") Accessible.name: qsTr("Settings dialog")
Accessible.description: qsTr("Dialog containing various settings for model text generation") Accessible.description: qsTr("Dialog containing various application settings")
}
TabBar {
id: settingsTabBar
width: parent.width
TabButton {
text: qsTr("Generation")
Accessible.role: Accessible.Button
Accessible.name: qsTr("Generation settings")
Accessible.description: qsTr("Settings related to how the model generates text")
} }
TabButton {
text: qsTr("Application")
Accessible.role: Accessible.Button
Accessible.name: qsTr("Application settings")
Accessible.description: qsTr("Settings related to general behavior of the application")
}
}
StackLayout {
anchors.top: settingsTabBar.bottom
anchors.bottom: parent.bottom
width: parent.width
currentIndex: settingsTabBar.currentIndex
Item {
id: generationSettingsTab
GridLayout { GridLayout {
columns: 2 columns: 2
rowSpacing: 2 rowSpacing: 2
@ -265,16 +307,15 @@ The prompt below is a question to answer, a task to complete, or a conversation
Accessible.name: batchSizeLabel.text Accessible.name: batchSizeLabel.text
Accessible.description: ToolTip.text Accessible.description: ToolTip.text
} }
Label { Label {
id: nThreadsLabel id: repeatPenaltyLabel
text: qsTr("CPU Threads") text: qsTr("Repeat Penalty:")
color: theme.textColor color: theme.textColor
Layout.row: 5 Layout.row: 5
Layout.column: 0 Layout.column: 0
} }
TextField { TextField {
text: settingsDialog.threadCount.toString() text: settings.repeatPenalty.toString()
color: theme.textColor color: theme.textColor
background: Rectangle { background: Rectangle {
implicitWidth: 150 implicitWidth: 150
@ -282,23 +323,58 @@ The prompt below is a question to answer, a task to complete, or a conversation
radius: 10 radius: 10
} }
padding: 10 padding: 10
ToolTip.text: qsTr("Amount of processing threads to use, a setting of 0 will use the lesser of 4 or your number of CPU threads") ToolTip.text: qsTr("Amount to penalize reptetitiveness of the output")
ToolTip.visible: hovered ToolTip.visible: hovered
Layout.row: 5 Layout.row: 5
Layout.column: 1 Layout.column: 1
validator: DoubleValidator {}
onAccepted: {
var val = parseFloat(text)
if (!isNaN(val)) {
settings.repeatPenalty = val
settings.sync()
focus = false
} else {
text = settings.repeatPenalty.toString()
}
}
Accessible.role: Accessible.EditableText
Accessible.name: repeatPenaltyLabel.text
Accessible.description: ToolTip.text
}
Label {
id: repeatPenaltyTokensLabel
text: qsTr("Repeat Penalty Tokens:")
color: theme.textColor
Layout.row: 6
Layout.column: 0
}
TextField {
text: settings.repeatPenaltyTokens.toString()
color: theme.textColor
background: Rectangle {
implicitWidth: 150
color: theme.backgroundLighter
radius: 10
}
padding: 10
ToolTip.text: qsTr("How far back in output to apply repeat penalty")
ToolTip.visible: hovered
Layout.row: 6
Layout.column: 1
validator: IntValidator { bottom: 1 } validator: IntValidator { bottom: 1 }
onAccepted: { onAccepted: {
var val = parseInt(text) var val = parseInt(text)
if (!isNaN(val)) { if (!isNaN(val)) {
settingsDialog.threadCount = val settings.repeatPenaltyTokens = val
LLM.threadCount = val settings.sync()
focus = false focus = false
} else { } else {
text = settingsDialog.threadCount.toString() text = settings.repeatPenaltyTokens.toString()
} }
} }
Accessible.role: Accessible.EditableText Accessible.role: Accessible.EditableText
Accessible.name: nThreadsLabel.text Accessible.name: repeatPenaltyTokensLabel.text
Accessible.description: ToolTip.text Accessible.description: ToolTip.text
} }
@ -306,11 +382,11 @@ The prompt below is a question to answer, a task to complete, or a conversation
id: promptTemplateLabel id: promptTemplateLabel
text: qsTr("Prompt Template:") text: qsTr("Prompt Template:")
color: theme.textColor color: theme.textColor
Layout.row: 6 Layout.row: 7
Layout.column: 0 Layout.column: 0
} }
Rectangle { Rectangle {
Layout.row: 6 Layout.row: 7
Layout.column: 1 Layout.column: 1
Layout.fillWidth: true Layout.fillWidth: true
height: 200 height: 200
@ -354,7 +430,7 @@ The prompt below is a question to answer, a task to complete, or a conversation
} }
} }
Button { Button {
Layout.row: 7 Layout.row: 8
Layout.column: 1 Layout.column: 1
Layout.fillWidth: true Layout.fillWidth: true
padding: 15 padding: 15
@ -375,8 +451,112 @@ The prompt below is a question to answer, a task to complete, or a conversation
color: theme.backgroundLight color: theme.backgroundLight
} }
onClicked: { onClicked: {
settingsDialog.restoreDefaults() settingsDialog.restoreGenerationDefaults()
} }
} }
} }
}
Item {
id: systemSettingsTab
GridLayout {
columns: 3
rowSpacing: 2
columnSpacing: 10
width: parent.width
anchors.top: parent.top
FolderDialog {
id: modelPathDialog
title: "Please choose a directory"
onAccepted: {
Download.downloadLocalModelsPath = selectedFolder
settings.modelPath = Download.downloadLocalModelsPath
settings.sync()
}
}
Label {
id: modelPathLabel
text: qsTr("Model file path:")
color: theme.textColor
Layout.row: 1
Layout.column: 0
}
TextField {
id: modelPathDisplayLabel
text: settings.modelPath
color: theme.textColor
readOnly: true
Layout.row: 1
Layout.column: 1
}
Button {
Layout.row: 1
Layout.column: 2
text: qsTr("Browse")
onClicked: modelPathDialog.open()
}
Label {
id: nThreadsLabel
text: qsTr("CPU Threads:")
color: theme.textColor
Layout.row: 2
Layout.column: 0
}
TextField {
text: settingsDialog.threadCount.toString()
color: theme.textColor
background: Rectangle {
implicitWidth: 150
color: theme.backgroundLighter
radius: 10
}
padding: 10
ToolTip.text: qsTr("Amount of processing threads to use, a setting of 0 will use the lesser of 4 or your number of CPU threads")
ToolTip.visible: hovered
Layout.row: 2
Layout.column: 1
validator: IntValidator { bottom: 1 }
onAccepted: {
var val = parseInt(text)
if (!isNaN(val)) {
settingsDialog.threadCount = val
LLM.threadCount = val
settings.sync()
focus = false
} else {
text = settingsDialog.threadCount.toString()
}
}
Accessible.role: Accessible.EditableText
Accessible.name: nThreadsLabel.text
Accessible.description: ToolTip.text
}
Button {
Layout.row: 3
Layout.column: 1
Layout.fillWidth: true
padding: 15
contentItem: Text {
text: qsTr("Restore Defaults")
horizontalAlignment: Text.AlignHCenter
color: theme.textColor
Accessible.role: Accessible.Button
Accessible.name: text
Accessible.description: qsTr("Restores the settings dialog to a default state")
}
background: Rectangle {
opacity: .5
border.color: theme.backgroundLightest
border.width: 1
radius: 10
color: theme.backgroundLight
}
onClicked: {
settingsDialog.restoreApplicationDefaults()
}
}
}
}
}
} }