Don't store db results in ChatLLM.

This commit is contained in:
Adam Treat 2023-06-19 18:23:54 -04:00 committed by AT
parent 0cfe225506
commit a3a6a20146
5 changed files with 17 additions and 12 deletions

View File

@ -58,6 +58,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
@ -177,11 +178,6 @@ void Chat::handleModelLoadedChanged()
deleteLater(); deleteLater();
} }
QList<ResultInfo> Chat::databaseResults() const
{
return m_llmodel->databaseResults();
}
void Chat::promptProcessing() void Chat::promptProcessing()
{ {
m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing; m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
@ -348,6 +344,11 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
emit tokenSpeedChanged(); emit tokenSpeedChanged();
} }
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
{
m_databaseResults = results;
}
bool Chat::serialize(QDataStream &stream, int version) const bool Chat::serialize(QDataStream &stream, int version) const
{ {
stream << m_creationDate; stream << m_creationDate;

View File

@ -61,7 +61,7 @@ public:
Q_INVOKABLE void stopGenerating(); Q_INVOKABLE void stopGenerating();
Q_INVOKABLE void newPromptResponsePair(const QString &prompt); Q_INVOKABLE void newPromptResponsePair(const QString &prompt);
QList<ResultInfo> databaseResults() const; QList<ResultInfo> databaseResults() const { return m_databaseResults; }
QString response() const; QString response() const;
bool responseInProgress() const { return m_responseInProgress; } bool responseInProgress() const { return m_responseInProgress; }
@ -133,6 +133,7 @@ private Q_SLOTS:
void handleModelNameChanged(); void handleModelNameChanged();
void handleModelLoadingError(const QString &error); void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed); void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
private: private:
QString m_id; QString m_id;
@ -147,6 +148,7 @@ private:
ResponseState m_responseState; ResponseState m_responseState;
qint64 m_creationDate; qint64 m_creationDate;
ChatLLM *m_llmodel; ChatLLM *m_llmodel;
QList<ResultInfo> m_databaseResults;
bool m_isServer; bool m_isServer;
bool m_shouldDeleteLater; bool m_shouldDeleteLater;
}; };

View File

@ -413,15 +413,16 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
if (!isModelLoaded()) if (!isModelLoaded())
return false; return false;
m_databaseResults.clear(); QList<ResultInfo> databaseResults;
const int retrievalSize = LocalDocs::globalInstance()->retrievalSize(); const int retrievalSize = LocalDocs::globalInstance()->retrievalSize();
emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &m_databaseResults); // blocks emit requestRetrieveFromDB(m_chat->collectionList(), prompt, retrievalSize, &databaseResults); // blocks
emit databaseResultsChanged(databaseResults);
// Augment the prompt template with the results if any // Augment the prompt template with the results if any
QList<QString> augmentedTemplate; QList<QString> augmentedTemplate;
if (!m_databaseResults.isEmpty()) if (!databaseResults.isEmpty())
augmentedTemplate.append("### Context:"); augmentedTemplate.append("### Context:");
for (const ResultInfo &info : m_databaseResults) for (const ResultInfo &info : databaseResults)
augmentedTemplate.append(info.text); augmentedTemplate.append(info.text);
augmentedTemplate.append(prompt_template); augmentedTemplate.append(prompt_template);

View File

@ -81,7 +81,6 @@ public:
void regenerateResponse(); void regenerateResponse();
void resetResponse(); void resetResponse();
void resetContext(); void resetContext();
QList<ResultInfo> databaseResults() const { return m_databaseResults; }
void stopGenerating() { m_stopGenerating = true; } void stopGenerating() { m_stopGenerating = true; }
@ -131,6 +130,7 @@ Q_SIGNALS:
void shouldBeLoadedChanged(); void shouldBeLoadedChanged();
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results); void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed); void reportSpeed(const QString &speed);
void databaseResultsChanged(const QList<ResultInfo>&);
protected: protected:
bool handlePrompt(int32_t token); bool handlePrompt(int32_t token);
@ -157,7 +157,6 @@ protected:
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded; std::atomic<bool> m_shouldBeLoaded;
QList<ResultInfo> m_databaseResults;
bool m_isRecalc; bool m_isRecalc;
bool m_isServer; bool m_isServer;
bool m_isChatGPT; bool m_isChatGPT;

View File

@ -22,10 +22,12 @@ Q_SIGNALS:
private Q_SLOTS: private Q_SLOTS:
QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat); QHttpServerResponse handleCompletionRequest(const QHttpServerRequest &request, bool isChat);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results) { m_databaseResults = results; }
private: private:
Chat *m_chat; Chat *m_chat;
QHttpServer *m_server; QHttpServer *m_server;
QList<ResultInfo> m_databaseResults;
}; };
#endif // SERVER_H #endif // SERVER_H