From 7d2ce060291b002b3e6d033d362f0cc5864b8f87 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 19 Jun 2023 19:51:28 -0400 Subject: [PATCH] Start working on more thread safety and model load error handling. --- gpt4all-chat/chat.cpp | 38 ++++++++++++++---- gpt4all-chat/chat.h | 14 ++++--- gpt4all-chat/chatllm.cpp | 87 ++++++++++++++++++++++++---------------- gpt4all-chat/chatllm.h | 25 ++++++++---- gpt4all-chat/server.cpp | 5 ++- gpt4all-chat/server.h | 2 + 6 files changed, 115 insertions(+), 56 deletions(-) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 9cbeccca..cde4c99c 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -43,9 +43,16 @@ Chat::~Chat() void Chat::connectLLM() { + const QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); + const QString localPath = Download::globalInstance()->downloadLocalModelsPath(); + m_watcher = new QFileSystemWatcher(this); + m_watcher->addPath(exePath); + m_watcher->addPath(localPath); + // Should be in same thread - connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection); - connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection); + connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::handleModelListChanged, Qt::DirectConnection); + connect(this, &Chat::modelNameChanged, this, &Chat::handleModelListChanged, Qt::DirectConnection); + connect(m_watcher, &QFileSystemWatcher::directoryChanged, this, &Chat::handleModelListChanged); // Should be in different threads connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); @@ -71,6 +78,8 @@ void Chat::connectLLM() // to respond to connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection); connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection); + + emit defaultModelChanged(modelList().first()); } void Chat::reset() @@ -80,7 +89,7 @@ void Chat::reset() LLM::globalInstance()->chatListModel()->removeChatFile(this); emit resetContextRequested(); // blocking queued connection m_id = Network::globalInstance()->generateUniqueId(); - emit idChanged(); + emit idChanged(m_id); // NOTE: We deliberately do no reset the name or creation date to indictate that this was originally // an older chat that was reset for another purpose. Resetting this data will lead to the chat // name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat' @@ -118,6 +127,7 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t { resetResponseState(); emit promptRequested( + m_collections, prompt, prompt_template, n_predict, @@ -275,6 +285,7 @@ bool Chat::isRecalc() const void Chat::loadDefaultModel() { + emit defaultModelChanged(modelList().first()); m_modelLoadingError = QString(); emit modelLoadingErrorChanged(); emit loadDefaultModelRequested(); @@ -369,7 +380,7 @@ bool Chat::deserialize(QDataStream &stream, int version) { stream >> m_creationDate; stream >> m_id; - emit idChanged(); + emit idChanged(m_id); stream >> m_name; stream >> m_userName; emit nameChanged(); @@ -380,7 +391,7 @@ bool Chat::deserialize(QDataStream &stream, int version) return false; if (version > 2) { stream >> m_collections; - emit collectionListChanged(); + emit collectionListChanged(m_collections); } m_llmodel->setModelName(m_savedModelName); if (!m_llmodel->deserialize(stream, version)) @@ -475,6 +486,19 @@ QList Chat::modelList() const return list; } +void Chat::handleModelListChanged() +{ + emit modelListChanged(); + emit defaultModelChanged(modelList().first()); +} + +void Chat::handleDownloadLocalModelsPathChanged() +{ + emit modelListChanged(); + emit defaultModelChanged(modelList().first()); + m_watcher->addPath(Download::globalInstance()->downloadLocalModelsPath()); +} + QList Chat::collectionList() const { return m_collections; @@ -491,7 +515,7 @@ void Chat::addCollection(const QString &collection) return; m_collections.append(collection); - emit collectionListChanged(); + emit collectionListChanged(m_collections); } void Chat::removeCollection(const QString &collection) @@ -500,5 +524,5 @@ void Chat::removeCollection(const QString &collection) return; m_collections.removeAll(collection); - emit collectionListChanged(); + emit collectionListChanged(m_collections); } diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index dc514274..6067ac01 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -98,16 +98,16 @@ public Q_SLOTS: void serverNewPromptResponsePair(const QString &prompt); Q_SIGNALS: - void idChanged(); + void idChanged(const QString &id); void nameChanged(); void chatModelChanged(); void isModelLoadedChanged(); void responseChanged(); void responseInProgressChanged(); void responseStateChanged(); - void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict, - int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, - int32_t n_threads); + void promptRequested(const QList &collectionList, const QString &prompt, const QString &prompt_template, + int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens, int32_t n_threads); void regenerateResponseRequested(); void resetResponseRequested(); void resetContextRequested(); @@ -120,8 +120,9 @@ Q_SIGNALS: void modelListChanged(); void modelLoadingErrorChanged(); void isServerChanged(); - void collectionListChanged(); + void collectionListChanged(const QList &collectionList); void tokenSpeedChanged(); + void defaultModelChanged(const QString &defaultModel); private Q_SLOTS: void handleResponseChanged(); @@ -134,6 +135,8 @@ private Q_SLOTS: void handleModelLoadingError(const QString &error); void handleTokenSpeedChanged(const QString &tokenSpeed); void handleDatabaseResultsChanged(const QList &results); + void handleModelListChanged(); + void handleDownloadLocalModelsPathChanged(); private: QString m_id; @@ -151,6 +154,7 @@ private: QList m_databaseResults; bool m_isServer; bool m_shouldDeleteLater; + QFileSystemWatcher *m_watcher; }; #endif // CHAT_H diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index cbe8c1bf..d3a400d5 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -94,7 +94,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_isRecalc(false) , m_shouldBeLoaded(true) , m_stopGenerating(false) - , m_chat(parent) , m_timer(nullptr) , m_isServer(isServer) , m_isChatGPT(false) @@ -104,14 +103,15 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection); // explicitly queued - connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); + connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); + connect(parent, &Chat::defaultModelChanged, this, &ChatLLM::handleDefaultModelChanged); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); // The following are blocking operations and will block the llm thread connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, Qt::BlockingQueuedConnection); - m_llmThread.setObjectName(m_chat->id()); + m_llmThread.setObjectName(parent->id()); m_llmThread.start(); } @@ -137,14 +137,11 @@ void ChatLLM::handleThreadStarted() bool ChatLLM::loadDefaultModel() { - const QList models = m_chat->modelList(); - if (models.isEmpty()) { - // try again when we get a list of models - connect(Download::globalInstance(), &Download::modelListChanged, this, - &ChatLLM::loadDefaultModel, Qt::SingleShotConnection); + if (m_defaultModel.isEmpty()) { + emit modelLoadingError(QString("Could not find default model to load")); return false; } - return loadModel(models.first()); + return loadModel(m_defaultModel); } bool ChatLLM::loadModel(const QString &modelName) @@ -170,7 +167,7 @@ bool ChatLLM::loadModel(const QString &modelName) if (alreadyAcquired) { resetContext(); #if defined(DEBUG_MODEL_LOADING) - qDebug() << "already acquired model deleted" << m_chat->id() << m_modelInfo.model; + qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_modelInfo.model; #endif delete m_modelInfo.model; m_modelInfo.model = nullptr; @@ -181,14 +178,14 @@ bool ChatLLM::loadModel(const QString &modelName) // returned to it, then the modelInfo.model pointer should be null which will happen on startup m_modelInfo = LLModelStore::globalInstance()->acquireModel(); #if defined(DEBUG_MODEL_LOADING) - qDebug() << "acquired model from store" << m_chat->id() << m_modelInfo.model; + qDebug() << "acquired model from store" << m_llmThread.objectName() << m_modelInfo.model; #endif // At this point it is possible that while we were blocked waiting to acquire the model from the // store, that our state was changed to not be loaded. If this is the case, release the model // back into the store and quit loading if (!m_shouldBeLoaded) { #if defined(DEBUG_MODEL_LOADING) - qDebug() << "no longer need model" << m_chat->id() << m_modelInfo.model; + qDebug() << "no longer need model" << m_llmThread.objectName() << m_modelInfo.model; #endif LLModelStore::globalInstance()->releaseModel(m_modelInfo); m_modelInfo = LLModelInfo(); @@ -199,7 +196,7 @@ bool ChatLLM::loadModel(const QString &modelName) // Check if the store just gave us exactly the model we were looking for if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) { #if defined(DEBUG_MODEL_LOADING) - qDebug() << "store had our model" << m_chat->id() << m_modelInfo.model; + qDebug() << "store had our model" << m_llmThread.objectName() << m_modelInfo.model; #endif restoreState(); emit isModelLoadedChanged(); @@ -207,7 +204,7 @@ bool ChatLLM::loadModel(const QString &modelName) } else { // Release the memory since we have to switch to a different model. #if defined(DEBUG_MODEL_LOADING) - qDebug() << "deleting model" << m_chat->id() << m_modelInfo.model; + qDebug() << "deleting model" << m_llmThread.objectName() << m_modelInfo.model; #endif delete m_modelInfo.model; m_modelInfo.model = nullptr; @@ -239,13 +236,28 @@ bool ChatLLM::loadModel(const QString &modelName) } else { m_modelInfo.model = LLModel::construct(filePath.toStdString()); if (m_modelInfo.model) { - m_modelInfo.model->loadModel(filePath.toStdString()); - switch (m_modelInfo.model->implementation().modelType[0]) { - case 'L': m_modelType = LLModelType::LLAMA_; break; - case 'G': m_modelType = LLModelType::GPTJ_; break; - case 'M': m_modelType = LLModelType::MPT_; break; - case 'R': m_modelType = LLModelType::REPLIT_; break; - default: delete std::exchange(m_modelInfo.model, nullptr); + bool success = m_modelInfo.model->loadModel(filePath.toStdString()); + if (!success) { + delete std::exchange(m_modelInfo.model, nullptr); + if (!m_isServer) + LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store + m_modelInfo = LLModelInfo(); + emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelName)); + } else { + switch (m_modelInfo.model->implementation().modelType[0]) { + case 'L': m_modelType = LLModelType::LLAMA_; break; + case 'G': m_modelType = LLModelType::GPTJ_; break; + case 'M': m_modelType = LLModelType::MPT_; break; + case 'R': m_modelType = LLModelType::REPLIT_; break; + default: + { + delete std::exchange(m_modelInfo.model, nullptr); + if (!m_isServer) + LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store + m_modelInfo = LLModelInfo(); + emit modelLoadingError(QString("Could not determine model type for %1").arg(modelName)); + } + } } } else { if (!m_isServer) @@ -255,11 +267,11 @@ bool ChatLLM::loadModel(const QString &modelName) } } #if defined(DEBUG_MODEL_LOADING) - qDebug() << "new model" << m_chat->id() << m_modelInfo.model; + qDebug() << "new model" << m_llmThread.objectName() << m_modelInfo.model; #endif restoreState(); #if defined(DEBUG) - qDebug() << "modelLoadedChanged" << m_chat->id(); + qDebug() << "modelLoadedChanged" << m_llmThread.objectName(); fflush(stdout); #endif emit isModelLoadedChanged(); @@ -368,7 +380,7 @@ bool ChatLLM::handlePrompt(int32_t token) // m_promptResponseTokens is related to last prompt/response not // the entire context window which we can reset on regenerate prompt #if defined(DEBUG) - qDebug() << "prompt process" << m_chat->id() << token; + qDebug() << "prompt process" << m_llmThread.objectName() << token; #endif ++m_promptTokens; ++m_promptResponseTokens; @@ -409,7 +421,7 @@ bool ChatLLM::handleRecalculate(bool isRecalc) return !m_stopGenerating; } -bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, +bool ChatLLM::prompt(const QList &collectionList, const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads) { if (!isModelLoaded()) @@ -417,7 +429,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 QList databaseResults; const int retrievalSize = LocalDocs::globalInstance()->retrievalSize(); - emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &databaseResults); // blocks + emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks emit databaseResultsChanged(databaseResults); // Augment the prompt template with the results if any @@ -468,7 +480,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3 void ChatLLM::setShouldBeLoaded(bool b) { #if defined(DEBUG_MODEL_LOADING) - qDebug() << "setShouldBeLoaded" << m_chat->id() << b << m_modelInfo.model; + qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_modelInfo.model; #endif m_shouldBeLoaded = b; // atomic emit shouldBeLoadedChanged(); @@ -495,7 +507,7 @@ void ChatLLM::unloadModel() saveState(); #if defined(DEBUG_MODEL_LOADING) - qDebug() << "unloadModel" << m_chat->id() << m_modelInfo.model; + qDebug() << "unloadModel" << m_llmThread.objectName() << m_modelInfo.model; #endif LLModelStore::globalInstance()->releaseModel(m_modelInfo); m_modelInfo = LLModelInfo(); @@ -508,7 +520,7 @@ void ChatLLM::reloadModel() return; #if defined(DEBUG_MODEL_LOADING) - qDebug() << "reloadModel" << m_chat->id() << m_modelInfo.model; + qDebug() << "reloadModel" << m_llmThread.objectName() << m_modelInfo.model; #endif if (m_modelName.isEmpty()) { loadDefaultModel(); @@ -547,9 +559,14 @@ void ChatLLM::generateName() } } -void ChatLLM::handleChatIdChanged() +void ChatLLM::handleChatIdChanged(const QString &id) { - m_llmThread.setObjectName(m_chat->id()); + m_llmThread.setObjectName(id); +} + +void ChatLLM::handleDefaultModelChanged(const QString &defaultModel) +{ + m_defaultModel = defaultModel; } bool ChatLLM::handleNamePrompt(int32_t token) @@ -605,7 +622,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version) QByteArray compressed = qCompress(m_state); stream << compressed; #if defined(DEBUG) - qDebug() << "serialize" << m_chat->id() << m_state.size(); + qDebug() << "serialize" << m_llmThread.objectName() << m_state.size(); #endif return stream.status() == QDataStream::Ok; } @@ -645,7 +662,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version) stream >> m_state; } #if defined(DEBUG) - qDebug() << "deserialize" << m_chat->id(); + qDebug() << "deserialize" << m_llmThread.objectName(); #endif return stream.status() == QDataStream::Ok; } @@ -667,7 +684,7 @@ void ChatLLM::saveState() const size_t stateSize = m_modelInfo.model->stateSize(); m_state.resize(stateSize); #if defined(DEBUG) - qDebug() << "saveState" << m_chat->id() << "size:" << m_state.size(); + qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size(); #endif m_modelInfo.model->saveState(static_cast(reinterpret_cast(m_state.data()))); } @@ -690,7 +707,7 @@ void ChatLLM::restoreState() } #if defined(DEBUG) - qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size(); + qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size(); #endif m_modelInfo.model->restoreState(static_cast(reinterpret_cast(m_state.data()))); m_state.clear(); diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 736d29e4..ad13677d 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -100,9 +100,9 @@ public: bool deserialize(QDataStream &stream, int version); public Q_SLOTS: - bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, - int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, - int32_t n_threads); + bool prompt(const QList &collectionList, const QString &prompt, const QString &prompt_template, + int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens, int32_t n_threads); bool loadDefaultModel(); bool loadModel(const QString &modelName); void modelNameChangeRequested(const QString &modelName); @@ -110,7 +110,8 @@ public Q_SLOTS: void unloadModel(); void reloadModel(); void generateName(); - void handleChatIdChanged(); + void handleChatIdChanged(const QString &id); + void handleDefaultModelChanged(const QString &defaultModel); void handleShouldBeLoadedChanged(); void handleThreadStarted(); @@ -143,15 +144,24 @@ protected: void restoreState(); protected: + // The following are all accessed by multiple threads and are thus guarded with thread protection + // mechanisms LLModel::PromptContext m_ctx; quint32 m_promptTokens; quint32 m_promptResponseTokens; - LLModelInfo m_modelInfo; - LLModelType m_modelType; + +private: + // The following are all accessed by multiple threads and are thus guarded with thread protection + // mechanisms std::string m_response; std::string m_nameResponse; + LLModelInfo m_modelInfo; + LLModelType m_modelType; QString m_modelName; - Chat *m_chat; + bool m_isChatGPT; + + // The following are only accessed by this thread + QString m_defaultModel; TokenTimer *m_timer; QByteArray m_state; QThread m_llmThread; @@ -159,7 +169,6 @@ protected: std::atomic m_shouldBeLoaded; std::atomic m_isRecalc; bool m_isServer; - bool m_isChatGPT; }; #endif // CHATLLM_H diff --git a/gpt4all-chat/server.cpp b/gpt4all-chat/server.cpp index 868a86e9..35ca8b82 100644 --- a/gpt4all-chat/server.cpp +++ b/gpt4all-chat/server.cpp @@ -72,6 +72,7 @@ Server::Server(Chat *chat) { connect(this, &Server::threadStarted, this, &Server::start); connect(this, &Server::databaseResultsChanged, this, &Server::handleDatabaseResultsChanged); + connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection); } Server::~Server() @@ -315,7 +316,9 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re int responseTokens = 0; QList>> responses; for (int i = 0; i < n; ++i) { - if (!prompt(actualPrompt, + if (!prompt( + m_collections, + actualPrompt, promptTemplate, max_tokens /*n_predict*/, top_k, diff --git a/gpt4all-chat/server.h b/gpt4all-chat/server.h index ac6f1f75..0250d9bd 100644 --- a/gpt4all-chat/server.h +++ b/gpt4all-chat/server.h @@ -23,11 +23,13 @@ Q_SIGNALS: private Q_SLOTS: QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat); void handleDatabaseResultsChanged(const QList &results) { m_databaseResults = results; } + void handleCollectionListChanged(const QList &collectionList) { m_collections = collectionList; } private: Chat *m_chat; QHttpServer *m_server; QList m_databaseResults; + QList m_collections; }; #endif // SERVER_H