mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-09-19 23:35:41 +00:00
Start working on more thread safety and model load error handling.
This commit is contained in:
parent
d5f56d3308
commit
7d2ce06029
@ -43,9 +43,16 @@ Chat::~Chat()
|
|||||||
|
|
||||||
void Chat::connectLLM()
|
void Chat::connectLLM()
|
||||||
{
|
{
|
||||||
|
const QString exePath = QCoreApplication::applicationDirPath() + QDir::separator();
|
||||||
|
const QString localPath = Download::globalInstance()->downloadLocalModelsPath();
|
||||||
|
m_watcher = new QFileSystemWatcher(this);
|
||||||
|
m_watcher->addPath(exePath);
|
||||||
|
m_watcher->addPath(localPath);
|
||||||
|
|
||||||
// Should be in same thread
|
// Should be in same thread
|
||||||
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
|
connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::handleModelListChanged, Qt::DirectConnection);
|
||||||
connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection);
|
connect(this, &Chat::modelNameChanged, this, &Chat::handleModelListChanged, Qt::DirectConnection);
|
||||||
|
connect(m_watcher, &QFileSystemWatcher::directoryChanged, this, &Chat::handleModelListChanged);
|
||||||
|
|
||||||
// Should be in different threads
|
// Should be in different threads
|
||||||
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
|
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
|
||||||
@ -71,6 +78,8 @@ void Chat::connectLLM()
|
|||||||
// to respond to
|
// to respond to
|
||||||
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection);
|
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection);
|
||||||
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection);
|
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection);
|
||||||
|
|
||||||
|
emit defaultModelChanged(modelList().first());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::reset()
|
void Chat::reset()
|
||||||
@ -80,7 +89,7 @@ void Chat::reset()
|
|||||||
LLM::globalInstance()->chatListModel()->removeChatFile(this);
|
LLM::globalInstance()->chatListModel()->removeChatFile(this);
|
||||||
emit resetContextRequested(); // blocking queued connection
|
emit resetContextRequested(); // blocking queued connection
|
||||||
m_id = Network::globalInstance()->generateUniqueId();
|
m_id = Network::globalInstance()->generateUniqueId();
|
||||||
emit idChanged();
|
emit idChanged(m_id);
|
||||||
// NOTE: We deliberately do no reset the name or creation date to indictate that this was originally
|
// NOTE: We deliberately do no reset the name or creation date to indictate that this was originally
|
||||||
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
|
// an older chat that was reset for another purpose. Resetting this data will lead to the chat
|
||||||
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
|
// name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat'
|
||||||
@ -118,6 +127,7 @@ void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t
|
|||||||
{
|
{
|
||||||
resetResponseState();
|
resetResponseState();
|
||||||
emit promptRequested(
|
emit promptRequested(
|
||||||
|
m_collections,
|
||||||
prompt,
|
prompt,
|
||||||
prompt_template,
|
prompt_template,
|
||||||
n_predict,
|
n_predict,
|
||||||
@ -275,6 +285,7 @@ bool Chat::isRecalc() const
|
|||||||
|
|
||||||
void Chat::loadDefaultModel()
|
void Chat::loadDefaultModel()
|
||||||
{
|
{
|
||||||
|
emit defaultModelChanged(modelList().first());
|
||||||
m_modelLoadingError = QString();
|
m_modelLoadingError = QString();
|
||||||
emit modelLoadingErrorChanged();
|
emit modelLoadingErrorChanged();
|
||||||
emit loadDefaultModelRequested();
|
emit loadDefaultModelRequested();
|
||||||
@ -369,7 +380,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
|||||||
{
|
{
|
||||||
stream >> m_creationDate;
|
stream >> m_creationDate;
|
||||||
stream >> m_id;
|
stream >> m_id;
|
||||||
emit idChanged();
|
emit idChanged(m_id);
|
||||||
stream >> m_name;
|
stream >> m_name;
|
||||||
stream >> m_userName;
|
stream >> m_userName;
|
||||||
emit nameChanged();
|
emit nameChanged();
|
||||||
@ -380,7 +391,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
|||||||
return false;
|
return false;
|
||||||
if (version > 2) {
|
if (version > 2) {
|
||||||
stream >> m_collections;
|
stream >> m_collections;
|
||||||
emit collectionListChanged();
|
emit collectionListChanged(m_collections);
|
||||||
}
|
}
|
||||||
m_llmodel->setModelName(m_savedModelName);
|
m_llmodel->setModelName(m_savedModelName);
|
||||||
if (!m_llmodel->deserialize(stream, version))
|
if (!m_llmodel->deserialize(stream, version))
|
||||||
@ -475,6 +486,19 @@ QList<QString> Chat::modelList() const
|
|||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Chat::handleModelListChanged()
|
||||||
|
{
|
||||||
|
emit modelListChanged();
|
||||||
|
emit defaultModelChanged(modelList().first());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Chat::handleDownloadLocalModelsPathChanged()
|
||||||
|
{
|
||||||
|
emit modelListChanged();
|
||||||
|
emit defaultModelChanged(modelList().first());
|
||||||
|
m_watcher->addPath(Download::globalInstance()->downloadLocalModelsPath());
|
||||||
|
}
|
||||||
|
|
||||||
QList<QString> Chat::collectionList() const
|
QList<QString> Chat::collectionList() const
|
||||||
{
|
{
|
||||||
return m_collections;
|
return m_collections;
|
||||||
@ -491,7 +515,7 @@ void Chat::addCollection(const QString &collection)
|
|||||||
return;
|
return;
|
||||||
|
|
||||||
m_collections.append(collection);
|
m_collections.append(collection);
|
||||||
emit collectionListChanged();
|
emit collectionListChanged(m_collections);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::removeCollection(const QString &collection)
|
void Chat::removeCollection(const QString &collection)
|
||||||
@ -500,5 +524,5 @@ void Chat::removeCollection(const QString &collection)
|
|||||||
return;
|
return;
|
||||||
|
|
||||||
m_collections.removeAll(collection);
|
m_collections.removeAll(collection);
|
||||||
emit collectionListChanged();
|
emit collectionListChanged(m_collections);
|
||||||
}
|
}
|
||||||
|
@ -98,16 +98,16 @@ public Q_SLOTS:
|
|||||||
void serverNewPromptResponsePair(const QString &prompt);
|
void serverNewPromptResponsePair(const QString &prompt);
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void idChanged();
|
void idChanged(const QString &id);
|
||||||
void nameChanged();
|
void nameChanged();
|
||||||
void chatModelChanged();
|
void chatModelChanged();
|
||||||
void isModelLoadedChanged();
|
void isModelLoadedChanged();
|
||||||
void responseChanged();
|
void responseChanged();
|
||||||
void responseInProgressChanged();
|
void responseInProgressChanged();
|
||||||
void responseStateChanged();
|
void responseStateChanged();
|
||||||
void promptRequested(const QString &prompt, const QString &prompt_template, int32_t n_predict,
|
void promptRequested(const QList<QString> &collectionList, const QString &prompt, const QString &prompt_template,
|
||||||
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
|
int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||||
int32_t n_threads);
|
int32_t repeat_penalty_tokens, int32_t n_threads);
|
||||||
void regenerateResponseRequested();
|
void regenerateResponseRequested();
|
||||||
void resetResponseRequested();
|
void resetResponseRequested();
|
||||||
void resetContextRequested();
|
void resetContextRequested();
|
||||||
@ -120,8 +120,9 @@ Q_SIGNALS:
|
|||||||
void modelListChanged();
|
void modelListChanged();
|
||||||
void modelLoadingErrorChanged();
|
void modelLoadingErrorChanged();
|
||||||
void isServerChanged();
|
void isServerChanged();
|
||||||
void collectionListChanged();
|
void collectionListChanged(const QList<QString> &collectionList);
|
||||||
void tokenSpeedChanged();
|
void tokenSpeedChanged();
|
||||||
|
void defaultModelChanged(const QString &defaultModel);
|
||||||
|
|
||||||
private Q_SLOTS:
|
private Q_SLOTS:
|
||||||
void handleResponseChanged();
|
void handleResponseChanged();
|
||||||
@ -134,6 +135,8 @@ private Q_SLOTS:
|
|||||||
void handleModelLoadingError(const QString &error);
|
void handleModelLoadingError(const QString &error);
|
||||||
void handleTokenSpeedChanged(const QString &tokenSpeed);
|
void handleTokenSpeedChanged(const QString &tokenSpeed);
|
||||||
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
|
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
|
||||||
|
void handleModelListChanged();
|
||||||
|
void handleDownloadLocalModelsPathChanged();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
QString m_id;
|
QString m_id;
|
||||||
@ -151,6 +154,7 @@ private:
|
|||||||
QList<ResultInfo> m_databaseResults;
|
QList<ResultInfo> m_databaseResults;
|
||||||
bool m_isServer;
|
bool m_isServer;
|
||||||
bool m_shouldDeleteLater;
|
bool m_shouldDeleteLater;
|
||||||
|
QFileSystemWatcher *m_watcher;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // CHAT_H
|
#endif // CHAT_H
|
||||||
|
@ -94,7 +94,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
|||||||
, m_isRecalc(false)
|
, m_isRecalc(false)
|
||||||
, m_shouldBeLoaded(true)
|
, m_shouldBeLoaded(true)
|
||||||
, m_stopGenerating(false)
|
, m_stopGenerating(false)
|
||||||
, m_chat(parent)
|
|
||||||
, m_timer(nullptr)
|
, m_timer(nullptr)
|
||||||
, m_isServer(isServer)
|
, m_isServer(isServer)
|
||||||
, m_isChatGPT(false)
|
, m_isChatGPT(false)
|
||||||
@ -104,14 +103,15 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
|||||||
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
|
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
|
||||||
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
|
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
|
||||||
Qt::QueuedConnection); // explicitly queued
|
Qt::QueuedConnection); // explicitly queued
|
||||||
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
||||||
|
connect(parent, &Chat::defaultModelChanged, this, &ChatLLM::handleDefaultModelChanged);
|
||||||
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
|
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
|
||||||
|
|
||||||
// The following are blocking operations and will block the llm thread
|
// The following are blocking operations and will block the llm thread
|
||||||
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
|
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
|
||||||
Qt::BlockingQueuedConnection);
|
Qt::BlockingQueuedConnection);
|
||||||
|
|
||||||
m_llmThread.setObjectName(m_chat->id());
|
m_llmThread.setObjectName(parent->id());
|
||||||
m_llmThread.start();
|
m_llmThread.start();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,14 +137,11 @@ void ChatLLM::handleThreadStarted()
|
|||||||
|
|
||||||
bool ChatLLM::loadDefaultModel()
|
bool ChatLLM::loadDefaultModel()
|
||||||
{
|
{
|
||||||
const QList<QString> models = m_chat->modelList();
|
if (m_defaultModel.isEmpty()) {
|
||||||
if (models.isEmpty()) {
|
emit modelLoadingError(QString("Could not find default model to load"));
|
||||||
// try again when we get a list of models
|
|
||||||
connect(Download::globalInstance(), &Download::modelListChanged, this,
|
|
||||||
&ChatLLM::loadDefaultModel, Qt::SingleShotConnection);
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return loadModel(models.first());
|
return loadModel(m_defaultModel);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ChatLLM::loadModel(const QString &modelName)
|
bool ChatLLM::loadModel(const QString &modelName)
|
||||||
@ -170,7 +167,7 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|||||||
if (alreadyAcquired) {
|
if (alreadyAcquired) {
|
||||||
resetContext();
|
resetContext();
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
qDebug() << "already acquired model deleted" << m_chat->id() << m_modelInfo.model;
|
qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_modelInfo.model;
|
||||||
#endif
|
#endif
|
||||||
delete m_modelInfo.model;
|
delete m_modelInfo.model;
|
||||||
m_modelInfo.model = nullptr;
|
m_modelInfo.model = nullptr;
|
||||||
@ -181,14 +178,14 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|||||||
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
|
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
|
||||||
m_modelInfo = LLModelStore::globalInstance()->acquireModel();
|
m_modelInfo = LLModelStore::globalInstance()->acquireModel();
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
qDebug() << "acquired model from store" << m_chat->id() << m_modelInfo.model;
|
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_modelInfo.model;
|
||||||
#endif
|
#endif
|
||||||
// At this point it is possible that while we were blocked waiting to acquire the model from the
|
// At this point it is possible that while we were blocked waiting to acquire the model from the
|
||||||
// store, that our state was changed to not be loaded. If this is the case, release the model
|
// store, that our state was changed to not be loaded. If this is the case, release the model
|
||||||
// back into the store and quit loading
|
// back into the store and quit loading
|
||||||
if (!m_shouldBeLoaded) {
|
if (!m_shouldBeLoaded) {
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
qDebug() << "no longer need model" << m_chat->id() << m_modelInfo.model;
|
qDebug() << "no longer need model" << m_llmThread.objectName() << m_modelInfo.model;
|
||||||
#endif
|
#endif
|
||||||
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
|
||||||
m_modelInfo = LLModelInfo();
|
m_modelInfo = LLModelInfo();
|
||||||
@ -199,7 +196,7 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|||||||
// Check if the store just gave us exactly the model we were looking for
|
// Check if the store just gave us exactly the model we were looking for
|
||||||
if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) {
|
if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) {
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
qDebug() << "store had our model" << m_chat->id() << m_modelInfo.model;
|
qDebug() << "store had our model" << m_llmThread.objectName() << m_modelInfo.model;
|
||||||
#endif
|
#endif
|
||||||
restoreState();
|
restoreState();
|
||||||
emit isModelLoadedChanged();
|
emit isModelLoadedChanged();
|
||||||
@ -207,7 +204,7 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|||||||
} else {
|
} else {
|
||||||
// Release the memory since we have to switch to a different model.
|
// Release the memory since we have to switch to a different model.
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
qDebug() << "deleting model" << m_chat->id() << m_modelInfo.model;
|
qDebug() << "deleting model" << m_llmThread.objectName() << m_modelInfo.model;
|
||||||
#endif
|
#endif
|
||||||
delete m_modelInfo.model;
|
delete m_modelInfo.model;
|
||||||
m_modelInfo.model = nullptr;
|
m_modelInfo.model = nullptr;
|
||||||
@ -239,13 +236,28 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|||||||
} else {
|
} else {
|
||||||
m_modelInfo.model = LLModel::construct(filePath.toStdString());
|
m_modelInfo.model = LLModel::construct(filePath.toStdString());
|
||||||
if (m_modelInfo.model) {
|
if (m_modelInfo.model) {
|
||||||
m_modelInfo.model->loadModel(filePath.toStdString());
|
bool success = m_modelInfo.model->loadModel(filePath.toStdString());
|
||||||
switch (m_modelInfo.model->implementation().modelType[0]) {
|
if (!success) {
|
||||||
case 'L': m_modelType = LLModelType::LLAMA_; break;
|
delete std::exchange(m_modelInfo.model, nullptr);
|
||||||
case 'G': m_modelType = LLModelType::GPTJ_; break;
|
if (!m_isServer)
|
||||||
case 'M': m_modelType = LLModelType::MPT_; break;
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
|
||||||
case 'R': m_modelType = LLModelType::REPLIT_; break;
|
m_modelInfo = LLModelInfo();
|
||||||
default: delete std::exchange(m_modelInfo.model, nullptr);
|
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelName));
|
||||||
|
} else {
|
||||||
|
switch (m_modelInfo.model->implementation().modelType[0]) {
|
||||||
|
case 'L': m_modelType = LLModelType::LLAMA_; break;
|
||||||
|
case 'G': m_modelType = LLModelType::GPTJ_; break;
|
||||||
|
case 'M': m_modelType = LLModelType::MPT_; break;
|
||||||
|
case 'R': m_modelType = LLModelType::REPLIT_; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
delete std::exchange(m_modelInfo.model, nullptr);
|
||||||
|
if (!m_isServer)
|
||||||
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
|
||||||
|
m_modelInfo = LLModelInfo();
|
||||||
|
emit modelLoadingError(QString("Could not determine model type for %1").arg(modelName));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!m_isServer)
|
if (!m_isServer)
|
||||||
@ -255,11 +267,11 @@ bool ChatLLM::loadModel(const QString &modelName)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
qDebug() << "new model" << m_chat->id() << m_modelInfo.model;
|
qDebug() << "new model" << m_llmThread.objectName() << m_modelInfo.model;
|
||||||
#endif
|
#endif
|
||||||
restoreState();
|
restoreState();
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
qDebug() << "modelLoadedChanged" << m_chat->id();
|
qDebug() << "modelLoadedChanged" << m_llmThread.objectName();
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
#endif
|
#endif
|
||||||
emit isModelLoadedChanged();
|
emit isModelLoadedChanged();
|
||||||
@ -368,7 +380,7 @@ bool ChatLLM::handlePrompt(int32_t token)
|
|||||||
// m_promptResponseTokens is related to last prompt/response not
|
// m_promptResponseTokens is related to last prompt/response not
|
||||||
// the entire context window which we can reset on regenerate prompt
|
// the entire context window which we can reset on regenerate prompt
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
qDebug() << "prompt process" << m_chat->id() << token;
|
qDebug() << "prompt process" << m_llmThread.objectName() << token;
|
||||||
#endif
|
#endif
|
||||||
++m_promptTokens;
|
++m_promptTokens;
|
||||||
++m_promptResponseTokens;
|
++m_promptResponseTokens;
|
||||||
@ -409,7 +421,7 @@ bool ChatLLM::handleRecalculate(bool isRecalc)
|
|||||||
return !m_stopGenerating;
|
return !m_stopGenerating;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k,
|
bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k,
|
||||||
float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads)
|
float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, int n_threads)
|
||||||
{
|
{
|
||||||
if (!isModelLoaded())
|
if (!isModelLoaded())
|
||||||
@ -417,7 +429,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
|
|||||||
|
|
||||||
QList<ResultInfo> databaseResults;
|
QList<ResultInfo> databaseResults;
|
||||||
const int retrievalSize = LocalDocs::globalInstance()->retrievalSize();
|
const int retrievalSize = LocalDocs::globalInstance()->retrievalSize();
|
||||||
emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &databaseResults); // blocks
|
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
|
||||||
emit databaseResultsChanged(databaseResults);
|
emit databaseResultsChanged(databaseResults);
|
||||||
|
|
||||||
// Augment the prompt template with the results if any
|
// Augment the prompt template with the results if any
|
||||||
@ -468,7 +480,7 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
|
|||||||
void ChatLLM::setShouldBeLoaded(bool b)
|
void ChatLLM::setShouldBeLoaded(bool b)
|
||||||
{
|
{
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
qDebug() << "setShouldBeLoaded" << m_chat->id() << b << m_modelInfo.model;
|
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_modelInfo.model;
|
||||||
#endif
|
#endif
|
||||||
m_shouldBeLoaded = b; // atomic
|
m_shouldBeLoaded = b; // atomic
|
||||||
emit shouldBeLoadedChanged();
|
emit shouldBeLoadedChanged();
|
||||||
@ -495,7 +507,7 @@ void ChatLLM::unloadModel()
|
|||||||
|
|
||||||
saveState();
|
saveState();
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
qDebug() << "unloadModel" << m_chat->id() << m_modelInfo.model;
|
qDebug() << "unloadModel" << m_llmThread.objectName() << m_modelInfo.model;
|
||||||
#endif
|
#endif
|
||||||
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
|
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
|
||||||
m_modelInfo = LLModelInfo();
|
m_modelInfo = LLModelInfo();
|
||||||
@ -508,7 +520,7 @@ void ChatLLM::reloadModel()
|
|||||||
return;
|
return;
|
||||||
|
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
qDebug() << "reloadModel" << m_chat->id() << m_modelInfo.model;
|
qDebug() << "reloadModel" << m_llmThread.objectName() << m_modelInfo.model;
|
||||||
#endif
|
#endif
|
||||||
if (m_modelName.isEmpty()) {
|
if (m_modelName.isEmpty()) {
|
||||||
loadDefaultModel();
|
loadDefaultModel();
|
||||||
@ -547,9 +559,14 @@ void ChatLLM::generateName()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatLLM::handleChatIdChanged()
|
void ChatLLM::handleChatIdChanged(const QString &id)
|
||||||
{
|
{
|
||||||
m_llmThread.setObjectName(m_chat->id());
|
m_llmThread.setObjectName(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ChatLLM::handleDefaultModelChanged(const QString &defaultModel)
|
||||||
|
{
|
||||||
|
m_defaultModel = defaultModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ChatLLM::handleNamePrompt(int32_t token)
|
bool ChatLLM::handleNamePrompt(int32_t token)
|
||||||
@ -605,7 +622,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
|
|||||||
QByteArray compressed = qCompress(m_state);
|
QByteArray compressed = qCompress(m_state);
|
||||||
stream << compressed;
|
stream << compressed;
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
qDebug() << "serialize" << m_chat->id() << m_state.size();
|
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
|
||||||
#endif
|
#endif
|
||||||
return stream.status() == QDataStream::Ok;
|
return stream.status() == QDataStream::Ok;
|
||||||
}
|
}
|
||||||
@ -645,7 +662,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
|
|||||||
stream >> m_state;
|
stream >> m_state;
|
||||||
}
|
}
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
qDebug() << "deserialize" << m_chat->id();
|
qDebug() << "deserialize" << m_llmThread.objectName();
|
||||||
#endif
|
#endif
|
||||||
return stream.status() == QDataStream::Ok;
|
return stream.status() == QDataStream::Ok;
|
||||||
}
|
}
|
||||||
@ -667,7 +684,7 @@ void ChatLLM::saveState()
|
|||||||
const size_t stateSize = m_modelInfo.model->stateSize();
|
const size_t stateSize = m_modelInfo.model->stateSize();
|
||||||
m_state.resize(stateSize);
|
m_state.resize(stateSize);
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
qDebug() << "saveState" << m_chat->id() << "size:" << m_state.size();
|
qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size();
|
||||||
#endif
|
#endif
|
||||||
m_modelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
m_modelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
||||||
}
|
}
|
||||||
@ -690,7 +707,7 @@ void ChatLLM::restoreState()
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size();
|
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
|
||||||
#endif
|
#endif
|
||||||
m_modelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
m_modelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
|
||||||
m_state.clear();
|
m_state.clear();
|
||||||
|
@ -100,9 +100,9 @@ public:
|
|||||||
bool deserialize(QDataStream &stream, int version);
|
bool deserialize(QDataStream &stream, int version);
|
||||||
|
|
||||||
public Q_SLOTS:
|
public Q_SLOTS:
|
||||||
bool prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict,
|
bool prompt(const QList<QString> &collectionList, const QString &prompt, const QString &prompt_template,
|
||||||
int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens,
|
int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||||
int32_t n_threads);
|
int32_t repeat_penalty_tokens, int32_t n_threads);
|
||||||
bool loadDefaultModel();
|
bool loadDefaultModel();
|
||||||
bool loadModel(const QString &modelName);
|
bool loadModel(const QString &modelName);
|
||||||
void modelNameChangeRequested(const QString &modelName);
|
void modelNameChangeRequested(const QString &modelName);
|
||||||
@ -110,7 +110,8 @@ public Q_SLOTS:
|
|||||||
void unloadModel();
|
void unloadModel();
|
||||||
void reloadModel();
|
void reloadModel();
|
||||||
void generateName();
|
void generateName();
|
||||||
void handleChatIdChanged();
|
void handleChatIdChanged(const QString &id);
|
||||||
|
void handleDefaultModelChanged(const QString &defaultModel);
|
||||||
void handleShouldBeLoadedChanged();
|
void handleShouldBeLoadedChanged();
|
||||||
void handleThreadStarted();
|
void handleThreadStarted();
|
||||||
|
|
||||||
@ -143,15 +144,24 @@ protected:
|
|||||||
void restoreState();
|
void restoreState();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
// The following are all accessed by multiple threads and are thus guarded with thread protection
|
||||||
|
// mechanisms
|
||||||
LLModel::PromptContext m_ctx;
|
LLModel::PromptContext m_ctx;
|
||||||
quint32 m_promptTokens;
|
quint32 m_promptTokens;
|
||||||
quint32 m_promptResponseTokens;
|
quint32 m_promptResponseTokens;
|
||||||
LLModelInfo m_modelInfo;
|
|
||||||
LLModelType m_modelType;
|
private:
|
||||||
|
// The following are all accessed by multiple threads and are thus guarded with thread protection
|
||||||
|
// mechanisms
|
||||||
std::string m_response;
|
std::string m_response;
|
||||||
std::string m_nameResponse;
|
std::string m_nameResponse;
|
||||||
|
LLModelInfo m_modelInfo;
|
||||||
|
LLModelType m_modelType;
|
||||||
QString m_modelName;
|
QString m_modelName;
|
||||||
Chat *m_chat;
|
bool m_isChatGPT;
|
||||||
|
|
||||||
|
// The following are only accessed by this thread
|
||||||
|
QString m_defaultModel;
|
||||||
TokenTimer *m_timer;
|
TokenTimer *m_timer;
|
||||||
QByteArray m_state;
|
QByteArray m_state;
|
||||||
QThread m_llmThread;
|
QThread m_llmThread;
|
||||||
@ -159,7 +169,6 @@ protected:
|
|||||||
std::atomic<bool> m_shouldBeLoaded;
|
std::atomic<bool> m_shouldBeLoaded;
|
||||||
std::atomic<bool> m_isRecalc;
|
std::atomic<bool> m_isRecalc;
|
||||||
bool m_isServer;
|
bool m_isServer;
|
||||||
bool m_isChatGPT;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // CHATLLM_H
|
#endif // CHATLLM_H
|
||||||
|
@ -72,6 +72,7 @@ Server::Server(Chat *chat)
|
|||||||
{
|
{
|
||||||
connect(this, &Server::threadStarted, this, &Server::start);
|
connect(this, &Server::threadStarted, this, &Server::start);
|
||||||
connect(this, &Server::databaseResultsChanged, this, &Server::handleDatabaseResultsChanged);
|
connect(this, &Server::databaseResultsChanged, this, &Server::handleDatabaseResultsChanged);
|
||||||
|
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
|
||||||
}
|
}
|
||||||
|
|
||||||
Server::~Server()
|
Server::~Server()
|
||||||
@ -315,7 +316,9 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
|
|||||||
int responseTokens = 0;
|
int responseTokens = 0;
|
||||||
QList<QPair<QString, QList<ResultInfo>>> responses;
|
QList<QPair<QString, QList<ResultInfo>>> responses;
|
||||||
for (int i = 0; i < n; ++i) {
|
for (int i = 0; i < n; ++i) {
|
||||||
if (!prompt(actualPrompt,
|
if (!prompt(
|
||||||
|
m_collections,
|
||||||
|
actualPrompt,
|
||||||
promptTemplate,
|
promptTemplate,
|
||||||
max_tokens /*n_predict*/,
|
max_tokens /*n_predict*/,
|
||||||
top_k,
|
top_k,
|
||||||
|
@ -23,11 +23,13 @@ Q_SIGNALS:
|
|||||||
private Q_SLOTS:
|
private Q_SLOTS:
|
||||||
QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat);
|
QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat);
|
||||||
void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; }
|
void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; }
|
||||||
|
void handleCollectionListChanged(const QList<QString> &collectionList) { m_collections = collectionList; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Chat *m_chat;
|
Chat *m_chat;
|
||||||
QHttpServer *m_server;
|
QHttpServer *m_server;
|
||||||
QList<ResultInfo> m_databaseResults;
|
QList<ResultInfo> m_databaseResults;
|
||||||
|
QList<QString> m_collections;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // SERVER_H
|
#endif // SERVER_H
|
||||||
|
Loading…
Reference in New Issue
Block a user