mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
First attempt at providing a persistent chat list experience.
Limitations: 1) Context is not restored for gpt-j models 2) When you switch between different model types in an existing chat the context and all the conversation is lost 3) The settings are not chat or conversation specific 4) The sizes of the chat persisted files are very large due to how much data the llama.cpp backend tries to persist. Need to investigate how we can shrink this.
This commit is contained in:
parent
081d32bd97
commit
f291853e51
@ -60,7 +60,7 @@ qt_add_executable(chat
|
||||
main.cpp
|
||||
chat.h chat.cpp
|
||||
chatllm.h chatllm.cpp
|
||||
chatmodel.h chatlistmodel.h
|
||||
chatmodel.h chatlistmodel.h chatlistmodel.cpp
|
||||
download.h download.cpp
|
||||
network.h network.cpp
|
||||
llm.h llm.cpp
|
||||
|
182
chat.cpp
182
chat.cpp
@ -1,32 +1,37 @@
|
||||
#include "chat.h"
|
||||
#include "llm.h"
|
||||
#include "network.h"
|
||||
#include "download.h"
|
||||
|
||||
Chat::Chat(QObject *parent)
|
||||
: QObject(parent)
|
||||
, m_llmodel(new ChatLLM)
|
||||
, m_id(Network::globalInstance()->generateUniqueId())
|
||||
, m_name(tr("New Chat"))
|
||||
, m_chatModel(new ChatModel(this))
|
||||
, m_responseInProgress(false)
|
||||
, m_desiredThreadCount(std::min(4, (int32_t) std::thread::hardware_concurrency()))
|
||||
, m_creationDate(QDateTime::currentSecsSinceEpoch())
|
||||
, m_llmodel(new ChatLLM(this))
|
||||
{
|
||||
// Should be in same thread
|
||||
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
|
||||
connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
|
||||
|
||||
// Should be in different threads
|
||||
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::responseChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::modelNameChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::threadCountChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::threadCountChanged, this, &Chat::syncThreadCount, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::recalcChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
|
||||
|
||||
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
|
||||
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
|
||||
connect(this, &Chat::unloadRequested, m_llmodel, &ChatLLM::unload, Qt::QueuedConnection);
|
||||
connect(this, &Chat::reloadRequested, m_llmodel, &ChatLLM::reload, Qt::QueuedConnection);
|
||||
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
|
||||
connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection);
|
||||
connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection);
|
||||
connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection);
|
||||
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
|
||||
connect(this, &Chat::setThreadCountRequested, m_llmodel, &ChatLLM::setThreadCount, Qt::QueuedConnection);
|
||||
|
||||
// The following are blocking operations and will block the gui thread, therefore must be fast
|
||||
// to respond to
|
||||
@ -38,9 +43,21 @@ Chat::Chat(QObject *parent)
|
||||
void Chat::reset()
|
||||
{
|
||||
stopGenerating();
|
||||
// Erase our current on disk representation as we're completely resetting the chat along with id
|
||||
LLM::globalInstance()->chatListModel()->removeChatFile(this);
|
||||
emit resetContextRequested(); // blocking queued connection
|
||||
m_id = Network::globalInstance()->generateUniqueId();
|
||||
emit idChanged();
|
||||
// NOTE: We deliberately do no reset the name or creation date to indictate that this was originally
|
||||
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
|
||||
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
|
||||
// further down in the list. This might surprise the user. In the future, we me might get rid of
|
||||
// the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown
|
||||
// we effectively do a reset context. We *have* to do this right now when switching between different
|
||||
// types of models. The only way to get rid of that would be a very long recalculate where we rebuild
|
||||
// the context if we switch between different types of models. Probably the right way to fix this
|
||||
// is to allow switching models but throwing up a dialog warning users if we switch between types
|
||||
// of models that a long recalculation will ensue.
|
||||
m_chatModel->clear();
|
||||
}
|
||||
|
||||
@ -49,10 +66,12 @@ bool Chat::isModelLoaded() const
|
||||
return m_llmodel->isModelLoaded();
|
||||
}
|
||||
|
||||
void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
|
||||
void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
|
||||
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens)
|
||||
{
|
||||
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens);
|
||||
emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch,
|
||||
repeat_penalty, repeat_penalty_tokens, LLM::globalInstance()->threadCount());
|
||||
}
|
||||
|
||||
void Chat::regenerateResponse()
|
||||
@ -70,6 +89,13 @@ QString Chat::response() const
|
||||
return m_llmodel->response();
|
||||
}
|
||||
|
||||
void Chat::handleResponseChanged()
|
||||
{
|
||||
const int index = m_chatModel->count() - 1;
|
||||
m_chatModel->updateValue(index, response());
|
||||
emit responseChanged();
|
||||
}
|
||||
|
||||
void Chat::responseStarted()
|
||||
{
|
||||
m_responseInProgress = true;
|
||||
@ -98,21 +124,6 @@ void Chat::setModelName(const QString &modelName)
|
||||
emit modelNameChangeRequested(modelName);
|
||||
}
|
||||
|
||||
void Chat::syncThreadCount() {
|
||||
emit setThreadCountRequested(m_desiredThreadCount);
|
||||
}
|
||||
|
||||
void Chat::setThreadCount(int32_t n_threads) {
|
||||
if (n_threads <= 0)
|
||||
n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
m_desiredThreadCount = n_threads;
|
||||
syncThreadCount();
|
||||
}
|
||||
|
||||
int32_t Chat::threadCount() {
|
||||
return m_llmodel->threadCount();
|
||||
}
|
||||
|
||||
void Chat::newPromptResponsePair(const QString &prompt)
|
||||
{
|
||||
m_chatModel->appendPrompt(tr("Prompt: "), prompt);
|
||||
@ -125,16 +136,25 @@ bool Chat::isRecalc() const
|
||||
return m_llmodel->isRecalc();
|
||||
}
|
||||
|
||||
void Chat::unload()
|
||||
void Chat::loadDefaultModel()
|
||||
{
|
||||
m_savedModelName = m_llmodel->modelName();
|
||||
stopGenerating();
|
||||
emit unloadRequested();
|
||||
emit loadDefaultModelRequested();
|
||||
}
|
||||
|
||||
void Chat::reload()
|
||||
void Chat::loadModel(const QString &modelName)
|
||||
{
|
||||
emit reloadRequested(m_savedModelName);
|
||||
emit loadModelRequested(modelName);
|
||||
}
|
||||
|
||||
void Chat::unloadModel()
|
||||
{
|
||||
stopGenerating();
|
||||
emit unloadModelRequested();
|
||||
}
|
||||
|
||||
void Chat::reloadModel()
|
||||
{
|
||||
emit reloadModelRequested(m_savedModelName);
|
||||
}
|
||||
|
||||
void Chat::generatedNameChanged()
|
||||
@ -150,4 +170,98 @@ void Chat::generatedNameChanged()
|
||||
void Chat::handleRecalculating()
|
||||
{
|
||||
Network::globalInstance()->sendRecalculatingContext(m_chatModel->count());
|
||||
emit recalcChanged();
|
||||
}
|
||||
|
||||
void Chat::handleModelNameChanged()
|
||||
{
|
||||
m_savedModelName = modelName();
|
||||
emit modelNameChanged();
|
||||
}
|
||||
|
||||
bool Chat::serialize(QDataStream &stream) const
|
||||
{
|
||||
stream << m_creationDate;
|
||||
stream << m_id;
|
||||
stream << m_name;
|
||||
stream << m_userName;
|
||||
stream << m_savedModelName;
|
||||
if (!m_llmodel->serialize(stream))
|
||||
return false;
|
||||
if (!m_chatModel->serialize(stream))
|
||||
return false;
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
bool Chat::deserialize(QDataStream &stream)
|
||||
{
|
||||
stream >> m_creationDate;
|
||||
stream >> m_id;
|
||||
emit idChanged();
|
||||
stream >> m_name;
|
||||
stream >> m_userName;
|
||||
emit nameChanged();
|
||||
stream >> m_savedModelName;
|
||||
if (!m_llmodel->deserialize(stream))
|
||||
return false;
|
||||
if (!m_chatModel->deserialize(stream))
|
||||
return false;
|
||||
emit chatModelChanged();
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
QList<QString> Chat::modelList() const
|
||||
{
|
||||
// Build a model list from exepath and from the localpath
|
||||
QList<QString> list;
|
||||
|
||||
QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
|
||||
QString localPath = Download::globalInstance()->downloadLocalModelsPath();
|
||||
|
||||
{
|
||||
QDir dir(exePath);
|
||||
dir.setNameFilters(QStringList() << "ggml-*.bin");
|
||||
QStringList fileNames = dir.entryList();
|
||||
for (QString f : fileNames) {
|
||||
QString filePath = exePath + f;
|
||||
QFileInfo info(filePath);
|
||||
QString name = info.completeBaseName().remove(0, 5);
|
||||
if (info.exists()) {
|
||||
if (name == modelName())
|
||||
list.prepend(name);
|
||||
else
|
||||
list.append(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (localPath != exePath) {
|
||||
QDir dir(localPath);
|
||||
dir.setNameFilters(QStringList() << "ggml-*.bin");
|
||||
QStringList fileNames = dir.entryList();
|
||||
for (QString f : fileNames) {
|
||||
QString filePath = localPath + f;
|
||||
QFileInfo info(filePath);
|
||||
QString name = info.completeBaseName().remove(0, 5);
|
||||
if (info.exists() && !list.contains(name)) { // don't allow duplicates
|
||||
if (name == modelName())
|
||||
list.prepend(name);
|
||||
else
|
||||
list.append(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (list.isEmpty()) {
|
||||
if (exePath != localPath) {
|
||||
qWarning() << "ERROR: Could not find any applicable models in"
|
||||
<< exePath << "nor" << localPath;
|
||||
} else {
|
||||
qWarning() << "ERROR: Could not find any applicable models in"
|
||||
<< exePath;
|
||||
}
|
||||
return QList<QString>();
|
||||
}
|
||||
|
||||
return list;
|
||||
}
|
||||
|
42
chat.h
42
chat.h
@ -3,6 +3,7 @@
|
||||
|
||||
#include <QObject>
|
||||
#include <QtQml>
|
||||
#include <QDataStream>
|
||||
|
||||
#include "chatllm.h"
|
||||
#include "chatmodel.h"
|
||||
@ -17,8 +18,8 @@ class Chat : public QObject
|
||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
|
||||
QML_ELEMENT
|
||||
QML_UNCREATABLE("Only creatable from c++!")
|
||||
|
||||
@ -36,13 +37,10 @@ public:
|
||||
|
||||
Q_INVOKABLE void reset();
|
||||
Q_INVOKABLE bool isModelLoaded() const;
|
||||
Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
|
||||
Q_INVOKABLE void prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
|
||||
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
|
||||
Q_INVOKABLE void regenerateResponse();
|
||||
Q_INVOKABLE void stopGenerating();
|
||||
Q_INVOKABLE void syncThreadCount();
|
||||
Q_INVOKABLE void setThreadCount(int32_t n_threads);
|
||||
Q_INVOKABLE int32_t threadCount();
|
||||
Q_INVOKABLE void newPromptResponsePair(const QString &prompt);
|
||||
|
||||
QString response() const;
|
||||
@ -51,8 +49,16 @@ public:
|
||||
void setModelName(const QString &modelName);
|
||||
bool isRecalc() const;
|
||||
|
||||
void unload();
|
||||
void reload();
|
||||
void loadDefaultModel();
|
||||
void loadModel(const QString &modelName);
|
||||
void unloadModel();
|
||||
void reloadModel();
|
||||
|
||||
qint64 creationDate() const { return m_creationDate; }
|
||||
bool serialize(QDataStream &stream) const;
|
||||
bool deserialize(QDataStream &stream);
|
||||
|
||||
QList<QString> modelList() const;
|
||||
|
||||
Q_SIGNALS:
|
||||
void idChanged();
|
||||
@ -61,35 +67,39 @@ Q_SIGNALS:
|
||||
void isModelLoadedChanged();
|
||||
void responseChanged();
|
||||
void responseInProgressChanged();
|
||||
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
|
||||
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict,
|
||||
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
|
||||
int32_t n_threads);
|
||||
void regenerateResponseRequested();
|
||||
void resetResponseRequested();
|
||||
void resetContextRequested();
|
||||
void modelNameChangeRequested(const QString &modelName);
|
||||
void modelNameChanged();
|
||||
void threadCountChanged();
|
||||
void setThreadCountRequested(int32_t threadCount);
|
||||
void recalcChanged();
|
||||
void unloadRequested();
|
||||
void reloadRequested(const QString &modelName);
|
||||
void loadDefaultModelRequested();
|
||||
void loadModelRequested(const QString &modelName);
|
||||
void unloadModelRequested();
|
||||
void reloadModelRequested(const QString &modelName);
|
||||
void generateNameRequested();
|
||||
void modelListChanged();
|
||||
|
||||
private Q_SLOTS:
|
||||
void handleResponseChanged();
|
||||
void responseStarted();
|
||||
void responseStopped();
|
||||
void generatedNameChanged();
|
||||
void handleRecalculating();
|
||||
void handleModelNameChanged();
|
||||
|
||||
private:
|
||||
ChatLLM *m_llmodel;
|
||||
QString m_id;
|
||||
QString m_name;
|
||||
QString m_userName;
|
||||
QString m_savedModelName;
|
||||
ChatModel *m_chatModel;
|
||||
bool m_responseInProgress;
|
||||
int32_t m_desiredThreadCount;
|
||||
qint64 m_creationDate;
|
||||
ChatLLM *m_llmodel;
|
||||
};
|
||||
|
||||
#endif // CHAT_H
|
||||
|
72
chatlistmodel.cpp
Normal file
72
chatlistmodel.cpp
Normal file
@ -0,0 +1,72 @@
|
||||
#include "chatlistmodel.h"
|
||||
|
||||
#include <QFile>
|
||||
#include <QDataStream>
|
||||
|
||||
void ChatListModel::removeChatFile(Chat *chat) const
|
||||
{
|
||||
QSettings settings;
|
||||
QFileInfo settingsInfo(settings.fileName());
|
||||
QString settingsPath = settingsInfo.absolutePath();
|
||||
QFile file(settingsPath + "/gpt4all-" + chat->id() + ".chat");
|
||||
if (!file.exists())
|
||||
return;
|
||||
bool success = file.remove();
|
||||
if (!success)
|
||||
qWarning() << "ERROR: Couldn't remove chat file:" << file.fileName();
|
||||
}
|
||||
|
||||
void ChatListModel::saveChats() const
|
||||
{
|
||||
QSettings settings;
|
||||
QFileInfo settingsInfo(settings.fileName());
|
||||
QString settingsPath = settingsInfo.absolutePath();
|
||||
for (Chat *chat : m_chats) {
|
||||
QFile file(settingsPath + "/gpt4all-" + chat->id() + ".chat");
|
||||
bool success = file.open(QIODevice::WriteOnly);
|
||||
if (!success) {
|
||||
qWarning() << "ERROR: Couldn't save chat to file:" << file.fileName();
|
||||
continue;
|
||||
}
|
||||
QDataStream out(&file);
|
||||
if (!chat->serialize(out)) {
|
||||
qWarning() << "ERROR: Couldn't serialize chat to file:" << file.fileName();
|
||||
file.remove();
|
||||
}
|
||||
file.close();
|
||||
}
|
||||
}
|
||||
|
||||
void ChatListModel::restoreChats()
|
||||
{
|
||||
QSettings settings;
|
||||
QFileInfo settingsInfo(settings.fileName());
|
||||
QString settingsPath = settingsInfo.absolutePath();
|
||||
QDir dir(settingsPath);
|
||||
dir.setNameFilters(QStringList() << "gpt4all-*.chat");
|
||||
QStringList fileNames = dir.entryList();
|
||||
beginResetModel();
|
||||
for (QString f : fileNames) {
|
||||
QString filePath = settingsPath + "/" + f;
|
||||
QFile file(filePath);
|
||||
bool success = file.open(QIODevice::ReadOnly);
|
||||
if (!success) {
|
||||
qWarning() << "ERROR: Couldn't restore chat from file:" << file.fileName();
|
||||
continue;
|
||||
}
|
||||
QDataStream in(&file);
|
||||
Chat *chat = new Chat(this);
|
||||
if (!chat->deserialize(in)) {
|
||||
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
|
||||
file.remove();
|
||||
} else {
|
||||
connect(chat, &Chat::nameChanged, this, &ChatListModel::nameChanged);
|
||||
m_chats.append(chat);
|
||||
}
|
||||
file.close();
|
||||
}
|
||||
std::sort(m_chats.begin(), m_chats.end(), [](const Chat* a, const Chat* b) {
|
||||
return a->creationDate() > b->creationDate();
|
||||
});
|
||||
endResetModel();
|
||||
}
|
@ -55,7 +55,7 @@ public:
|
||||
|
||||
Q_INVOKABLE void addChat()
|
||||
{
|
||||
// Don't add a new chat if the current chat is empty
|
||||
// Don't add a new chat if we already have one
|
||||
if (m_newChat)
|
||||
return;
|
||||
|
||||
@ -73,13 +73,29 @@ public:
|
||||
setCurrentChat(m_newChat);
|
||||
}
|
||||
|
||||
void setNewChat(Chat* chat)
|
||||
{
|
||||
// Don't add a new chat if we already have one
|
||||
if (m_newChat)
|
||||
return;
|
||||
|
||||
m_newChat = chat;
|
||||
connect(m_newChat->chatModel(), &ChatModel::countChanged,
|
||||
this, &ChatListModel::newChatCountChanged);
|
||||
connect(m_newChat, &Chat::nameChanged,
|
||||
this, &ChatListModel::nameChanged);
|
||||
setCurrentChat(m_newChat);
|
||||
}
|
||||
|
||||
Q_INVOKABLE void removeChat(Chat* chat)
|
||||
{
|
||||
if (!m_chats.contains(chat)) {
|
||||
qDebug() << "WARNING: Removing chat failed with id" << chat->id();
|
||||
qWarning() << "WARNING: Removing chat failed with id" << chat->id();
|
||||
return;
|
||||
}
|
||||
|
||||
removeChatFile(chat);
|
||||
|
||||
emit disconnectChat(chat);
|
||||
if (chat == m_newChat) {
|
||||
m_newChat->disconnect(this);
|
||||
@ -115,20 +131,20 @@ public:
|
||||
void setCurrentChat(Chat *chat)
|
||||
{
|
||||
if (!m_chats.contains(chat)) {
|
||||
qDebug() << "ERROR: Setting current chat failed with id" << chat->id();
|
||||
qWarning() << "ERROR: Setting current chat failed with id" << chat->id();
|
||||
return;
|
||||
}
|
||||
|
||||
if (m_currentChat) {
|
||||
if (m_currentChat->isModelLoaded())
|
||||
m_currentChat->unload();
|
||||
m_currentChat->unloadModel();
|
||||
emit disconnect(m_currentChat);
|
||||
}
|
||||
|
||||
emit connectChat(chat);
|
||||
m_currentChat = chat;
|
||||
if (!m_currentChat->isModelLoaded())
|
||||
m_currentChat->reload();
|
||||
m_currentChat->reloadModel();
|
||||
emit currentChatChanged();
|
||||
}
|
||||
|
||||
@ -138,9 +154,12 @@ public:
|
||||
return m_chats.at(index);
|
||||
}
|
||||
|
||||
|
||||
int count() const { return m_chats.size(); }
|
||||
|
||||
void removeChatFile(Chat *chat) const;
|
||||
void saveChats() const;
|
||||
void restoreChats();
|
||||
|
||||
Q_SIGNALS:
|
||||
void countChanged();
|
||||
void connectChat(Chat*);
|
||||
|
119
chatllm.cpp
119
chatllm.cpp
@ -1,7 +1,7 @@
|
||||
#include "chatllm.h"
|
||||
#include "chat.h"
|
||||
#include "download.h"
|
||||
#include "network.h"
|
||||
#include "llm.h"
|
||||
#include "llmodel/gptj.h"
|
||||
#include "llmodel/llamamodel.h"
|
||||
|
||||
@ -32,28 +32,29 @@ static QString modelFilePath(const QString &modelName)
|
||||
return QString();
|
||||
}
|
||||
|
||||
ChatLLM::ChatLLM()
|
||||
ChatLLM::ChatLLM(Chat *parent)
|
||||
: QObject{nullptr}
|
||||
, m_llmodel(nullptr)
|
||||
, m_promptResponseTokens(0)
|
||||
, m_responseLogits(0)
|
||||
, m_isRecalc(false)
|
||||
, m_chat(parent)
|
||||
{
|
||||
moveToThread(&m_llmThread);
|
||||
connect(&m_llmThread, &QThread::started, this, &ChatLLM::loadModel);
|
||||
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
|
||||
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
|
||||
m_llmThread.setObjectName("llm thread"); // FIXME: Should identify these with chat name
|
||||
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
||||
m_llmThread.setObjectName(m_chat->id());
|
||||
m_llmThread.start();
|
||||
}
|
||||
|
||||
bool ChatLLM::loadModel()
|
||||
bool ChatLLM::loadDefaultModel()
|
||||
{
|
||||
const QList<QString> models = LLM::globalInstance()->modelList();
|
||||
const QList<QString> models = m_chat->modelList();
|
||||
if (models.isEmpty()) {
|
||||
// try again when we get a list of models
|
||||
connect(Download::globalInstance(), &Download::modelListChanged, this,
|
||||
&ChatLLM::loadModel, Qt::SingleShotConnection);
|
||||
&ChatLLM::loadDefaultModel, Qt::SingleShotConnection);
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -62,10 +63,10 @@ bool ChatLLM::loadModel()
|
||||
QString defaultModel = settings.value("defaultModel", "gpt4all-j-v1.3-groovy").toString();
|
||||
if (defaultModel.isEmpty() || !models.contains(defaultModel))
|
||||
defaultModel = models.first();
|
||||
return loadModelPrivate(defaultModel);
|
||||
return loadModel(defaultModel);
|
||||
}
|
||||
|
||||
bool ChatLLM::loadModelPrivate(const QString &modelName)
|
||||
bool ChatLLM::loadModel(const QString &modelName)
|
||||
{
|
||||
if (isModelLoaded() && m_modelName == modelName)
|
||||
return true;
|
||||
@ -100,12 +101,13 @@ bool ChatLLM::loadModelPrivate(const QString &modelName)
|
||||
}
|
||||
|
||||
emit isModelLoadedChanged();
|
||||
emit threadCountChanged();
|
||||
|
||||
if (isFirstLoad)
|
||||
emit sendStartup();
|
||||
else
|
||||
emit sendModelLoaded();
|
||||
} else {
|
||||
qWarning() << "ERROR: Could not find model at" << filePath;
|
||||
}
|
||||
|
||||
if (m_llmodel)
|
||||
@ -114,19 +116,6 @@ bool ChatLLM::loadModelPrivate(const QString &modelName)
|
||||
return m_llmodel;
|
||||
}
|
||||
|
||||
void ChatLLM::setThreadCount(int32_t n_threads) {
|
||||
if (m_llmodel && m_llmodel->threadCount() != n_threads) {
|
||||
m_llmodel->setThreadCount(n_threads);
|
||||
emit threadCountChanged();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t ChatLLM::threadCount() {
|
||||
if (!m_llmodel)
|
||||
return 1;
|
||||
return m_llmodel->threadCount();
|
||||
}
|
||||
|
||||
bool ChatLLM::isModelLoaded() const
|
||||
{
|
||||
return m_llmodel && m_llmodel->isModelLoaded();
|
||||
@ -203,7 +192,7 @@ void ChatLLM::setModelName(const QString &modelName)
|
||||
|
||||
void ChatLLM::modelNameChangeRequested(const QString &modelName)
|
||||
{
|
||||
if (!loadModelPrivate(modelName))
|
||||
if (!loadModel(modelName))
|
||||
qWarning() << "ERROR: Could not load model" << modelName;
|
||||
}
|
||||
|
||||
@ -247,8 +236,8 @@ bool ChatLLM::handleRecalculate(bool isRecalc)
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
|
||||
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k,
|
||||
float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads)
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
return false;
|
||||
@ -269,6 +258,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
|
||||
m_ctx.n_batch = n_batch;
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
m_llmodel->setThreadCount(n_threads);
|
||||
#if defined(DEBUG)
|
||||
printf("%s", qPrintable(instructPrompt));
|
||||
fflush(stdout);
|
||||
@ -288,19 +278,22 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
|
||||
return true;
|
||||
}
|
||||
|
||||
void ChatLLM::unload()
|
||||
void ChatLLM::unloadModel()
|
||||
{
|
||||
saveState();
|
||||
delete m_llmodel;
|
||||
m_llmodel = nullptr;
|
||||
emit isModelLoadedChanged();
|
||||
}
|
||||
|
||||
void ChatLLM::reload(const QString &modelName)
|
||||
void ChatLLM::reloadModel(const QString &modelName)
|
||||
{
|
||||
if (modelName.isEmpty())
|
||||
loadModel();
|
||||
else
|
||||
loadModelPrivate(modelName);
|
||||
if (modelName.isEmpty()) {
|
||||
loadDefaultModel();
|
||||
} else {
|
||||
loadModel(modelName);
|
||||
}
|
||||
restoreState();
|
||||
}
|
||||
|
||||
void ChatLLM::generateName()
|
||||
@ -333,6 +326,11 @@ void ChatLLM::generateName()
|
||||
}
|
||||
}
|
||||
|
||||
void ChatLLM::handleChatIdChanged()
|
||||
{
|
||||
m_llmThread.setObjectName(m_chat->id());
|
||||
}
|
||||
|
||||
bool ChatLLM::handleNamePrompt(int32_t token)
|
||||
{
|
||||
Q_UNUSED(token);
|
||||
@ -354,3 +352,60 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc)
|
||||
Q_UNREACHABLE();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ChatLLM::serialize(QDataStream &stream)
|
||||
{
|
||||
stream << response();
|
||||
stream << generatedName();
|
||||
stream << m_promptResponseTokens;
|
||||
stream << m_responseLogits;
|
||||
stream << m_ctx.n_past;
|
||||
stream << quint64(m_ctx.logits.size());
|
||||
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.logits.data()), m_ctx.logits.size() * sizeof(float));
|
||||
stream << quint64(m_ctx.tokens.size());
|
||||
stream.writeRawData(reinterpret_cast<const char*>(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int));
|
||||
saveState();
|
||||
stream << m_state;
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
bool ChatLLM::deserialize(QDataStream &stream)
|
||||
{
|
||||
QString response;
|
||||
stream >> response;
|
||||
m_response = response.toStdString();
|
||||
QString nameResponse;
|
||||
stream >> nameResponse;
|
||||
m_nameResponse = nameResponse.toStdString();
|
||||
stream >> m_promptResponseTokens;
|
||||
stream >> m_responseLogits;
|
||||
stream >> m_ctx.n_past;
|
||||
quint64 logitsSize;
|
||||
stream >> logitsSize;
|
||||
m_ctx.logits.resize(logitsSize);
|
||||
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
|
||||
quint64 tokensSize;
|
||||
stream >> tokensSize;
|
||||
m_ctx.tokens.resize(tokensSize);
|
||||
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
|
||||
stream >> m_state;
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
void ChatLLM::saveState()
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
return;
|
||||
|
||||
const size_t stateSize = m_llmodel->stateSize();
|
||||
m_state.resize(stateSize);
|
||||
m_llmodel->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
||||
}
|
||||
|
||||
void ChatLLM::restoreState()
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
return;
|
||||
|
||||
m_llmodel->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
||||
}
|
||||
|
29
chatllm.h
29
chatllm.h
@ -6,18 +6,18 @@
|
||||
|
||||
#include "llmodel/llmodel.h"
|
||||
|
||||
class Chat;
|
||||
class ChatLLM : public QObject
|
||||
{
|
||||
Q_OBJECT
|
||||
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
|
||||
Q_PROPERTY(QString response READ response NOTIFY responseChanged)
|
||||
Q_PROPERTY(QString modelName READ modelName WRITE setModelName NOTIFY modelNameChanged)
|
||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||
Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
|
||||
|
||||
public:
|
||||
ChatLLM();
|
||||
ChatLLM(Chat *parent);
|
||||
|
||||
bool isModelLoaded() const;
|
||||
void regenerateResponse();
|
||||
@ -25,8 +25,6 @@ public:
|
||||
void resetContext();
|
||||
|
||||
void stopGenerating() { m_stopGenerating = true; }
|
||||
void setThreadCount(int32_t n_threads);
|
||||
int32_t threadCount();
|
||||
|
||||
QString response() const;
|
||||
QString modelName() const;
|
||||
@ -37,14 +35,20 @@ public:
|
||||
|
||||
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
|
||||
|
||||
bool serialize(QDataStream &stream);
|
||||
bool deserialize(QDataStream &stream);
|
||||
|
||||
public Q_SLOTS:
|
||||
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens);
|
||||
bool loadModel();
|
||||
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
|
||||
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
|
||||
int32_t n_threads);
|
||||
bool loadDefaultModel();
|
||||
bool loadModel(const QString &modelName);
|
||||
void modelNameChangeRequested(const QString &modelName);
|
||||
void unload();
|
||||
void reload(const QString &modelName);
|
||||
void unloadModel();
|
||||
void reloadModel(const QString &modelName);
|
||||
void generateName();
|
||||
void handleChatIdChanged();
|
||||
|
||||
Q_SIGNALS:
|
||||
void isModelLoadedChanged();
|
||||
@ -52,22 +56,23 @@ Q_SIGNALS:
|
||||
void responseStarted();
|
||||
void responseStopped();
|
||||
void modelNameChanged();
|
||||
void threadCountChanged();
|
||||
void recalcChanged();
|
||||
void sendStartup();
|
||||
void sendModelLoaded();
|
||||
void sendResetContext();
|
||||
void generatedNameChanged();
|
||||
void stateChanged();
|
||||
|
||||
private:
|
||||
void resetContextPrivate();
|
||||
bool loadModelPrivate(const QString &modelName);
|
||||
bool handlePrompt(int32_t token);
|
||||
bool handleResponse(int32_t token, const std::string &response);
|
||||
bool handleRecalculate(bool isRecalc);
|
||||
bool handleNamePrompt(int32_t token);
|
||||
bool handleNameResponse(int32_t token, const std::string &response);
|
||||
bool handleNameRecalculate(bool isRecalc);
|
||||
void saveState();
|
||||
void restoreState();
|
||||
|
||||
private:
|
||||
LLModel::PromptContext m_ctx;
|
||||
@ -77,6 +82,8 @@ private:
|
||||
quint32 m_promptResponseTokens;
|
||||
quint32 m_responseLogits;
|
||||
QString m_modelName;
|
||||
Chat *m_chat;
|
||||
QByteArray m_state;
|
||||
QThread m_llmThread;
|
||||
std::atomic<bool> m_stopGenerating;
|
||||
bool m_isRecalc;
|
||||
|
41
chatmodel.h
41
chatmodel.h
@ -3,6 +3,7 @@
|
||||
|
||||
#include <QAbstractListModel>
|
||||
#include <QtQml>
|
||||
#include <QDataStream>
|
||||
|
||||
struct ChatItem
|
||||
{
|
||||
@ -209,6 +210,46 @@ public:
|
||||
|
||||
int count() const { return m_chatItems.size(); }
|
||||
|
||||
bool serialize(QDataStream &stream) const
|
||||
{
|
||||
stream << count();
|
||||
for (auto c : m_chatItems) {
|
||||
stream << c.id;
|
||||
stream << c.name;
|
||||
stream << c.value;
|
||||
stream << c.prompt;
|
||||
stream << c.newResponse;
|
||||
stream << c.currentResponse;
|
||||
stream << c.stopped;
|
||||
stream << c.thumbsUpState;
|
||||
stream << c.thumbsDownState;
|
||||
}
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
bool deserialize(QDataStream &stream)
|
||||
{
|
||||
int size;
|
||||
stream >> size;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ChatItem c;
|
||||
stream >> c.id;
|
||||
stream >> c.name;
|
||||
stream >> c.value;
|
||||
stream >> c.prompt;
|
||||
stream >> c.newResponse;
|
||||
stream >> c.currentResponse;
|
||||
stream >> c.stopped;
|
||||
stream >> c.thumbsUpState;
|
||||
stream >> c.thumbsDownState;
|
||||
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
|
||||
m_chatItems.append(c);
|
||||
endInsertRows();
|
||||
}
|
||||
emit countChanged();
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
Q_SIGNALS:
|
||||
void countChanged();
|
||||
|
||||
|
96
llm.cpp
96
llm.cpp
@ -20,77 +20,22 @@ LLM *LLM::globalInstance()
|
||||
LLM::LLM()
|
||||
: QObject{nullptr}
|
||||
, m_chatListModel(new ChatListModel(this))
|
||||
, m_threadCount(std::min(4, (int32_t) std::thread::hardware_concurrency()))
|
||||
{
|
||||
// Should be in the same thread
|
||||
connect(Download::globalInstance(), &Download::modelListChanged,
|
||||
this, &LLM::modelListChanged, Qt::DirectConnection);
|
||||
connect(m_chatListModel, &ChatListModel::connectChat,
|
||||
this, &LLM::connectChat, Qt::DirectConnection);
|
||||
connect(m_chatListModel, &ChatListModel::disconnectChat,
|
||||
this, &LLM::disconnectChat, Qt::DirectConnection);
|
||||
connect(QCoreApplication::instance(), &QCoreApplication::aboutToQuit,
|
||||
this, &LLM::aboutToQuit);
|
||||
|
||||
if (!m_chatListModel->count())
|
||||
m_chatListModel->restoreChats();
|
||||
if (m_chatListModel->count()) {
|
||||
Chat *firstChat = m_chatListModel->get(0);
|
||||
if (firstChat->chatModel()->count() < 2)
|
||||
m_chatListModel->setNewChat(firstChat);
|
||||
else
|
||||
m_chatListModel->setCurrentChat(firstChat);
|
||||
} else
|
||||
m_chatListModel->addChat();
|
||||
}
|
||||
|
||||
QList<QString> LLM::modelList() const
|
||||
{
|
||||
Q_ASSERT(m_chatListModel->currentChat());
|
||||
const Chat *currentChat = m_chatListModel->currentChat();
|
||||
// Build a model list from exepath and from the localpath
|
||||
QList<QString> list;
|
||||
|
||||
QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
|
||||
QString localPath = Download::globalInstance()->downloadLocalModelsPath();
|
||||
|
||||
{
|
||||
QDir dir(exePath);
|
||||
dir.setNameFilters(QStringList() << "ggml-*.bin");
|
||||
QStringList fileNames = dir.entryList();
|
||||
for (QString f : fileNames) {
|
||||
QString filePath = exePath + f;
|
||||
QFileInfo info(filePath);
|
||||
QString name = info.completeBaseName().remove(0, 5);
|
||||
if (info.exists()) {
|
||||
if (name == currentChat->modelName())
|
||||
list.prepend(name);
|
||||
else
|
||||
list.append(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (localPath != exePath) {
|
||||
QDir dir(localPath);
|
||||
dir.setNameFilters(QStringList() << "ggml-*.bin");
|
||||
QStringList fileNames = dir.entryList();
|
||||
for (QString f : fileNames) {
|
||||
QString filePath = localPath + f;
|
||||
QFileInfo info(filePath);
|
||||
QString name = info.completeBaseName().remove(0, 5);
|
||||
if (info.exists() && !list.contains(name)) { // don't allow duplicates
|
||||
if (name == currentChat->modelName())
|
||||
list.prepend(name);
|
||||
else
|
||||
list.append(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (list.isEmpty()) {
|
||||
if (exePath != localPath) {
|
||||
qWarning() << "ERROR: Could not find any applicable models in"
|
||||
<< exePath << "nor" << localPath;
|
||||
} else {
|
||||
qWarning() << "ERROR: Could not find any applicable models in"
|
||||
<< exePath;
|
||||
}
|
||||
return QList<QString>();
|
||||
}
|
||||
|
||||
return list;
|
||||
}
|
||||
|
||||
bool LLM::checkForUpdates() const
|
||||
{
|
||||
Network::globalInstance()->sendCheckForUpdates();
|
||||
@ -113,21 +58,20 @@ bool LLM::checkForUpdates() const
|
||||
return QProcess::startDetached(fileName);
|
||||
}
|
||||
|
||||
bool LLM::isRecalc() const
|
||||
int32_t LLM::threadCount() const
|
||||
{
|
||||
Q_ASSERT(m_chatListModel->currentChat());
|
||||
return m_chatListModel->currentChat()->isRecalc();
|
||||
return m_threadCount;
|
||||
}
|
||||
|
||||
void LLM::connectChat(Chat *chat)
|
||||
void LLM::setThreadCount(int32_t n_threads)
|
||||
{
|
||||
// Should be in the same thread
|
||||
connect(chat, &Chat::modelNameChanged, this, &LLM::modelListChanged, Qt::DirectConnection);
|
||||
connect(chat, &Chat::recalcChanged, this, &LLM::recalcChanged, Qt::DirectConnection);
|
||||
connect(chat, &Chat::responseChanged, this, &LLM::responseChanged, Qt::DirectConnection);
|
||||
if (n_threads <= 0)
|
||||
n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
m_threadCount = n_threads;
|
||||
emit threadCountChanged();
|
||||
}
|
||||
|
||||
void LLM::disconnectChat(Chat *chat)
|
||||
void LLM::aboutToQuit()
|
||||
{
|
||||
chat->disconnect(this);
|
||||
m_chatListModel->saveChats();
|
||||
}
|
||||
|
16
llm.h
16
llm.h
@ -3,37 +3,33 @@
|
||||
|
||||
#include <QObject>
|
||||
|
||||
#include "chat.h"
|
||||
#include "chatlistmodel.h"
|
||||
|
||||
class LLM : public QObject
|
||||
{
|
||||
Q_OBJECT
|
||||
Q_PROPERTY(QList<QString> modelList READ modelList NOTIFY modelListChanged)
|
||||
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||
Q_PROPERTY(ChatListModel *chatListModel READ chatListModel NOTIFY chatListModelChanged)
|
||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||
|
||||
public:
|
||||
static LLM *globalInstance();
|
||||
|
||||
QList<QString> modelList() const;
|
||||
bool isRecalc() const;
|
||||
ChatListModel *chatListModel() const { return m_chatListModel; }
|
||||
int32_t threadCount() const;
|
||||
void setThreadCount(int32_t n_threads);
|
||||
|
||||
Q_INVOKABLE bool checkForUpdates() const;
|
||||
|
||||
Q_SIGNALS:
|
||||
void modelListChanged();
|
||||
void recalcChanged();
|
||||
void responseChanged();
|
||||
void chatListModelChanged();
|
||||
void threadCountChanged();
|
||||
|
||||
private Q_SLOTS:
|
||||
void connectChat(Chat*);
|
||||
void disconnectChat(Chat*);
|
||||
void aboutToQuit();
|
||||
|
||||
private:
|
||||
ChatListModel *m_chatListModel;
|
||||
int32_t m_threadCount;
|
||||
|
||||
private:
|
||||
explicit LLM();
|
||||
|
@ -67,6 +67,7 @@ int32_t LLamaModel::threadCount() {
|
||||
|
||||
LLamaModel::~LLamaModel()
|
||||
{
|
||||
llama_free(d_ptr->ctx);
|
||||
}
|
||||
|
||||
bool LLamaModel::isModelLoaded() const
|
||||
@ -74,6 +75,21 @@ bool LLamaModel::isModelLoaded() const
|
||||
return d_ptr->modelLoaded;
|
||||
}
|
||||
|
||||
size_t LLamaModel::stateSize() const
|
||||
{
|
||||
return llama_get_state_size(d_ptr->ctx);
|
||||
}
|
||||
|
||||
size_t LLamaModel::saveState(uint8_t *dest) const
|
||||
{
|
||||
return llama_copy_state_data(d_ptr->ctx, dest);
|
||||
}
|
||||
|
||||
size_t LLamaModel::restoreState(const uint8_t *src)
|
||||
{
|
||||
return llama_set_state_data(d_ptr->ctx, src);
|
||||
}
|
||||
|
||||
void LLamaModel::prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
|
@ -14,6 +14,9 @@ public:
|
||||
|
||||
bool loadModel(const std::string &modelPath) override;
|
||||
bool isModelLoaded() const override;
|
||||
size_t stateSize() const override;
|
||||
size_t saveState(uint8_t *dest) const override;
|
||||
size_t restoreState(const uint8_t *src) override;
|
||||
void prompt(const std::string &prompt,
|
||||
std::function<bool(int32_t)> promptCallback,
|
||||
std::function<bool(int32_t, const std::string&)> responseCallback,
|
||||
|
@ -12,6 +12,9 @@ public:
|
||||
|
||||
virtual bool loadModel(const std::string &modelPath) = 0;
|
||||
virtual bool isModelLoaded() const = 0;
|
||||
virtual size_t stateSize() const { return 0; }
|
||||
virtual size_t saveState(uint8_t *dest) const { return 0; }
|
||||
virtual size_t restoreState(const uint8_t *src) { return 0; }
|
||||
struct PromptContext {
|
||||
std::vector<float> logits; // logits of current context
|
||||
std::vector<int32_t> tokens; // current tokens in the context window
|
||||
|
@ -48,6 +48,24 @@ bool llmodel_isModelLoaded(llmodel_model model)
|
||||
return wrapper->llModel->isModelLoaded();
|
||||
}
|
||||
|
||||
uint64_t llmodel_get_state_size(llmodel_model model)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->stateSize();
|
||||
}
|
||||
|
||||
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->saveState(dest);
|
||||
}
|
||||
|
||||
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src)
|
||||
{
|
||||
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
|
||||
return wrapper->llModel->restoreState(src);
|
||||
}
|
||||
|
||||
// Wrapper functions for the C callbacks
|
||||
bool prompt_wrapper(int32_t token_id, void *user_data) {
|
||||
llmodel_prompt_callback callback = reinterpret_cast<llmodel_prompt_callback>(user_data);
|
||||
|
@ -98,6 +98,32 @@ bool llmodel_loadModel(llmodel_model model, const char *model_path);
|
||||
*/
|
||||
bool llmodel_isModelLoaded(llmodel_model model);
|
||||
|
||||
/**
|
||||
* Get the size of the internal state of the model.
|
||||
* NOTE: This state data is specific to the type of model you have created.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @return the size in bytes of the internal state of the model
|
||||
*/
|
||||
uint64_t llmodel_get_state_size(llmodel_model model);
|
||||
|
||||
/**
|
||||
* Saves the internal state of the model to the specified destination address.
|
||||
* NOTE: This state data is specific to the type of model you have created.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param dest A pointer to the destination.
|
||||
* @return the number of bytes copied
|
||||
*/
|
||||
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest);
|
||||
|
||||
/**
|
||||
* Restores the internal state of the model using data from the specified address.
|
||||
* NOTE: This state data is specific to the type of model you have created.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
* @param src A pointer to the src.
|
||||
* @return the number of bytes read
|
||||
*/
|
||||
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src);
|
||||
|
||||
/**
|
||||
* Generate a response using the model.
|
||||
* @param model A pointer to the llmodel_model instance.
|
||||
|
32
main.qml
32
main.qml
@ -65,7 +65,7 @@ Window {
|
||||
}
|
||||
|
||||
// check for any current models and if not, open download dialog
|
||||
if (LLM.modelList.length === 0 && !firstStartDialog.opened) {
|
||||
if (currentChat.modelList.length === 0 && !firstStartDialog.opened) {
|
||||
downloadNewModels.open();
|
||||
return;
|
||||
}
|
||||
@ -125,7 +125,7 @@ Window {
|
||||
anchors.horizontalCenter: parent.horizontalCenter
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
spacing: 0
|
||||
model: LLM.modelList
|
||||
model: currentChat.modelList
|
||||
Accessible.role: Accessible.ComboBox
|
||||
Accessible.name: qsTr("ComboBox for displaying/picking the current model")
|
||||
Accessible.description: qsTr("Use this for picking the current model to use; the first item is the current model")
|
||||
@ -367,9 +367,9 @@ Window {
|
||||
text: qsTr("Recalculating context.")
|
||||
|
||||
Connections {
|
||||
target: LLM
|
||||
target: currentChat
|
||||
function onRecalcChanged() {
|
||||
if (LLM.isRecalc)
|
||||
if (currentChat.isRecalc)
|
||||
recalcPopup.open()
|
||||
else
|
||||
recalcPopup.close()
|
||||
@ -422,10 +422,7 @@ Window {
|
||||
var item = chatModel.get(i)
|
||||
var string = item.name;
|
||||
var isResponse = item.name === qsTr("Response: ")
|
||||
if (item.currentResponse)
|
||||
string += currentChat.response
|
||||
else
|
||||
string += chatModel.get(i).value
|
||||
string += chatModel.get(i).value
|
||||
if (isResponse && item.stopped)
|
||||
string += " <stopped>"
|
||||
string += "\n"
|
||||
@ -440,10 +437,7 @@ Window {
|
||||
var item = chatModel.get(i)
|
||||
var isResponse = item.name === qsTr("Response: ")
|
||||
str += "{\"content\": ";
|
||||
if (item.currentResponse)
|
||||
str += JSON.stringify(currentChat.response)
|
||||
else
|
||||
str += JSON.stringify(item.value)
|
||||
str += JSON.stringify(item.value)
|
||||
str += ", \"role\": \"" + (isResponse ? "assistant" : "user") + "\"";
|
||||
if (isResponse && item.thumbsUpState !== item.thumbsDownState)
|
||||
str += ", \"rating\": \"" + (item.thumbsUpState ? "positive" : "negative") + "\"";
|
||||
@ -572,14 +566,14 @@ Window {
|
||||
Accessible.description: qsTr("This is the list of prompt/response pairs comprising the actual conversation with the model")
|
||||
|
||||
delegate: TextArea {
|
||||
text: currentResponse ? currentChat.response : (value ? value : "")
|
||||
text: value
|
||||
width: listView.width
|
||||
color: theme.textColor
|
||||
wrapMode: Text.WordWrap
|
||||
focus: false
|
||||
readOnly: true
|
||||
font.pixelSize: theme.fontSizeLarge
|
||||
cursorVisible: currentResponse ? (currentChat.response !== "" ? currentChat.responseInProgress : false) : false
|
||||
cursorVisible: currentResponse ? currentChat.responseInProgress : false
|
||||
cursorPosition: text.length
|
||||
background: Rectangle {
|
||||
color: name === qsTr("Response: ") ? theme.backgroundLighter : theme.backgroundLight
|
||||
@ -599,8 +593,8 @@ Window {
|
||||
anchors.leftMargin: 90
|
||||
anchors.top: parent.top
|
||||
anchors.topMargin: 5
|
||||
visible: (currentResponse ? true : false) && currentChat.response === "" && currentChat.responseInProgress
|
||||
running: (currentResponse ? true : false) && currentChat.response === "" && currentChat.responseInProgress
|
||||
visible: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress
|
||||
running: (currentResponse ? true : false) && value === "" && currentChat.responseInProgress
|
||||
|
||||
Accessible.role: Accessible.Animation
|
||||
Accessible.name: qsTr("Busy indicator")
|
||||
@ -631,7 +625,7 @@ Window {
|
||||
window.height / 2 - height / 2)
|
||||
x: globalPoint.x
|
||||
y: globalPoint.y
|
||||
property string text: currentResponse ? currentChat.response : (value ? value : "")
|
||||
property string text: value
|
||||
response: newResponse === undefined || newResponse === "" ? text : newResponse
|
||||
onAccepted: {
|
||||
var responseHasChanged = response !== text && response !== newResponse
|
||||
@ -711,7 +705,7 @@ Window {
|
||||
property bool isAutoScrolling: false
|
||||
|
||||
Connections {
|
||||
target: LLM
|
||||
target: currentChat
|
||||
function onResponseChanged() {
|
||||
if (listView.shouldAutoScroll) {
|
||||
listView.isAutoScrolling = true
|
||||
@ -762,7 +756,6 @@ Window {
|
||||
if (listElement.name === qsTr("Response: ")) {
|
||||
chatModel.updateCurrentResponse(index, true);
|
||||
chatModel.updateStopped(index, false);
|
||||
chatModel.updateValue(index, currentChat.response);
|
||||
chatModel.updateThumbsUpState(index, false);
|
||||
chatModel.updateThumbsDownState(index, false);
|
||||
chatModel.updateNewResponse(index, "");
|
||||
@ -840,7 +833,6 @@ Window {
|
||||
var index = Math.max(0, chatModel.count - 1);
|
||||
var listElement = chatModel.get(index);
|
||||
chatModel.updateCurrentResponse(index, false);
|
||||
chatModel.updateValue(index, currentChat.response);
|
||||
}
|
||||
currentChat.newPromptResponsePair(textInput.text);
|
||||
currentChat.prompt(textInput.text, settingsDialog.promptTemplate,
|
||||
|
@ -458,7 +458,6 @@ void Network::handleIpifyFinished()
|
||||
|
||||
void Network::handleMixpanelFinished()
|
||||
{
|
||||
Q_ASSERT(m_usageStatsActive);
|
||||
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
||||
if (!reply)
|
||||
return;
|
||||
|
@ -83,6 +83,7 @@ Drawer {
|
||||
opacity: 0.9
|
||||
property bool isCurrent: LLM.chatListModel.currentChat === LLM.chatListModel.get(index)
|
||||
property bool trashQuestionDisplayed: false
|
||||
z: isCurrent ? 199 : 1
|
||||
color: index % 2 === 0 ? theme.backgroundLight : theme.backgroundLighter
|
||||
border.width: isCurrent
|
||||
border.color: chatName.readOnly ? theme.assistantColor : theme.userColor
|
||||
@ -112,6 +113,11 @@ Drawer {
|
||||
color: "transparent"
|
||||
}
|
||||
onEditingFinished: {
|
||||
// Work around a bug in qml where we're losing focus when the whole window
|
||||
// goes out of focus even though this textfield should be marked as not
|
||||
// having focus
|
||||
if (chatName.readOnly)
|
||||
return;
|
||||
changeName();
|
||||
Network.sendRenameChat()
|
||||
}
|
||||
@ -188,6 +194,7 @@ Drawer {
|
||||
visible: isCurrent && trashQuestionDisplayed
|
||||
opacity: 1.0
|
||||
radius: 10
|
||||
z: 200
|
||||
Row {
|
||||
spacing: 10
|
||||
Button {
|
||||
|
@ -12,7 +12,7 @@ Dialog {
|
||||
id: modelDownloaderDialog
|
||||
modal: true
|
||||
opacity: 0.9
|
||||
closePolicy: LLM.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside)
|
||||
closePolicy: LLM.chatListModel.currentChat.modelList.length === 0 ? Popup.NoAutoClose : (Popup.CloseOnEscape | Popup.CloseOnPressOutside)
|
||||
background: Rectangle {
|
||||
anchors.fill: parent
|
||||
anchors.margins: -20
|
||||
|
Loading…
Reference in New Issue
Block a user