chatllm: clean up API

Some functions did not need to be public or did not need to exist at
all.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-08-07 16:08:43 -04:00
parent 595501fcde
commit f1f60d6ef8
3 changed files with 33 additions and 38 deletions

View File

@ -161,7 +161,7 @@ Q_SIGNALS:
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(const QString &response); void handleResponseChanged(const QString &response);
void handleModelLoadingPercentageChanged(float); void handleModelLoadingPercentageChanged(float loadingPercentage);
void promptProcessing(); void promptProcessing();
void generatingQuestions(); void generatingQuestions();
void responseStopped(qint64 promptResponseMs); void responseStopped(qint64 promptResponseMs);

View File

@ -653,23 +653,20 @@ QString ChatLLM::response() const
return QString::fromStdString(remove_leading_whitespace(m_response)); return QString::fromStdString(remove_leading_whitespace(m_response));
} }
ModelInfo ChatLLM::modelInfo() const
{
return m_modelInfo;
}
void ChatLLM::setModelInfo(const ModelInfo &modelInfo) void ChatLLM::setModelInfo(const ModelInfo &modelInfo)
{ {
m_modelInfo = modelInfo; m_modelInfo = modelInfo;
emit modelInfoChanged(modelInfo); emit modelInfoChanged(modelInfo);
} }
void ChatLLM::acquireModel() { void ChatLLM::acquireModel()
{
m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
emit loadedModelInfoChanged(); emit loadedModelInfoChanged();
} }
void ChatLLM::resetModel() { void ChatLLM::resetModel()
{
m_llModelInfo = {}; m_llModelInfo = {};
emit loadedModelInfoChanged(); emit loadedModelInfoChanged();
} }

View File

@ -1,5 +1,4 @@
#ifndef CHATLLM_H #pragma once
#define CHATLLM_H
#include "database.h" // IWYU pragma: keep #include "database.h" // IWYU pragma: keep
#include "modellist.h" #include "modellist.h"
@ -27,6 +26,8 @@
using namespace Qt::Literals::StringLiterals; using namespace Qt::Literals::StringLiterals;
class Chat;
class ChatLLM;
class QDataStream; class QDataStream;
// NOTE: values serialized to disk, do not change or reuse // NOTE: values serialized to disk, do not change or reuse
@ -37,8 +38,6 @@ enum LLModelType {
BERT_ = 3, // no longer used BERT_ = 3, // no longer used
}; };
class ChatLLM;
struct LLModelInfo { struct LLModelInfo {
std::unique_ptr<LLModel> model; std::unique_ptr<LLModel> model;
QFileInfo fileInfo; QFileInfo fileInfo;
@ -90,7 +89,6 @@ private:
quint32 m_tokens; quint32 m_tokens;
}; };
class Chat;
class ChatLLM : public QObject class ChatLLM : public QObject
{ {
Q_OBJECT Q_OBJECT
@ -98,35 +96,28 @@ class ChatLLM : public QObject
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
public: public:
ChatLLM(Chat *parent, bool isServer = false); ChatLLM(Chat *parent, bool isServer = false);
virtual ~ChatLLM(); virtual ~ChatLLM();
void destroy(); void destroy();
static void destroyStore(); static void destroyStore();
bool isModelLoaded() const;
void regenerateResponse(); void regenerateResponse();
void resetResponse(); void resetResponse();
void resetContext(); void resetContext();
void stopGenerating() { m_stopGenerating = true; } void stopGenerating() { m_stopGenerating = true; }
bool shouldBeLoaded() const { return m_shouldBeLoaded; }
void setShouldBeLoaded(bool b); void setShouldBeLoaded(bool b);
void requestTrySwitchContext(); void requestTrySwitchContext();
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; } void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; } void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
QString response() const;
ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &info); void setModelInfo(const ModelInfo &info);
bool restoringFromText() const { return m_restoringFromText; } bool restoringFromText() const { return m_restoringFromText; }
void acquireModel();
void resetModel();
QString deviceBackend() const QString deviceBackend() const
{ {
auto *lcppmodel = dynamic_cast<LlamaCppBackend *>(m_llModelInfo.model.get()); auto *lcppmodel = dynamic_cast<LlamaCppBackend *>(m_llModelInfo.model.get());
@ -150,8 +141,6 @@ public:
return m_llModelInfo.fallbackReason.value_or(u""_s); return m_llModelInfo.fallbackReason.value_or(u""_s);
} }
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
bool serialize(QDataStream &stream, int version, bool serializeKV); bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV); bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; } void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }
@ -159,20 +148,10 @@ public:
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt); bool prompt(const QList<QString> &collectionList, const QString &prompt);
bool loadDefaultModel(); bool loadDefaultModel();
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
bool loadModel(const ModelInfo &modelInfo); bool loadModel(const ModelInfo &modelInfo);
void modelChangeRequested(const ModelInfo &modelInfo); void modelChangeRequested(const ModelInfo &modelInfo);
void unloadModel();
void reloadModel();
void generateName(); void generateName();
void generateQuestions(qint64 elapsed);
void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged();
void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
void processSystemPrompt(); void processSystemPrompt();
void processRestoreStateFromText();
Q_SIGNALS: Q_SIGNALS:
void restoringFromTextChanged(); void restoringFromTextChanged();
@ -195,10 +174,13 @@ Q_SIGNALS:
void reportSpeed(const QString &speed); void reportSpeed(const QString &speed);
void reportDevice(const QString &device); void reportDevice(const QString &device);
void reportFallbackReason(const QString &fallbackReason); void reportFallbackReason(const QString &fallbackReason);
void databaseResultsChanged(const QList<ResultInfo>&); void databaseResultsChanged(const QList<ResultInfo> &results);
void modelInfoChanged(const ModelInfo &modelInfo); void modelInfoChanged(const ModelInfo &modelInfo);
protected: protected:
bool isModelLoaded() const;
void acquireModel();
void resetModel();
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate, bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens); int32_t repeat_penalty_tokens);
@ -215,14 +197,32 @@ protected:
void saveState(); void saveState();
void restoreState(); void restoreState();
// used by Server class
ModelInfo modelInfo() const { return m_modelInfo; }
QString response() const;
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
protected Q_SLOTS:
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
void unloadModel();
void reloadModel();
void generateQuestions(qint64 elapsed);
void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged();
void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
void processRestoreStateFromText();
private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
protected: protected:
LLModel::PromptContext m_ctx; LLModel::PromptContext m_ctx;
quint32 m_promptTokens; quint32 m_promptTokens;
quint32 m_promptResponseTokens; quint32 m_promptResponseTokens;
private: private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
std::string m_response; std::string m_response;
std::string m_nameResponse; std::string m_nameResponse;
QString m_questionResponse; QString m_questionResponse;
@ -248,5 +248,3 @@ private:
bool m_pristineLoadedState = false; bool m_pristineLoadedState = false;
QVector<QPair<QString, QString>> m_stateFromText; QVector<QPair<QString, QString>> m_stateFromText;
}; };
#endif // CHATLLM_H