chatllm: clean up API

Some functions did not need to be public or did not need to exist at
all.

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-08-07 16:08:43 -04:00
parent 595501fcde
commit f1f60d6ef8
3 changed files with 33 additions and 38 deletions

View File

@ -161,7 +161,7 @@ Q_SIGNALS:
private Q_SLOTS:
void handleResponseChanged(const QString &response);
void handleModelLoadingPercentageChanged(float);
void handleModelLoadingPercentageChanged(float loadingPercentage);
void promptProcessing();
void generatingQuestions();
void responseStopped(qint64 promptResponseMs);

View File

@ -653,23 +653,20 @@ QString ChatLLM::response() const
return QString::fromStdString(remove_leading_whitespace(m_response));
}
ModelInfo ChatLLM::modelInfo() const
{
return m_modelInfo;
}
void ChatLLM::setModelInfo(const ModelInfo &modelInfo)
{
m_modelInfo = modelInfo;
emit modelInfoChanged(modelInfo);
}
void ChatLLM::acquireModel() {
void ChatLLM::acquireModel()
{
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
emit loadedModelInfoChanged();
}
void ChatLLM::resetModel() {
void ChatLLM::resetModel()
{
m_llModelInfo = {};
emit loadedModelInfoChanged();
}

View File

@ -1,5 +1,4 @@
#ifndef CHATLLM_H
#define CHATLLM_H
#pragma once
#include "database.h" // IWYU pragma: keep
#include "modellist.h"
@ -27,6 +26,8 @@
using namespace Qt::Literals::StringLiterals;
class Chat;
class ChatLLM;
class QDataStream;
// NOTE: values serialized to disk, do not change or reuse
@ -37,8 +38,6 @@ enum LLModelType {
BERT_ = 3, // no longer used
};
class ChatLLM;
struct LLModelInfo {
std::unique_ptr<LLModel> model;
QFileInfo fileInfo;
@ -90,7 +89,6 @@ private:
quint32 m_tokens;
};
class Chat;
class ChatLLM : public QObject
{
Q_OBJECT
@ -98,35 +96,28 @@ class ChatLLM : public QObject
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
public:
ChatLLM(Chat *parent, bool isServer = false);
virtual ~ChatLLM();
void destroy();
static void destroyStore();
bool isModelLoaded() const;
void regenerateResponse();
void resetResponse();
void resetContext();
void stopGenerating() { m_stopGenerating = true; }
bool shouldBeLoaded() const { return m_shouldBeLoaded; }
void setShouldBeLoaded(bool b);
void requestTrySwitchContext();
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
QString response() const;
ModelInfo modelInfo() const;
void setModelInfo(const ModelInfo &info);
bool restoringFromText() const { return m_restoringFromText; }
void acquireModel();
void resetModel();
QString deviceBackend() const
{
auto *lcppmodel = dynamic_cast<LlamaCppBackend *>(m_llModelInfo.model.get());
@ -150,8 +141,6 @@ public:
return m_llModelInfo.fallbackReason.value_or(u""_s);
}
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }
@ -159,20 +148,10 @@ public:
public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt);
bool loadDefaultModel();
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
bool loadModel(const ModelInfo &modelInfo);
void modelChangeRequested(const ModelInfo &modelInfo);
void unloadModel();
void reloadModel();
void generateName();
void generateQuestions(qint64 elapsed);
void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged();
void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
void processSystemPrompt();
void processRestoreStateFromText();
Q_SIGNALS:
void restoringFromTextChanged();
@ -195,10 +174,13 @@ Q_SIGNALS:
void reportSpeed(const QString &speed);
void reportDevice(const QString &device);
void reportFallbackReason(const QString &fallbackReason);
void databaseResultsChanged(const QList<ResultInfo>&);
void databaseResultsChanged(const QList<ResultInfo> &results);
void modelInfoChanged(const ModelInfo &modelInfo);
protected:
bool isModelLoaded() const;
void acquireModel();
void resetModel();
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens);
@ -215,14 +197,32 @@ protected:
void saveState();
void restoreState();
// used by Server class
ModelInfo modelInfo() const { return m_modelInfo; }
QString response() const;
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
protected Q_SLOTS:
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
void unloadModel();
void reloadModel();
void generateQuestions(qint64 elapsed);
void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged();
void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
void processRestoreStateFromText();
private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
protected:
LLModel::PromptContext m_ctx;
quint32 m_promptTokens;
quint32 m_promptResponseTokens;
private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
std::string m_response;
std::string m_nameResponse;
QString m_questionResponse;
@ -248,5 +248,3 @@ private:
bool m_pristineLoadedState = false;
QVector<QPair<QString, QString>> m_stateFromText;
};
#endif // CHATLLM_H