From f1f60d6ef8bb4b873bb42c8f66258aef6db01b2a Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Wed, 7 Aug 2024 16:08:43 -0400 Subject: [PATCH] chatllm: clean up API Some functions did not need to be public or did not need to exist at all. Signed-off-by: Jared Van Bortel --- gpt4all-chat/chat.h | 2 +- gpt4all-chat/chatllm.cpp | 11 +++----- gpt4all-chat/chatllm.h | 58 +++++++++++++++++++--------------------- 3 files changed, 33 insertions(+), 38 deletions(-) diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 065c624e..cb1b1ccc 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -161,7 +161,7 @@ Q_SIGNALS: private Q_SLOTS: void handleResponseChanged(const QString &response); - void handleModelLoadingPercentageChanged(float); + void handleModelLoadingPercentageChanged(float loadingPercentage); void promptProcessing(); void generatingQuestions(); void responseStopped(qint64 promptResponseMs); diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index b386d0ce..6b6c0f02 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -653,23 +653,20 @@ QString ChatLLM::response() const return QString::fromStdString(remove_leading_whitespace(m_response)); } -ModelInfo ChatLLM::modelInfo() const -{ - return m_modelInfo; -} - void ChatLLM::setModelInfo(const ModelInfo &modelInfo) { m_modelInfo = modelInfo; emit modelInfoChanged(modelInfo); } -void ChatLLM::acquireModel() { +void ChatLLM::acquireModel() +{ m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); emit loadedModelInfoChanged(); } -void ChatLLM::resetModel() { +void ChatLLM::resetModel() +{ m_llModelInfo = {}; emit loadedModelInfoChanged(); } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 68f4e0f1..18ccb897 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -1,5 +1,4 @@ -#ifndef CHATLLM_H -#define CHATLLM_H +#pragma once #include "database.h" // IWYU pragma: keep #include "modellist.h" @@ -27,6 +26,8 @@ using namespace Qt::Literals::StringLiterals; +class Chat; +class ChatLLM; class QDataStream; // NOTE: values serialized to disk, do not change or reuse @@ -37,8 +38,6 @@ enum LLModelType { BERT_ = 3, // no longer used }; -class ChatLLM; - struct LLModelInfo { std::unique_ptr model; QFileInfo fileInfo; @@ -90,7 +89,6 @@ private: quint32 m_tokens; }; -class Chat; class ChatLLM : public QObject { Q_OBJECT @@ -98,35 +96,28 @@ class ChatLLM : public QObject Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged) + public: ChatLLM(Chat *parent, bool isServer = false); virtual ~ChatLLM(); void destroy(); static void destroyStore(); - bool isModelLoaded() const; void regenerateResponse(); void resetResponse(); void resetContext(); void stopGenerating() { m_stopGenerating = true; } - bool shouldBeLoaded() const { return m_shouldBeLoaded; } void setShouldBeLoaded(bool b); void requestTrySwitchContext(); void setForceUnloadModel(bool b) { m_forceUnloadModel = b; } void setMarkedForDeletion(bool b) { m_markedForDeletion = b; } - QString response() const; - - ModelInfo modelInfo() const; void setModelInfo(const ModelInfo &info); bool restoringFromText() const { return m_restoringFromText; } - void acquireModel(); - void resetModel(); - QString deviceBackend() const { auto *lcppmodel = dynamic_cast(m_llModelInfo.model.get()); @@ -150,8 +141,6 @@ public: return m_llModelInfo.fallbackReason.value_or(u""_s); } - QString generatedName() const { return QString::fromStdString(m_nameResponse); } - bool serialize(QDataStream &stream, int version, bool serializeKV); bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV); void setStateFromText(const QVector> &stateFromText) { m_stateFromText = stateFromText; } @@ -159,20 +148,10 @@ public: public Q_SLOTS: bool prompt(const QList &collectionList, const QString &prompt); bool loadDefaultModel(); - void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); bool loadModel(const ModelInfo &modelInfo); void modelChangeRequested(const ModelInfo &modelInfo); - void unloadModel(); - void reloadModel(); void generateName(); - void generateQuestions(qint64 elapsed); - void handleChatIdChanged(const QString &id); - void handleShouldBeLoadedChanged(); - void handleThreadStarted(); - void handleForceMetalChanged(bool forceMetal); - void handleDeviceChanged(); void processSystemPrompt(); - void processRestoreStateFromText(); Q_SIGNALS: void restoringFromTextChanged(); @@ -195,10 +174,13 @@ Q_SIGNALS: void reportSpeed(const QString &speed); void reportDevice(const QString &device); void reportFallbackReason(const QString &fallbackReason); - void databaseResultsChanged(const QList&); + void databaseResultsChanged(const QList &results); void modelInfoChanged(const ModelInfo &modelInfo); protected: + bool isModelLoaded() const; + void acquireModel(); + void resetModel(); bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); @@ -215,14 +197,32 @@ protected: void saveState(); void restoreState(); + // used by Server class + ModelInfo modelInfo() const { return m_modelInfo; } + QString response() const; + QString generatedName() const { return QString::fromStdString(m_nameResponse); } + +protected Q_SLOTS: + void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); + void unloadModel(); + void reloadModel(); + void generateQuestions(qint64 elapsed); + void handleChatIdChanged(const QString &id); + void handleShouldBeLoadedChanged(); + void handleThreadStarted(); + void handleForceMetalChanged(bool forceMetal); + void handleDeviceChanged(); + void processRestoreStateFromText(); + +private: + bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); + protected: LLModel::PromptContext m_ctx; quint32 m_promptTokens; quint32 m_promptResponseTokens; private: - bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); - std::string m_response; std::string m_nameResponse; QString m_questionResponse; @@ -248,5 +248,3 @@ private: bool m_pristineLoadedState = false; QVector> m_stateFromText; }; - -#endif // CHATLLM_H