Start working on more thread safety and model load error handling.

This commit is contained in:
Adam Treat 2023-06-19 19:51:28 -04:00 committed by AT
parent d5f56d3308
commit 7d2ce06029
6 changed files with 115 additions and 56 deletions

View File

@ -43,9 +43,16 @@ Chat::~Chat()
void Chat::connectLLM()
{
const QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
const QString localPath = Download::globalInstance()->downloadLocalModelsPath();
m_watcher = new QFileSystemWatcher(this);
m_watcher->addPath(exePath);
m_watcher->addPath(localPath);
// Should be in same thread
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::handleModelListChanged, Qt::DirectConnection);
connect(this, &Chat::modelNameChanged, this, &Chat::handleModelListChanged, Qt::DirectConnection);
connect(m_watcher, &QFileSystemWatcher::directoryChanged, this, &Chat::handleModelListChanged);
// Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
@ -71,6 +78,8 @@ void Chat::connectLLM()
// to respond to
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection);
emit defaultModelChanged(modelList().first());
}
void Chat::reset()
@ -80,7 +89,7 @@ void Chat::reset()
LLM::globalInstance()->chatListModel()->removeChatFile(this);
emit resetContextRequested(); // blocking queued connection
m_id = Network::globalInstance()->generateUniqueId();
emit idChanged();
emit idChanged(m_id);
// NOTE: We deliberately do no reset the name or creation date to indictate that this was originally
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
@ -118,6 +127,7 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t
{
resetResponseState();
emit promptRequested(
m_collections,
prompt,
prompt_template,
n_predict,
@ -275,6 +285,7 @@ bool Chat::isRecalc() const
void Chat::loadDefaultModel()
{
emit defaultModelChanged(modelList().first());
m_modelLoadingError = QString();
emit modelLoadingErrorChanged();
emit loadDefaultModelRequested();
@ -369,7 +380,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
{
stream >> m_creationDate;
stream >> m_id;
emit idChanged();
emit idChanged(m_id);
stream >> m_name;
stream >> m_userName;
emit nameChanged();
@ -380,7 +391,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
return false;
if (version > 2) {
stream >> m_collections;
emit collectionListChanged();
emit collectionListChanged(m_collections);
}
m_llmodel->setModelName(m_savedModelName);
if (!m_llmodel->deserialize(stream, version))
@ -475,6 +486,19 @@ QList<QString> Chat::modelList() const
return list;
}
void Chat::handleModelListChanged()
{
emit modelListChanged();
emit defaultModelChanged(modelList().first());
}
void Chat::handleDownloadLocalModelsPathChanged()
{
emit modelListChanged();
emit defaultModelChanged(modelList().first());
m_watcher->addPath(Download::globalInstance()->downloadLocalModelsPath());
}
QList<QString> Chat::collectionList() const
{
return m_collections;
@ -491,7 +515,7 @@ void Chat::addCollection(const QString &collection)
return;
m_collections.append(collection);
emit collectionListChanged();
emit collectionListChanged(m_collections);
}
void Chat::removeCollection(const QString &collection)
@ -500,5 +524,5 @@ void Chat::removeCollection(const QString &collection)
return;
m_collections.removeAll(collection);
emit collectionListChanged();
emit collectionListChanged(m_collections);
}

View File

@ -98,16 +98,16 @@ public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt);
Q_SIGNALS:
void idChanged();
void idChanged(const QString &id);
void nameChanged();
void chatModelChanged();
void isModelLoadedChanged();
void responseChanged();
void responseInProgressChanged();
void responseStateChanged();
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict,
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
int32_t n_threads);
void promptRequested(const QList<QString> &collectionList, const QString &prompt, const QString &prompt_template,
int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, int32_t n_threads);
void regenerateResponseRequested();
void resetResponseRequested();
void resetContextRequested();
@ -120,8 +120,9 @@ Q_SIGNALS:
void modelListChanged();
void modelLoadingErrorChanged();
void isServerChanged();
void collectionListChanged();
void collectionListChanged(const QList<QString> &collectionList);
void tokenSpeedChanged();
void defaultModelChanged(const QString &defaultModel);
private Q_SLOTS:
void handleResponseChanged();
@ -134,6 +135,8 @@ private Q_SLOTS:
void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
void handleModelListChanged();
void handleDownloadLocalModelsPathChanged();
private:
QString m_id;
@ -151,6 +154,7 @@ private:
QList<ResultInfo> m_databaseResults;
bool m_isServer;
bool m_shouldDeleteLater;
QFileSystemWatcher *m_watcher;
};
#endif // CHAT_H

View File

@ -94,7 +94,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_isRecalc(false)
, m_shouldBeLoaded(true)
, m_stopGenerating(false)
, m_chat(parent)
, m_timer(nullptr)
, m_isServer(isServer)
, m_isChatGPT(false)
@ -104,14 +103,15 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(parent, &Chat::defaultModelChanged, this, &ChatLLM::handleDefaultModelChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
// The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
Qt::BlockingQueuedConnection);
m_llmThread.setObjectName(m_chat->id());
m_llmThread.setObjectName(parent->id());
m_llmThread.start();
}
@ -137,14 +137,11 @@ void ChatLLM::handleThreadStarted()
bool ChatLLM::loadDefaultModel()
{
const QList<QString> models = m_chat->modelList();
if (models.isEmpty()) {
// try again when we get a list of models
connect(Download::globalInstance(), &Download::modelListChanged, this,
&ChatLLM::loadDefaultModel, Qt::SingleShotConnection);
if (m_defaultModel.isEmpty()) {
emit modelLoadingError(QString("Could not find default model to load"));
return false;
}
return loadModel(models.first());
return loadModel(m_defaultModel);
}
bool ChatLLM::loadModel(const QString &modelName)
@ -170,7 +167,7 @@ bool ChatLLM::loadModel(const QString &modelName)
if (alreadyAcquired) {
resetContext();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "already acquired model deleted" << m_chat->id() << m_modelInfo.model;
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_modelInfo.model;
#endif
delete m_modelInfo.model;
m_modelInfo.model = nullptr;
@ -181,14 +178,14 @@ bool ChatLLM::loadModel(const QString &modelName)
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
m_modelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_chat->id() << m_modelInfo.model;
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_modelInfo.model;
#endif
// At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model
// back into the store and quit loading
if (!m_shouldBeLoaded) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "no longer need model" << m_chat->id() << m_modelInfo.model;
qDebug() << "no longer need model" << m_llmThread.objectName() << m_modelInfo.model;
#endif
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
m_modelInfo = LLModelInfo();
@ -199,7 +196,7 @@ bool ChatLLM::loadModel(const QString &modelName)
// Check if the store just gave us exactly the model we were looking for
if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_chat->id() << m_modelInfo.model;
qDebug() << "store had our model" << m_llmThread.objectName() << m_modelInfo.model;
#endif
restoreState();
emit isModelLoadedChanged();
@ -207,7 +204,7 @@ bool ChatLLM::loadModel(const QString &modelName)
} else {
// Release the memory since we have to switch to a different model.
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "deleting model" << m_chat->id() << m_modelInfo.model;
qDebug() << "deleting model" << m_llmThread.objectName() << m_modelInfo.model;
#endif
delete m_modelInfo.model;
m_modelInfo.model = nullptr;
@ -239,13 +236,28 @@ bool ChatLLM::loadModel(const QString &modelName)
} else {
m_modelInfo.model = LLModel::construct(filePath.toStdString());
if (m_modelInfo.model) {
m_modelInfo.model->loadModel(filePath.toStdString());
bool success = m_modelInfo.model->loadModel(filePath.toStdString());
if (!success) {
delete std::exchange(m_modelInfo.model, nullptr);
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
m_modelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelName));
} else {
switch (m_modelInfo.model->implementation().modelType[0]) {
case 'L': m_modelType = LLModelType::LLAMA_; break;
case 'G': m_modelType = LLModelType::GPTJ_; break;
case 'M': m_modelType = LLModelType::MPT_; break;
case 'R': m_modelType = LLModelType::REPLIT_; break;
default: delete std::exchange(m_modelInfo.model, nullptr);
default:
{
delete std::exchange(m_modelInfo.model, nullptr);
if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
m_modelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelName));
}
}
}
} else {
if (!m_isServer)
@ -255,11 +267,11 @@ bool ChatLLM::loadModel(const QString &modelName)
}
}
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "new model" << m_chat->id() << m_modelInfo.model;
qDebug() << "new model" << m_llmThread.objectName() << m_modelInfo.model;
#endif
restoreState();
#if defined(DEBUG)
qDebug() << "modelLoadedChanged" << m_chat->id();
qDebug() << "modelLoadedChanged" << m_llmThread.objectName();
fflush(stdout);
#endif
emit isModelLoadedChanged();
@ -368,7 +380,7 @@ bool ChatLLM::handlePrompt(int32_t token)
// m_promptResponseTokens is related to last prompt/response not
// the entire context window which we can reset on regenerate prompt
#if defined(DEBUG)
qDebug() << "prompt process" << m_chat->id() << token;
qDebug() << "prompt process" << m_llmThread.objectName() << token;
#endif
++m_promptTokens;
++m_promptResponseTokens;
@ -409,7 +421,7 @@ bool ChatLLM::handleRecalculate(bool isRecalc)
return !m_stopGenerating;
}
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k,
bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k,
float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads)
{
if (!isModelLoaded())
@ -417,7 +429,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
QList<ResultInfo> databaseResults;
const int retrievalSize = LocalDocs::globalInstance()->retrievalSize();
emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &databaseResults); // blocks
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
emit databaseResultsChanged(databaseResults);
// Augment the prompt template with the results if any
@ -468,7 +480,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
void ChatLLM::setShouldBeLoaded(bool b)
{
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "setShouldBeLoaded" << m_chat->id() << b << m_modelInfo.model;
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_modelInfo.model;
#endif
m_shouldBeLoaded = b; // atomic
emit shouldBeLoadedChanged();
@ -495,7 +507,7 @@ void ChatLLM::unloadModel()
saveState();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "unloadModel" << m_chat->id() << m_modelInfo.model;
qDebug() << "unloadModel" << m_llmThread.objectName() << m_modelInfo.model;
#endif
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
m_modelInfo = LLModelInfo();
@ -508,7 +520,7 @@ void ChatLLM::reloadModel()
return;
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "reloadModel" << m_chat->id() << m_modelInfo.model;
qDebug() << "reloadModel" << m_llmThread.objectName() << m_modelInfo.model;
#endif
if (m_modelName.isEmpty()) {
loadDefaultModel();
@ -547,9 +559,14 @@ void ChatLLM::generateName()
}
}
void ChatLLM::handleChatIdChanged()
void ChatLLM::handleChatIdChanged(const QString &id)
{
m_llmThread.setObjectName(m_chat->id());
m_llmThread.setObjectName(id);
}
void ChatLLM::handleDefaultModelChanged(const QString &defaultModel)
{
m_defaultModel = defaultModel;
}
bool ChatLLM::handleNamePrompt(int32_t token)
@ -605,7 +622,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
QByteArray compressed = qCompress(m_state);
stream << compressed;
#if defined(DEBUG)
qDebug() << "serialize" << m_chat->id() << m_state.size();
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
#endif
return stream.status() == QDataStream::Ok;
}
@ -645,7 +662,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
stream >> m_state;
}
#if defined(DEBUG)
qDebug() << "deserialize" << m_chat->id();
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
return stream.status() == QDataStream::Ok;
}
@ -667,7 +684,7 @@ void ChatLLM::saveState()
const size_t stateSize = m_modelInfo.model->stateSize();
m_state.resize(stateSize);
#if defined(DEBUG)
qDebug() << "saveState" << m_chat->id() << "size:" << m_state.size();
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
m_modelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
}
@ -690,7 +707,7 @@ void ChatLLM::restoreState()
}
#if defined(DEBUG)
qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size();
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
m_modelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_state.clear();

View File

@ -100,9 +100,9 @@ public:
bool deserialize(QDataStream &stream, int version);
public Q_SLOTS:
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
int32_t n_threads);
bool prompt(const QList<QString> &collectionList, const QString &prompt, const QString &prompt_template,
int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens, int32_t n_threads);
bool loadDefaultModel();
bool loadModel(const QString &modelName);
void modelNameChangeRequested(const QString &modelName);
@ -110,7 +110,8 @@ public Q_SLOTS:
void unloadModel();
void reloadModel();
void generateName();
void handleChatIdChanged();
void handleChatIdChanged(const QString &id);
void handleDefaultModelChanged(const QString &defaultModel);
void handleShouldBeLoadedChanged();
void handleThreadStarted();
@ -143,15 +144,24 @@ protected:
void restoreState();
protected:
// The following are all accessed by multiple threads and are thus guarded with thread protection
// mechanisms
LLModel::PromptContext m_ctx;
quint32 m_promptTokens;
quint32 m_promptResponseTokens;
LLModelInfo m_modelInfo;
LLModelType m_modelType;
private:
// The following are all accessed by multiple threads and are thus guarded with thread protection
// mechanisms
std::string m_response;
std::string m_nameResponse;
LLModelInfo m_modelInfo;
LLModelType m_modelType;
QString m_modelName;
Chat *m_chat;
bool m_isChatGPT;
// The following are only accessed by this thread
QString m_defaultModel;
TokenTimer *m_timer;
QByteArray m_state;
QThread m_llmThread;
@ -159,7 +169,6 @@ protected:
std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_isRecalc;
bool m_isServer;
bool m_isChatGPT;
};
#endif // CHATLLM_H

View File

@ -72,6 +72,7 @@ Server::Server(Chat *chat)
{
connect(this, &Server::threadStarted, this, &Server::start);
connect(this, &Server::databaseResultsChanged, this, &Server::handleDatabaseResultsChanged);
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
}
Server::~Server()
@ -315,7 +316,9 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
int responseTokens = 0;
QList<QPair<QString, QList<ResultInfo>>> responses;
for (int i = 0; i < n; ++i) {
if (!prompt(actualPrompt,
if (!prompt(
m_collections,
actualPrompt,
promptTemplate,
max_tokens /*n_predict*/,
top_k,

View File

@ -23,11 +23,13 @@ Q_SIGNALS:
private Q_SLOTS:
QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; }
void handleCollectionListChanged(const QList<QString> &collectionList) { m_collections = collectionList; }
private:
Chat *m_chat;
QHttpServer *m_server;
QList<ResultInfo> m_databaseResults;
QList<QString> m_collections;
};
#endif // SERVER_H