Start working on more thread safety and model load error handling.

This commit is contained in:
Adam Treat 2023-06-19 19:51:28 -04:00 committed by AT
parent d5f56d3308
commit 7d2ce06029
6 changed files with 115 additions and 56 deletions

View File

@ -43,9 +43,16 @@ Chat::~Chat()
void Chat::connectLLM() void Chat::connectLLM()
{ {
const QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
const QString localPath = Download::globalInstance()->downloadLocalModelsPath();
m_watcher = new QFileSystemWatcher(this);
m_watcher->addPath(exePath);
m_watcher->addPath(localPath);
// Should be in same thread // Should be in same thread
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection); connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::handleModelListChanged, Qt::DirectConnection);
connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection); connect(this, &Chat::modelNameChanged, this, &Chat::handleModelListChanged, Qt::DirectConnection);
connect(m_watcher, &QFileSystemWatcher::directoryChanged, this, &Chat::handleModelListChanged);
// Should be in different threads // Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
@ -71,6 +78,8 @@ void Chat::connectLLM()
// to respond to // to respond to
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection); connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection); connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection);
emit defaultModelChanged(modelList().first());
} }
void Chat::reset() void Chat::reset()
@ -80,7 +89,7 @@ void Chat::reset()
LLM::globalInstance()->chatListModel()->removeChatFile(this); LLM::globalInstance()->chatListModel()->removeChatFile(this);
emit resetContextRequested(); // blocking queued connection emit resetContextRequested(); // blocking queued connection
m_id = Network::globalInstance()->generateUniqueId(); m_id = Network::globalInstance()->generateUniqueId();
emit idChanged(); emit idChanged(m_id);
// NOTE: We deliberately do no reset the name or creation date to indictate that this was originally // NOTE: We deliberately do no reset the name or creation date to indictate that this was originally
// an older chat that was reset for another purpose. Resetting this data will lead to the chat // an older chat that was reset for another purpose. Resetting this data will lead to the chat
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat' // name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
@ -118,6 +127,7 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t
{ {
resetResponseState(); resetResponseState();
emit promptRequested( emit promptRequested(
m_collections,
prompt, prompt,
prompt_template, prompt_template,
n_predict, n_predict,
@ -275,6 +285,7 @@ bool Chat::isRecalc() const
void Chat::loadDefaultModel() void Chat::loadDefaultModel()
{ {
emit defaultModelChanged(modelList().first());
m_modelLoadingError = QString(); m_modelLoadingError = QString();
emit modelLoadingErrorChanged(); emit modelLoadingErrorChanged();
emit loadDefaultModelRequested(); emit loadDefaultModelRequested();
@ -369,7 +380,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
{ {
stream >> m_creationDate; stream >> m_creationDate;
stream >> m_id; stream >> m_id;
emit idChanged(); emit idChanged(m_id);
stream >> m_name; stream >> m_name;
stream >> m_userName; stream >> m_userName;
emit nameChanged(); emit nameChanged();
@ -380,7 +391,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
return false; return false;
if (version > 2) { if (version > 2) {
stream >> m_collections; stream >> m_collections;
emit collectionListChanged(); emit collectionListChanged(m_collections);
} }
m_llmodel->setModelName(m_savedModelName); m_llmodel->setModelName(m_savedModelName);
if (!m_llmodel->deserialize(stream, version)) if (!m_llmodel->deserialize(stream, version))
@ -475,6 +486,19 @@ QList<QString> Chat::modelList() const
return list; return list;
} }
void Chat::handleModelListChanged()
{
emit modelListChanged();
emit defaultModelChanged(modelList().first());
}
void Chat::handleDownloadLocalModelsPathChanged()
{
emit modelListChanged();
emit defaultModelChanged(modelList().first());
m_watcher->addPath(Download::globalInstance()->downloadLocalModelsPath());
}
QList<QString> Chat::collectionList() const QList<QString> Chat::collectionList() const
{ {
return m_collections; return m_collections;
@ -491,7 +515,7 @@ void Chat::addCollection(const QString &collection)
return; return;
m_collections.append(collection); m_collections.append(collection);
emit collectionListChanged(); emit collectionListChanged(m_collections);
} }
void Chat::removeCollection(const QString &collection) void Chat::removeCollection(const QString &collection)
@ -500,5 +524,5 @@ void Chat::removeCollection(const QString &collection)
return; return;
m_collections.removeAll(collection); m_collections.removeAll(collection);
emit collectionListChanged(); emit collectionListChanged(m_collections);
} }

View File

@ -98,16 +98,16 @@ public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt); void serverNewPromptResponsePair(const QString &prompt);
Q_SIGNALS: Q_SIGNALS:
void idChanged(); void idChanged(const QString &id);
void nameChanged(); void nameChanged();
void chatModelChanged(); void chatModelChanged();
void isModelLoadedChanged(); void isModelLoadedChanged();
void responseChanged(); void responseChanged();
void responseInProgressChanged(); void responseInProgressChanged();
void responseStateChanged(); void responseStateChanged();
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, void promptRequested(const QList<QString> &collectionList, const QString &prompt, const QString &prompt_template,
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t n_threads); int32_t repeat_penalty_tokens, int32_t n_threads);
void regenerateResponseRequested(); void regenerateResponseRequested();
void resetResponseRequested(); void resetResponseRequested();
void resetContextRequested(); void resetContextRequested();
@ -120,8 +120,9 @@ Q_SIGNALS:
void modelListChanged(); void modelListChanged();
void modelLoadingErrorChanged(); void modelLoadingErrorChanged();
void isServerChanged(); void isServerChanged();
void collectionListChanged(); void collectionListChanged(const QList<QString> &collectionList);
void tokenSpeedChanged(); void tokenSpeedChanged();
void defaultModelChanged(const QString &defaultModel);
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(); void handleResponseChanged();
@ -134,6 +135,8 @@ private Q_SLOTS:
void handleModelLoadingError(const QString &error); void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed); void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results); void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
void handleModelListChanged();
void handleDownloadLocalModelsPathChanged();
private: private:
QString m_id; QString m_id;
@ -151,6 +154,7 @@ private:
QList<ResultInfo> m_databaseResults; QList<ResultInfo> m_databaseResults;
bool m_isServer; bool m_isServer;
bool m_shouldDeleteLater; bool m_shouldDeleteLater;
QFileSystemWatcher *m_watcher;
}; };
#endif // CHAT_H #endif // CHAT_H

View File

@ -94,7 +94,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_isRecalc(false) , m_isRecalc(false)
, m_shouldBeLoaded(true) , m_shouldBeLoaded(true)
, m_stopGenerating(false) , m_stopGenerating(false)
, m_chat(parent)
, m_timer(nullptr) , m_timer(nullptr)
, m_isServer(isServer) , m_isServer(isServer)
, m_isChatGPT(false) , m_isChatGPT(false)
@ -104,14 +103,15 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued Qt::QueuedConnection); // explicitly queued
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(parent, &Chat::defaultModelChanged, this, &ChatLLM::handleDefaultModelChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
// The following are blocking operations and will block the llm thread // The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
Qt::BlockingQueuedConnection); Qt::BlockingQueuedConnection);
m_llmThread.setObjectName(m_chat->id()); m_llmThread.setObjectName(parent->id());
m_llmThread.start(); m_llmThread.start();
} }
@ -137,14 +137,11 @@ void ChatLLM::handleThreadStarted()
bool ChatLLM::loadDefaultModel() bool ChatLLM::loadDefaultModel()
{ {
const QList<QString> models = m_chat->modelList(); if (m_defaultModel.isEmpty()) {
if (models.isEmpty()) { emit modelLoadingError(QString("Could not find default model to load"));
// try again when we get a list of models
connect(Download::globalInstance(), &Download::modelListChanged, this,
&ChatLLM::loadDefaultModel, Qt::SingleShotConnection);
return false; return false;
} }
return loadModel(models.first()); return loadModel(m_defaultModel);
} }
bool ChatLLM::loadModel(const QString &modelName) bool ChatLLM::loadModel(const QString &modelName)
@ -170,7 +167,7 @@ bool ChatLLM::loadModel(const QString &modelName)
if (alreadyAcquired) { if (alreadyAcquired) {
resetContext(); resetContext();
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "already acquired model deleted" << m_chat->id() << m_modelInfo.model; qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_modelInfo.model;
#endif #endif
delete m_modelInfo.model; delete m_modelInfo.model;
m_modelInfo.model = nullptr; m_modelInfo.model = nullptr;
@ -181,14 +178,14 @@ bool ChatLLM::loadModel(const QString &modelName)
// returned to it, then the modelInfo.model pointer should be null which will happen on startup // returned to it, then the modelInfo.model pointer should be null which will happen on startup
m_modelInfo = LLModelStore::globalInstance()->acquireModel(); m_modelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_chat->id() << m_modelInfo.model; qDebug() << "acquired model from store" << m_llmThread.objectName() << m_modelInfo.model;
#endif #endif
// At this point it is possible that while we were blocked waiting to acquire the model from the // At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model // store, that our state was changed to not be loaded. If this is the case, release the model
// back into the store and quit loading // back into the store and quit loading
if (!m_shouldBeLoaded) { if (!m_shouldBeLoaded) {
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "no longer need model" << m_chat->id() << m_modelInfo.model; qDebug() << "no longer need model" << m_llmThread.objectName() << m_modelInfo.model;
#endif #endif
LLModelStore::globalInstance()->releaseModel(m_modelInfo); LLModelStore::globalInstance()->releaseModel(m_modelInfo);
m_modelInfo = LLModelInfo(); m_modelInfo = LLModelInfo();
@ -199,7 +196,7 @@ bool ChatLLM::loadModel(const QString &modelName)
// Check if the store just gave us exactly the model we were looking for // Check if the store just gave us exactly the model we were looking for
if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) { if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) {
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_chat->id() << m_modelInfo.model; qDebug() << "store had our model" << m_llmThread.objectName() << m_modelInfo.model;
#endif #endif
restoreState(); restoreState();
emit isModelLoadedChanged(); emit isModelLoadedChanged();
@ -207,7 +204,7 @@ bool ChatLLM::loadModel(const QString &modelName)
} else { } else {
// Release the memory since we have to switch to a different model. // Release the memory since we have to switch to a different model.
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "deleting model" << m_chat->id() << m_modelInfo.model; qDebug() << "deleting model" << m_llmThread.objectName() << m_modelInfo.model;
#endif #endif
delete m_modelInfo.model; delete m_modelInfo.model;
m_modelInfo.model = nullptr; m_modelInfo.model = nullptr;
@ -239,13 +236,28 @@ bool ChatLLM::loadModel(const QString &modelName)
} else { } else {
m_modelInfo.model = LLModel::construct(filePath.toStdString()); m_modelInfo.model = LLModel::construct(filePath.toStdString());
if (m_modelInfo.model) { if (m_modelInfo.model) {
m_modelInfo.model->loadModel(filePath.toStdString()); bool success = m_modelInfo.model->loadModel(filePath.toStdString());
switch (m_modelInfo.model->implementation().modelType[0]) { if (!success) {
case 'L': m_modelType = LLModelType::LLAMA_; break; delete std::exchange(m_modelInfo.model, nullptr);
case 'G': m_modelType = LLModelType::GPTJ_; break; if (!m_isServer)
case 'M': m_modelType = LLModelType::MPT_; break; LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
case 'R': m_modelType = LLModelType::REPLIT_; break; m_modelInfo = LLModelInfo();
default: delete std::exchange(m_modelInfo.model, nullptr); emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelName));
} else {
switch (m_modelInfo.model->implementation().modelType[0]) {
case 'L': m_modelType = LLModelType::LLAMA_; break;
case 'G': m_modelType = LLModelType::GPTJ_; break;
case 'M': m_modelType = LLModelType::MPT_; break;
case 'R': m_modelType = LLModelType::REPLIT_; break;
default:
{
delete std::exchange(m_modelInfo.model, nullptr);
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
m_modelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelName));
}
}
} }
} else { } else {
if (!m_isServer) if (!m_isServer)
@ -255,11 +267,11 @@ bool ChatLLM::loadModel(const QString &modelName)
} }
} }
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "new model" << m_chat->id() << m_modelInfo.model; qDebug() << "new model" << m_llmThread.objectName() << m_modelInfo.model;
#endif #endif
restoreState(); restoreState();
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "modelLoadedChanged" << m_chat->id(); qDebug() << "modelLoadedChanged" << m_llmThread.objectName();
fflush(stdout); fflush(stdout);
#endif #endif
emit isModelLoadedChanged(); emit isModelLoadedChanged();
@ -368,7 +380,7 @@ bool ChatLLM::handlePrompt(int32_t token)
// m_promptResponseTokens is related to last prompt/response not // m_promptResponseTokens is related to last prompt/response not
// the entire context window which we can reset on regenerate prompt // the entire context window which we can reset on regenerate prompt
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "prompt process" << m_chat->id() << token; qDebug() << "prompt process" << m_llmThread.objectName() << token;
#endif #endif
++m_promptTokens; ++m_promptTokens;
++m_promptResponseTokens; ++m_promptResponseTokens;
@ -409,7 +421,7 @@ bool ChatLLM::handleRecalculate(bool isRecalc)
return !m_stopGenerating; return !m_stopGenerating;
} }
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k,
float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads) float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads)
{ {
if (!isModelLoaded()) if (!isModelLoaded())
@ -417,7 +429,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
QList<ResultInfo> databaseResults; QList<ResultInfo> databaseResults;
const int retrievalSize = LocalDocs::globalInstance()->retrievalSize(); const int retrievalSize = LocalDocs::globalInstance()->retrievalSize();
emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &databaseResults); // blocks emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
emit databaseResultsChanged(databaseResults); emit databaseResultsChanged(databaseResults);
// Augment the prompt template with the results if any // Augment the prompt template with the results if any
@ -468,7 +480,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
void ChatLLM::setShouldBeLoaded(bool b) void ChatLLM::setShouldBeLoaded(bool b)
{ {
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "setShouldBeLoaded" << m_chat->id() << b << m_modelInfo.model; qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_modelInfo.model;
#endif #endif
m_shouldBeLoaded = b; // atomic m_shouldBeLoaded = b; // atomic
emit shouldBeLoadedChanged(); emit shouldBeLoadedChanged();
@ -495,7 +507,7 @@ void ChatLLM::unloadModel()
saveState(); saveState();
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "unloadModel" << m_chat->id() << m_modelInfo.model; qDebug() << "unloadModel" << m_llmThread.objectName() << m_modelInfo.model;
#endif #endif
LLModelStore::globalInstance()->releaseModel(m_modelInfo); LLModelStore::globalInstance()->releaseModel(m_modelInfo);
m_modelInfo = LLModelInfo(); m_modelInfo = LLModelInfo();
@ -508,7 +520,7 @@ void ChatLLM::reloadModel()
return; return;
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "reloadModel" << m_chat->id() << m_modelInfo.model; qDebug() << "reloadModel" << m_llmThread.objectName() << m_modelInfo.model;
#endif #endif
if (m_modelName.isEmpty()) { if (m_modelName.isEmpty()) {
loadDefaultModel(); loadDefaultModel();
@ -547,9 +559,14 @@ void ChatLLM::generateName()
} }
} }
void ChatLLM::handleChatIdChanged() void ChatLLM::handleChatIdChanged(const QString &id)
{ {
m_llmThread.setObjectName(m_chat->id()); m_llmThread.setObjectName(id);
}
void ChatLLM::handleDefaultModelChanged(const QString &defaultModel)
{
m_defaultModel = defaultModel;
} }
bool ChatLLM::handleNamePrompt(int32_t token) bool ChatLLM::handleNamePrompt(int32_t token)
@ -605,7 +622,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
QByteArray compressed = qCompress(m_state); QByteArray compressed = qCompress(m_state);
stream << compressed; stream << compressed;
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "serialize" << m_chat->id() << m_state.size(); qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
#endif #endif
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }
@ -645,7 +662,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
stream >> m_state; stream >> m_state;
} }
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "deserialize" << m_chat->id(); qDebug() << "deserialize" << m_llmThread.objectName();
#endif #endif
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }
@ -667,7 +684,7 @@ void ChatLLM::saveState()
const size_t stateSize = m_modelInfo.model->stateSize(); const size_t stateSize = m_modelInfo.model->stateSize();
m_state.resize(stateSize); m_state.resize(stateSize);
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "saveState" << m_chat->id() << "size:" << m_state.size(); qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif #endif
m_modelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data()))); m_modelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
} }
@ -690,7 +707,7 @@ void ChatLLM::restoreState()
} }
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size(); qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif #endif
m_modelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data()))); m_modelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_state.clear(); m_state.clear();

View File

@ -100,9 +100,9 @@ public:
bool deserialize(QDataStream &stream, int version); bool deserialize(QDataStream &stream, int version);
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, bool prompt(const QList<QString> &collectionList, const QString &prompt, const QString &prompt_template,
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t n_threads); int32_t repeat_penalty_tokens, int32_t n_threads);
bool loadDefaultModel(); bool loadDefaultModel();
bool loadModel(const QString &modelName); bool loadModel(const QString &modelName);
void modelNameChangeRequested(const QString &modelName); void modelNameChangeRequested(const QString &modelName);
@ -110,7 +110,8 @@ public Q_SLOTS:
void unloadModel(); void unloadModel();
void reloadModel(); void reloadModel();
void generateName(); void generateName();
void handleChatIdChanged(); void handleChatIdChanged(const QString &id);
void handleDefaultModelChanged(const QString &defaultModel);
void handleShouldBeLoadedChanged(); void handleShouldBeLoadedChanged();
void handleThreadStarted(); void handleThreadStarted();
@ -143,15 +144,24 @@ protected:
void restoreState(); void restoreState();
protected: protected:
// The following are all accessed by multiple threads and are thus guarded with thread protection
// mechanisms
LLModel::PromptContext m_ctx; LLModel::PromptContext m_ctx;
quint32 m_promptTokens; quint32 m_promptTokens;
quint32 m_promptResponseTokens; quint32 m_promptResponseTokens;
LLModelInfo m_modelInfo;
LLModelType m_modelType; private:
// The following are all accessed by multiple threads and are thus guarded with thread protection
// mechanisms
std::string m_response; std::string m_response;
std::string m_nameResponse; std::string m_nameResponse;
LLModelInfo m_modelInfo;
LLModelType m_modelType;
QString m_modelName; QString m_modelName;
Chat *m_chat; bool m_isChatGPT;
// The following are only accessed by this thread
QString m_defaultModel;
TokenTimer *m_timer; TokenTimer *m_timer;
QByteArray m_state; QByteArray m_state;
QThread m_llmThread; QThread m_llmThread;
@ -159,7 +169,6 @@ protected:
std::atomic<bool> m_shouldBeLoaded; std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_isRecalc; std::atomic<bool> m_isRecalc;
bool m_isServer; bool m_isServer;
bool m_isChatGPT;
}; };
#endif // CHATLLM_H #endif // CHATLLM_H

View File

@ -72,6 +72,7 @@ Server::Server(Chat *chat)
{ {
connect(this, &Server::threadStarted, this, &Server::start); connect(this, &Server::threadStarted, this, &Server::start);
connect(this, &Server::databaseResultsChanged, this, &Server::handleDatabaseResultsChanged); connect(this, &Server::databaseResultsChanged, this, &Server::handleDatabaseResultsChanged);
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
} }
Server::~Server() Server::~Server()
@ -315,7 +316,9 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
int responseTokens = 0; int responseTokens = 0;
QList<QPair<QString, QList<ResultInfo>>> responses; QList<QPair<QString, QList<ResultInfo>>> responses;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
if (!prompt(actualPrompt, if (!prompt(
m_collections,
actualPrompt,
promptTemplate, promptTemplate,
max_tokens /*n_predict*/, max_tokens /*n_predict*/,
top_k, top_k,

View File

@ -23,11 +23,13 @@ Q_SIGNALS:
private Q_SLOTS: private Q_SLOTS:
QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat); QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; } void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; }
void handleCollectionListChanged(const QList<QString> &collectionList) { m_collections = collectionList; }
private: private:
Chat *m_chat; Chat *m_chat;
QHttpServer *m_server; QHttpServer *m_server;
QList<ResultInfo> m_databaseResults; QList<ResultInfo> m_databaseResults;
QList<QString> m_collections;
}; };
#endif // SERVER_H #endif // SERVER_H