Complete revamp of model loading to allow for more discreet control by

the user of the models loading behavior.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat 2024-02-07 09:37:59 -05:00 committed by AT
parent f2024a1f9e
commit d948a4f2ee
14 changed files with 506 additions and 175 deletions

View File

@ -180,6 +180,9 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
d_ptr->model_params.use_mlock = params.use_mlock; d_ptr->model_params.use_mlock = params.use_mlock;
#endif #endif
d_ptr->model_params.progress_callback = &LLModel::staticProgressCallback;
d_ptr->model_params.progress_callback_user_data = this;
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (llama_verbose()) { if (llama_verbose()) {
std::cerr << "llama.cpp: using Metal" << std::endl; std::cerr << "llama.cpp: using Metal" << std::endl;

View File

@ -74,6 +74,8 @@ public:
int32_t n_last_batch_tokens = 0; int32_t n_last_batch_tokens = 0;
}; };
using ProgressCallback = std::function<bool(float progress)>;
explicit LLModel() {} explicit LLModel() {}
virtual ~LLModel() {} virtual ~LLModel() {}
@ -125,6 +127,8 @@ public:
virtual bool hasGPUDevice() { return false; } virtual bool hasGPUDevice() { return false; }
virtual bool usingGPUDevice() { return false; } virtual bool usingGPUDevice() { return false; }
void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }
protected: protected:
// These are pure virtual because subclasses need to implement as the default implementation of // These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions // 'prompt' above calls these functions
@ -153,6 +157,15 @@ protected:
const Implementation *m_implementation = nullptr; const Implementation *m_implementation = nullptr;
ProgressCallback m_progressCallback;
static bool staticProgressCallback(float progress, void* ctx)
{
LLModel* model = static_cast<LLModel*>(ctx);
if (model && model->m_progressCallback)
return model->m_progressCallback(progress);
return true;
}
private: private:
friend class LLMImplementation; friend class LLMImplementation;
}; };

View File

@ -109,6 +109,7 @@ qt_add_qml_module(chat
qml/ModelSettings.qml qml/ModelSettings.qml
qml/ApplicationSettings.qml qml/ApplicationSettings.qml
qml/LocalDocsSettings.qml qml/LocalDocsSettings.qml
qml/SwitchModelDialog.qml
qml/MySettingsTab.qml qml/MySettingsTab.qml
qml/MySettingsStack.qml qml/MySettingsStack.qml
qml/MySettingsDestructiveButton.qml qml/MySettingsDestructiveButton.qml
@ -123,6 +124,7 @@ qt_add_qml_module(chat
qml/MyTextField.qml qml/MyTextField.qml
qml/MyCheckBox.qml qml/MyCheckBox.qml
qml/MyBusyIndicator.qml qml/MyBusyIndicator.qml
qml/MyMiniButton.qml
qml/MyToolButton.qml qml/MyToolButton.qml
RESOURCES RESOURCES
icons/send_message.svg icons/send_message.svg
@ -133,6 +135,7 @@ qt_add_qml_module(chat
icons/db.svg icons/db.svg
icons/download.svg icons/download.svg
icons/settings.svg icons/settings.svg
icons/eject.svg
icons/edit.svg icons/edit.svg
icons/image.svg icons/image.svg
icons/trash.svg icons/trash.svg

View File

@ -23,14 +23,10 @@ Chat::Chat(bool isServer, QObject *parent)
, m_id(Network::globalInstance()->generateUniqueId()) , m_id(Network::globalInstance()->generateUniqueId())
, m_name(tr("Server Chat")) , m_name(tr("Server Chat"))
, m_chatModel(new ChatModel(this)) , m_chatModel(new ChatModel(this))
, m_responseInProgress(false)
, m_responseState(Chat::ResponseStopped) , m_responseState(Chat::ResponseStopped)
, m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new Server(this)) , m_llmodel(new Server(this))
, m_isServer(true) , m_isServer(true)
, m_shouldDeleteLater(false)
, m_isModelLoaded(false)
, m_shouldLoadModelWhenInstalled(false)
, m_collectionModel(new LocalDocsCollectionsModel(this)) , m_collectionModel(new LocalDocsCollectionsModel(this))
{ {
connectLLM(); connectLLM();
@ -45,7 +41,7 @@ Chat::~Chat()
void Chat::connectLLM() void Chat::connectLLM()
{ {
// Should be in different threads // Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
@ -57,6 +53,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::reportFallbackReason, this, &Chat::handleFallbackReasonChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportFallbackReason, this, &Chat::handleFallbackReasonChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::trySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection); connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection);
@ -69,8 +66,6 @@ void Chat::connectLLM()
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection); connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection);
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections); connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
connect(ModelList::globalInstance()->installedModels(), &InstalledModels::countChanged,
this, &Chat::handleModelInstalled, Qt::QueuedConnection);
} }
void Chat::reset() void Chat::reset()
@ -101,7 +96,12 @@ void Chat::processSystemPrompt()
bool Chat::isModelLoaded() const bool Chat::isModelLoaded() const
{ {
return m_isModelLoaded; return m_modelLoadingPercentage == 1.0f;
}
float Chat::modelLoadingPercentage() const
{
return m_modelLoadingPercentage;
} }
void Chat::resetResponseState() void Chat::resetResponseState()
@ -158,15 +158,17 @@ void Chat::handleResponseChanged(const QString &response)
emit responseChanged(); emit responseChanged();
} }
void Chat::handleModelLoadedChanged(bool loaded) void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)
{ {
if (m_shouldDeleteLater) if (m_shouldDeleteLater)
deleteLater(); deleteLater();
if (loaded == m_isModelLoaded) if (loadingPercentage == m_modelLoadingPercentage)
return; return;
m_isModelLoaded = loaded; m_modelLoadingPercentage = loadingPercentage;
emit modelLoadingPercentageChanged();
if (m_modelLoadingPercentage == 1.0f || m_modelLoadingPercentage == 0.0f)
emit isModelLoadedChanged(); emit isModelLoadedChanged();
} }
@ -238,10 +240,10 @@ ModelInfo Chat::modelInfo() const
void Chat::setModelInfo(const ModelInfo &modelInfo) void Chat::setModelInfo(const ModelInfo &modelInfo)
{ {
if (m_modelInfo == modelInfo) if (m_modelInfo == modelInfo && isModelLoaded())
return; return;
m_isModelLoaded = false; m_modelLoadingPercentage = std::numeric_limits<float>::min();
emit isModelLoadedChanged(); emit isModelLoadedChanged();
m_modelLoadingError = QString(); m_modelLoadingError = QString();
emit modelLoadingErrorChanged(); emit modelLoadingErrorChanged();
@ -291,21 +293,26 @@ void Chat::unloadModel()
void Chat::reloadModel() void Chat::reloadModel()
{ {
// If the installed model list is empty, then we mark a special flag and monitor for when a model
// is installed
if (!ModelList::globalInstance()->installedModels()->count()) {
m_shouldLoadModelWhenInstalled = true;
return;
}
m_llmodel->setShouldBeLoaded(true); m_llmodel->setShouldBeLoaded(true);
} }
void Chat::handleModelInstalled() void Chat::forceUnloadModel()
{ {
if (!m_shouldLoadModelWhenInstalled) stopGenerating();
return; m_llmodel->setForceUnloadModel(true);
m_shouldLoadModelWhenInstalled = false; m_llmodel->setShouldBeLoaded(false);
reloadModel(); }
void Chat::forceReloadModel()
{
m_llmodel->setForceUnloadModel(true);
m_llmodel->setShouldBeLoaded(true);
}
void Chat::trySwitchContextOfLoadedModel()
{
emit trySwitchContextOfLoadedModelAttempted();
m_llmodel->setShouldTrySwitchContext(true);
} }
void Chat::generatedNameChanged(const QString &name) void Chat::generatedNameChanged(const QString &name)

View File

@ -17,6 +17,7 @@ class Chat : public QObject
Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged) Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged)
Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged) Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged)
Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged)
Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged)
Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged)
Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged) Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged)
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
@ -61,6 +62,7 @@ public:
Q_INVOKABLE void reset(); Q_INVOKABLE void reset();
Q_INVOKABLE void processSystemPrompt(); Q_INVOKABLE void processSystemPrompt();
Q_INVOKABLE bool isModelLoaded() const; Q_INVOKABLE bool isModelLoaded() const;
Q_INVOKABLE float modelLoadingPercentage() const;
Q_INVOKABLE void prompt(const QString &prompt); Q_INVOKABLE void prompt(const QString &prompt);
Q_INVOKABLE void regenerateResponse(); Q_INVOKABLE void regenerateResponse();
Q_INVOKABLE void stopGenerating(); Q_INVOKABLE void stopGenerating();
@ -75,8 +77,11 @@ public:
void setModelInfo(const ModelInfo &modelInfo); void setModelInfo(const ModelInfo &modelInfo);
bool isRecalc() const; bool isRecalc() const;
void unloadModel(); Q_INVOKABLE void unloadModel();
void reloadModel(); Q_INVOKABLE void reloadModel();
Q_INVOKABLE void forceUnloadModel();
Q_INVOKABLE void forceReloadModel();
Q_INVOKABLE void trySwitchContextOfLoadedModel();
void unloadAndDeleteLater(); void unloadAndDeleteLater();
qint64 creationDate() const { return m_creationDate; } qint64 creationDate() const { return m_creationDate; }
@ -106,6 +111,7 @@ Q_SIGNALS:
void nameChanged(); void nameChanged();
void chatModelChanged(); void chatModelChanged();
void isModelLoadedChanged(); void isModelLoadedChanged();
void modelLoadingPercentageChanged();
void responseChanged(); void responseChanged();
void responseInProgressChanged(); void responseInProgressChanged();
void responseStateChanged(); void responseStateChanged();
@ -127,10 +133,12 @@ Q_SIGNALS:
void deviceChanged(); void deviceChanged();
void fallbackReasonChanged(); void fallbackReasonChanged();
void collectionModelChanged(); void collectionModelChanged();
void trySwitchContextOfLoadedModelAttempted();
void trySwitchContextOfLoadedModelCompleted(bool);
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(const QString &response); void handleResponseChanged(const QString &response);
void handleModelLoadedChanged(bool); void handleModelLoadingPercentageChanged(float);
void promptProcessing(); void promptProcessing();
void responseStopped(); void responseStopped();
void generatedNameChanged(const QString &name); void generatedNameChanged(const QString &name);
@ -141,7 +149,6 @@ private Q_SLOTS:
void handleFallbackReasonChanged(const QString &device); void handleFallbackReasonChanged(const QString &device);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results); void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
void handleModelInfoChanged(const ModelInfo &modelInfo); void handleModelInfoChanged(const ModelInfo &modelInfo);
void handleModelInstalled();
private: private:
QString m_id; QString m_id;
@ -163,8 +170,7 @@ private:
QList<ResultInfo> m_databaseResults; QList<ResultInfo> m_databaseResults;
bool m_isServer = false; bool m_isServer = false;
bool m_shouldDeleteLater = false; bool m_shouldDeleteLater = false;
bool m_isModelLoaded = false; float m_modelLoadingPercentage = 0.0f;
bool m_shouldLoadModelWhenInstalled = false;
LocalDocsCollectionsModel *m_collectionModel; LocalDocsCollectionsModel *m_collectionModel;
}; };

View File

@ -179,9 +179,9 @@ public:
if (m_currentChat && m_currentChat != m_serverChat) if (m_currentChat && m_currentChat != m_serverChat)
m_currentChat->unloadModel(); m_currentChat->unloadModel();
m_currentChat = chat; m_currentChat = chat;
if (!m_currentChat->isModelLoaded() && m_currentChat != m_serverChat)
m_currentChat->reloadModel();
emit currentChatChanged(); emit currentChatChanged();
if (!m_currentChat->isModelLoaded() && m_currentChat != m_serverChat)
m_currentChat->trySwitchContextOfLoadedModel();
} }
Q_INVOKABLE Chat* get(int index) Q_INVOKABLE Chat* get(int index)

View File

@ -62,7 +62,9 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_promptResponseTokens(0) , m_promptResponseTokens(0)
, m_promptTokens(0) , m_promptTokens(0)
, m_isRecalc(false) , m_isRecalc(false)
, m_shouldBeLoaded(true) , m_shouldBeLoaded(false)
, m_forceUnloadModel(false)
, m_shouldTrySwitchContext(false)
, m_stopGenerating(false) , m_stopGenerating(false)
, m_timer(nullptr) , m_timer(nullptr)
, m_isServer(isServer) , m_isServer(isServer)
@ -76,6 +78,8 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued Qt::QueuedConnection); // explicitly queued
connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged,
Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged); connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
@ -143,6 +147,54 @@ bool ChatLLM::loadDefaultModel()
return loadModel(defaultModel); return loadModel(defaultModel);
} }
bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
{
// We're trying to see if the store already has the model fully loaded that we wish to use
// and if so we just acquire it from the store and switch the context and return true. If the
// store doesn't have it or we're already loaded or in any other case just return false.
// If we're already loaded or a server or we're reloading to change the variant/device or the
// modelInfo is empty, then this should fail
if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) {
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
}
QString filePath = modelInfo.dirpath + modelInfo.filename();
QFileInfo fileInfo(filePath);
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
// The store gave us no already loaded model, the wrong type of model, then give it back to the
// store and fail
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) {
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
}
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
// We should be loaded and now we are
m_shouldBeLoaded = true;
m_shouldTrySwitchContext = false;
// Restore, signal and process
restoreState();
emit modelLoadingPercentageChanged(1.0f);
emit trySwitchContextOfLoadedModelCompleted(true);
processSystemPrompt();
return true;
}
bool ChatLLM::loadModel(const ModelInfo &modelInfo) bool ChatLLM::loadModel(const ModelInfo &modelInfo)
{ {
// This is a complicated method because N different possible threads are interested in the outcome // This is a complicated method because N different possible threads are interested in the outcome
@ -170,7 +222,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
#endif #endif
delete m_llModelInfo.model; delete m_llModelInfo.model;
m_llModelInfo.model = nullptr; m_llModelInfo.model = nullptr;
emit isModelLoadedChanged(false); emit modelLoadingPercentageChanged(std::numeric_limits<float>::min());
} else if (!m_isServer) { } else if (!m_isServer) {
// This is a blocking call that tries to retrieve the model we need from the model store. // This is a blocking call that tries to retrieve the model we need from the model store.
// If it succeeds, then we just have to restore state. If the store has never had a model // If it succeeds, then we just have to restore state. If the store has never had a model
@ -188,7 +240,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
#endif #endif
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo(); m_llModelInfo = LLModelInfo();
emit isModelLoadedChanged(false); emit modelLoadingPercentageChanged(0.0f);
return false; return false;
} }
@ -198,7 +250,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif #endif
restoreState(); restoreState();
emit isModelLoadedChanged(true); emit modelLoadingPercentageChanged(1.0f);
setModelInfo(modelInfo); setModelInfo(modelInfo);
Q_ASSERT(!m_modelInfo.filename().isEmpty()); Q_ASSERT(!m_modelInfo.filename().isEmpty());
if (m_modelInfo.filename().isEmpty()) if (m_modelInfo.filename().isEmpty())
@ -261,6 +313,12 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx); m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
if (m_llModelInfo.model) { if (m_llModelInfo.model) {
m_llModelInfo.model->setProgressCallback([this](float progress) -> bool {
emit modelLoadingPercentageChanged(progress);
return m_shouldBeLoaded;
});
// Update the settings that a model is being loaded and update the device list // Update the settings that a model is being loaded and update the device list
MySettings::globalInstance()->setAttemptModelLoad(filePath); MySettings::globalInstance()->setAttemptModelLoad(filePath);
@ -354,7 +412,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
qDebug() << "modelLoadedChanged" << m_llmThread.objectName(); qDebug() << "modelLoadedChanged" << m_llmThread.objectName();
fflush(stdout); fflush(stdout);
#endif #endif
emit isModelLoadedChanged(isModelLoaded()); emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f);
static bool isFirstLoad = true; static bool isFirstLoad = true;
if (isFirstLoad) { if (isFirstLoad) {
@ -456,6 +514,7 @@ void ChatLLM::setModelInfo(const ModelInfo &modelInfo)
void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo) void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo)
{ {
m_shouldBeLoaded = true;
loadModel(modelInfo); loadModel(modelInfo);
} }
@ -598,6 +657,12 @@ void ChatLLM::setShouldBeLoaded(bool b)
emit shouldBeLoadedChanged(); emit shouldBeLoadedChanged();
} }
void ChatLLM::setShouldTrySwitchContext(bool b)
{
m_shouldTrySwitchContext = b; // atomic
emit shouldTrySwitchContextChanged();
}
void ChatLLM::handleShouldBeLoadedChanged() void ChatLLM::handleShouldBeLoadedChanged()
{ {
if (m_shouldBeLoaded) if (m_shouldBeLoaded)
@ -606,10 +671,10 @@ void ChatLLM::handleShouldBeLoadedChanged()
unloadModel(); unloadModel();
} }
void ChatLLM::forceUnloadModel() void ChatLLM::handleShouldTrySwitchContextChanged()
{ {
m_shouldBeLoaded = false; // atomic if (m_shouldTrySwitchContext)
unloadModel(); trySwitchContextOfLoadedModel(modelInfo());
} }
void ChatLLM::unloadModel() void ChatLLM::unloadModel()
@ -617,17 +682,27 @@ void ChatLLM::unloadModel()
if (!isModelLoaded() || m_isServer) if (!isModelLoaded() || m_isServer)
return; return;
emit modelLoadingPercentageChanged(0.0f);
saveState(); saveState();
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model; qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
#endif #endif
if (m_forceUnloadModel) {
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_forceUnloadModel = false;
}
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo(); m_llModelInfo = LLModelInfo();
emit isModelLoadedChanged(false);
} }
void ChatLLM::reloadModel() void ChatLLM::reloadModel()
{ {
if (isModelLoaded() && m_forceUnloadModel)
unloadModel(); // we unload first if we are forcing an unload
if (isModelLoaded() || m_isServer) if (isModelLoaded() || m_isServer)
return; return;

View File

@ -81,6 +81,8 @@ public:
bool shouldBeLoaded() const { return m_shouldBeLoaded; } bool shouldBeLoaded() const { return m_shouldBeLoaded; }
void setShouldBeLoaded(bool b); void setShouldBeLoaded(bool b);
void setShouldTrySwitchContext(bool b);
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
QString response() const; QString response() const;
@ -98,14 +100,15 @@ public:
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt); bool prompt(const QList<QString> &collectionList, const QString &prompt);
bool loadDefaultModel(); bool loadDefaultModel();
bool trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
bool loadModel(const ModelInfo &modelInfo); bool loadModel(const ModelInfo &modelInfo);
void modelChangeRequested(const ModelInfo &modelInfo); void modelChangeRequested(const ModelInfo &modelInfo);
void forceUnloadModel();
void unloadModel(); void unloadModel();
void reloadModel(); void reloadModel();
void generateName(); void generateName();
void handleChatIdChanged(const QString &id); void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged(); void handleShouldBeLoadedChanged();
void handleShouldTrySwitchContextChanged();
void handleThreadStarted(); void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal); void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged(); void handleDeviceChanged();
@ -114,7 +117,7 @@ public Q_SLOTS:
Q_SIGNALS: Q_SIGNALS:
void recalcChanged(); void recalcChanged();
void isModelLoadedChanged(bool); void modelLoadingPercentageChanged(float);
void modelLoadingError(const QString &error); void modelLoadingError(const QString &error);
void responseChanged(const QString &response); void responseChanged(const QString &response);
void promptProcessing(); void promptProcessing();
@ -125,6 +128,8 @@ Q_SIGNALS:
void stateChanged(); void stateChanged();
void threadStarted(); void threadStarted();
void shouldBeLoadedChanged(); void shouldBeLoadedChanged();
void shouldTrySwitchContextChanged();
void trySwitchContextOfLoadedModelCompleted(bool);
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results); void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed); void reportSpeed(const QString &speed);
void reportDevice(const QString &device); void reportDevice(const QString &device);
@ -167,7 +172,9 @@ private:
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded; std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_shouldTrySwitchContext;
std::atomic<bool> m_isRecalc; std::atomic<bool> m_isRecalc;
std::atomic<bool> m_forceUnloadModel;
bool m_isServer; bool m_isServer;
bool m_forceMetal; bool m_forceMetal;
bool m_reloadingToChangeVariant; bool m_reloadingToChangeVariant;

View File

@ -0,0 +1,6 @@
<svg xmlns="http://www.w3.org/2000/svg" fill="#7d7d8e" viewBox="0 0 448 512"><path d="M448 384v64c0 17.673-14.327 32-32 32H32c-17.673 0-32-14.327-32-32v-64c0-17.673 14.327-32 32-32h384c17.673 0 32 14.327 32 32zM48.053 320h351.886c41.651 0 63.581-49.674 35.383-80.435L259.383 47.558c-19.014-20.743-51.751-20.744-70.767 0L12.67 239.565C-15.475 270.268 6.324 320 48.053 320z"/></svg>
<!--
Font Awesome Free 5.2.0 by @fontawesome - https://fontawesome.com
License - https://fontawesome.com/license (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License)
-->

After

Width:  |  Height:  |  Size: 557 B

View File

@ -126,6 +126,10 @@ Window {
} }
} }
function currentModelName() {
return ModelList.modelInfo(currentChat.modelInfo.id).name;
}
PopupDialog { PopupDialog {
id: errorCompatHardware id: errorCompatHardware
anchors.centerIn: parent anchors.centerIn: parent
@ -282,6 +286,18 @@ Window {
} }
} }
SwitchModelDialog {
id: switchModelDialog
anchors.centerIn: parent
width: Math.min(1024, window.width - (window.width * .2))
height: Math.min(600, window.height - (window.height * .2))
Item {
Accessible.role: Accessible.Dialog
Accessible.name: qsTr("Switch model dialog")
Accessible.description: qsTr("Warn the user if they switch models, then context will be erased")
}
}
Rectangle { Rectangle {
id: header id: header
anchors.left: parent.left anchors.left: parent.left
@ -292,7 +308,9 @@ Window {
Item { Item {
anchors.centerIn: parent anchors.centerIn: parent
height: childrenRect.height height: childrenRect.height
visible: currentChat.isModelLoaded || currentChat.modelLoadingError !== "" || currentChat.isServer visible: true
|| currentChat.modelLoadingError !== ""
|| currentChat.isServer
Label { Label {
id: modelLabel id: modelLabel
@ -306,46 +324,92 @@ Window {
horizontalAlignment: TextInput.AlignRight horizontalAlignment: TextInput.AlignRight
} }
MyComboBox { RowLayout {
id: comboBox id: comboLayout
implicitWidth: 375
width: window.width >= 750 ? implicitWidth : implicitWidth - ((750 - window.width))
anchors.top: modelLabel.top anchors.top: modelLabel.top
anchors.bottom: modelLabel.bottom anchors.bottom: modelLabel.bottom
anchors.horizontalCenter: parent.horizontalCenter anchors.horizontalCenter: parent.horizontalCenter
anchors.horizontalCenterOffset: window.width >= 950 ? 0 : Math.max(-((950 - window.width) / 2), -99.5) anchors.horizontalCenterOffset: window.width >= 950 ? 0 : Math.max(-((950 - window.width) / 2), -99.5)
spacing: 20
MyComboBox {
id: comboBox
Layout.fillWidth: true
Layout.fillHeight: true
implicitWidth: 575
width: window.width >= 750 ? implicitWidth : implicitWidth - ((750 - window.width))
enabled: !currentChat.isServer enabled: !currentChat.isServer
model: ModelList.installedModels model: ModelList.installedModels
valueRole: "id" valueRole: "id"
textRole: "name" textRole: "name"
property string currentModelName: "" property bool isCurrentlyLoading: false
function updateCurrentModelName() { property real modelLoadingPercentage: 0.0
var info = ModelList.modelInfo(currentChat.modelInfo.id); property bool trySwitchContextInProgress: false
comboBox.currentModelName = info.name;
function changeModel(index) {
comboBox.modelLoadingPercentage = 0.0;
comboBox.isCurrentlyLoading = true;
currentChat.stopGenerating()
currentChat.reset();
currentChat.modelInfo = ModelList.modelInfo(comboBox.valueAt(index))
} }
Connections { Connections {
target: currentChat target: currentChat
function onModelInfoChanged() { function onModelLoadingPercentageChanged() {
comboBox.updateCurrentModelName(); comboBox.modelLoadingPercentage = currentChat.modelLoadingPercentage;
comboBox.isCurrentlyLoading = currentChat.modelLoadingPercentage !== 0.0
&& currentChat.modelLoadingPercentage !== 1.0;
}
function onTrySwitchContextOfLoadedModelAttempted() {
comboBox.trySwitchContextInProgress = true;
}
function onTrySwitchContextOfLoadedModelCompleted() {
comboBox.trySwitchContextInProgress = false;
} }
} }
Connections { Connections {
target: window target: switchModelDialog
function onCurrentChatChanged() { function onAccepted() {
comboBox.updateCurrentModelName(); comboBox.changeModel(switchModelDialog.index)
} }
} }
background: ProgressBar {
id: modelProgress
value: comboBox.modelLoadingPercentage
background: Rectangle { background: Rectangle {
color: theme.mainComboBackground color: theme.mainComboBackground
radius: 10 radius: 10
} }
contentItem: Item {
Rectangle {
visible: comboBox.isCurrentlyLoading
anchors.bottom: parent.bottom
width: modelProgress.visualPosition * parent.width
height: 10
radius: 2
color: theme.progressForeground
}
}
}
contentItem: Text { contentItem: Text {
anchors.horizontalCenter: parent.horizontalCenter anchors.horizontalCenter: parent.horizontalCenter
leftPadding: 10 leftPadding: 10
rightPadding: 20 rightPadding: 20
text: currentChat.modelLoadingError !== "" text: {
? qsTr("Model loading error...") if (currentChat.modelLoadingError !== "")
: comboBox.currentModelName return qsTr("Model loading error...")
if (comboBox.trySwitchContextInProgress)
return qsTr("Switching context...")
if (currentModelName() === "")
return qsTr("Choose a model...")
if (currentChat.modelLoadingPercentage === 0.0)
return qsTr("Reload \u00B7 ") + currentModelName()
if (comboBox.isCurrentlyLoading)
return qsTr("Loading \u00B7 ") + currentModelName()
return currentModelName()
}
font.pixelSize: theme.fontSizeLarger font.pixelSize: theme.fontSizeLarger
color: theme.white color: theme.white
verticalAlignment: Text.AlignVCenter verticalAlignment: Text.AlignVCenter
@ -353,6 +417,7 @@ Window {
elide: Text.ElideRight elide: Text.ElideRight
} }
delegate: ItemDelegate { delegate: ItemDelegate {
id: comboItemDelegate
width: comboBox.width width: comboBox.width
contentItem: Text { contentItem: Text {
text: name text: name
@ -369,39 +434,58 @@ Window {
highlighted: comboBox.highlightedIndex === index highlighted: comboBox.highlightedIndex === index
} }
Accessible.role: Accessible.ComboBox Accessible.role: Accessible.ComboBox
Accessible.name: comboBox.currentModelName Accessible.name: currentModelName()
Accessible.description: qsTr("The top item is the current model") Accessible.description: qsTr("The top item is the current model")
onActivated: function (index) { onActivated: function (index) {
currentChat.stopGenerating() var newInfo = ModelList.modelInfo(comboBox.valueAt(index));
currentChat.reset(); if (currentModelName() !== ""
currentChat.modelInfo = ModelList.modelInfo(comboBox.valueAt(index)) && newInfo !== currentChat.modelInfo
} && chatModel.count !== 0) {
switchModelDialog.index = index;
switchModelDialog.open();
} else {
comboBox.changeModel(index);
} }
} }
Item { MyMiniButton {
anchors.centerIn: parent id: ejectButton
visible: ModelList.installedModels.count visible: currentChat.isModelLoaded
&& !currentChat.isModelLoaded z: 500
&& currentChat.modelLoadingError === "" anchors.right: parent.right
&& !currentChat.isServer anchors.rightMargin: 50
width: childrenRect.width
height: childrenRect.height
Row {
spacing: 5
MyBusyIndicator {
anchors.verticalCenter: parent.verticalCenter anchors.verticalCenter: parent.verticalCenter
running: parent.visible source: "qrc:/gpt4all/icons/eject.svg"
Accessible.role: Accessible.Animation backgroundColor: theme.gray300
Accessible.name: qsTr("Busy indicator") backgroundColorHovered: theme.iconBackgroundLight
Accessible.description: qsTr("loading model...") onClicked: {
currentChat.forceUnloadModel();
}
ToolTip.text: qsTr("Eject the currently loaded model")
ToolTip.visible: hovered
} }
Label { MyMiniButton {
id: reloadButton
visible: currentChat.modelLoadingError === ""
&& !comboBox.trySwitchContextInProgress
&& (currentChat.isModelLoaded || currentModelName() !== "")
z: 500
anchors.right: ejectButton.visible ? ejectButton.left : parent.right
anchors.rightMargin: ejectButton.visible ? 10 : 50
anchors.verticalCenter: parent.verticalCenter anchors.verticalCenter: parent.verticalCenter
text: qsTr("Loading model...") source: "qrc:/gpt4all/icons/regenerate.svg"
font.pixelSize: theme.fontSizeLarge backgroundColor: theme.gray300
color: theme.oppositeTextColor backgroundColorHovered: theme.iconBackgroundLight
onClicked: {
if (currentChat.isModelLoaded)
currentChat.forceReloadModel();
else
currentChat.reloadModel();
}
ToolTip.text: qsTr("Reload the currently loaded model")
ToolTip.visible: hovered
}
} }
} }
} }
@ -790,9 +874,9 @@ Window {
Rectangle { Rectangle {
id: homePage id: homePage
color: "transparent"//theme.green200 color: "transparent"
anchors.fill: parent anchors.fill: parent
visible: (ModelList.installedModels.count === 0 || chatModel.count === 0) && !currentChat.isServer visible: !currentChat.isModelLoaded && (ModelList.installedModels.count === 0 || currentModelName() === "") && !currentChat.isServer
ColumnLayout { ColumnLayout {
anchors.centerIn: parent anchors.centerIn: parent
@ -1138,10 +1222,14 @@ Window {
} }
} }
RowLayout {
anchors.bottom: textInputView.top
anchors.horizontalCenter: textInputView.horizontalCenter
anchors.bottomMargin: 20
spacing: 10
MyButton { MyButton {
id: myButton
visible: chatModel.count && !currentChat.isServer
textColor: theme.textColor textColor: theme.textColor
visible: chatModel.count && !currentChat.isServer && currentChat.isModelLoaded
Image { Image {
anchors.verticalCenter: parent.verticalCenter anchors.verticalCenter: parent.verticalCenter
anchors.left: parent.left anchors.left: parent.left
@ -1170,20 +1258,50 @@ Window {
} }
} }
} }
background: Rectangle {
border.color: theme.conversationButtonBorder borderWidth: 1
border.width: 2 backgroundColor: theme.conversationButtonBackground
radius: 10 backgroundColorHovered: theme.conversationButtonBackgroundHovered
color: myButton.hovered ? theme.conversationButtonBackgroundHovered : theme.conversationButtonBackground backgroundRadius: 5
}
anchors.bottom: textInputView.top
anchors.horizontalCenter: textInputView.horizontalCenter
anchors.bottomMargin: 20
padding: 15 padding: 15
topPadding: 4
bottomPadding: 4
text: currentChat.responseInProgress ? qsTr("Stop generating") : qsTr("Regenerate response") text: currentChat.responseInProgress ? qsTr("Stop generating") : qsTr("Regenerate response")
fontPixelSize: theme.fontSizeSmaller
Accessible.description: qsTr("Controls generation of the response") Accessible.description: qsTr("Controls generation of the response")
} }
MyButton {
textColor: theme.textColor
visible: chatModel.count
&& !currentChat.isServer
&& !currentChat.isModelLoaded
&& currentChat.modelLoadingPercentage === 0.0
&& currentChat.modelInfo.name !== ""
Image {
anchors.verticalCenter: parent.verticalCenter
anchors.left: parent.left
anchors.leftMargin: 15
source: "qrc:/gpt4all/icons/regenerate.svg"
}
leftPadding: 50
onClicked: {
currentChat.reloadModel();
}
borderWidth: 1
backgroundColor: theme.conversationButtonBackground
backgroundColorHovered: theme.conversationButtonBackgroundHovered
backgroundRadius: 5
padding: 15
topPadding: 4
bottomPadding: 4
text: qsTr("Reload \u00B7 ") + currentChat.modelInfo.name
fontPixelSize: theme.fontSizeSmaller
Accessible.description: qsTr("Reloads the model")
}
}
Text { Text {
id: device id: device
anchors.bottom: textInputView.top anchors.bottom: textInputView.top
@ -1224,7 +1342,7 @@ Window {
rightPadding: 40 rightPadding: 40
enabled: currentChat.isModelLoaded && !currentChat.isServer enabled: currentChat.isModelLoaded && !currentChat.isServer
font.pixelSize: theme.fontSizeLarger font.pixelSize: theme.fontSizeLarger
placeholderText: qsTr("Send a message...") placeholderText: currentChat.isModelLoaded ? qsTr("Send a message...") : qsTr("Load a model to continue...")
Accessible.role: Accessible.EditableText Accessible.role: Accessible.EditableText
Accessible.name: placeholderText Accessible.name: placeholderText
Accessible.description: qsTr("Send messages/prompts to the model") Accessible.description: qsTr("Send messages/prompts to the model")

View File

@ -13,6 +13,7 @@ Button {
property color mutedTextColor: theme.oppositeMutedTextColor property color mutedTextColor: theme.oppositeMutedTextColor
property color backgroundColor: theme.buttonBackground property color backgroundColor: theme.buttonBackground
property color backgroundColorHovered: theme.buttonBackgroundHovered property color backgroundColorHovered: theme.buttonBackgroundHovered
property real backgroundRadius: 10
property real borderWidth: MySettings.chatTheme === "LegacyDark" ? 1 : 0 property real borderWidth: MySettings.chatTheme === "LegacyDark" ? 1 : 0
property color borderColor: theme.buttonBorder property color borderColor: theme.buttonBorder
property real fontPixelSize: theme.fontSizeLarge property real fontPixelSize: theme.fontSizeLarge
@ -25,7 +26,7 @@ Button {
Accessible.name: text Accessible.name: text
} }
background: Rectangle { background: Rectangle {
radius: 10 radius: myButton.backgroundRadius
border.width: myButton.borderWidth border.width: myButton.borderWidth
border.color: myButton.borderColor border.color: myButton.borderColor
color: myButton.hovered ? backgroundColorHovered : backgroundColor color: myButton.hovered ? backgroundColorHovered : backgroundColor

View File

@ -0,0 +1,47 @@
import QtCore
import QtQuick
import QtQuick.Controls
import QtQuick.Controls.Basic
import Qt5Compat.GraphicalEffects
Button {
id: myButton
padding: 0
property color backgroundColor: theme.iconBackgroundDark
property color backgroundColorHovered: theme.iconBackgroundHovered
property alias source: image.source
property alias fillMode: image.fillMode
width: 30
height: 30
contentItem: Text {
text: myButton.text
horizontalAlignment: Text.AlignHCenter
color: myButton.enabled ? theme.textColor : theme.mutedTextColor
font.pixelSize: theme.fontSizeLarge
Accessible.role: Accessible.Button
Accessible.name: text
}
background: Item {
anchors.fill: parent
Rectangle {
anchors.fill: parent
color: "transparent"
}
Image {
id: image
anchors.centerIn: parent
mipmap: true
width: 20
height: 20
}
ColorOverlay {
anchors.fill: image
source: image
color: myButton.hovered ? backgroundColorHovered : backgroundColor
}
}
Accessible.role: Accessible.Button
Accessible.name: text
ToolTip.delay: Qt.styleHints.mousePressAndHoldInterval
}

View File

@ -0,0 +1,44 @@
import QtCore
import QtQuick
import QtQuick.Controls
import QtQuick.Controls.Basic
import QtQuick.Layouts
import llm
import mysettings
MyDialog {
id: switchModelDialog
anchors.centerIn: parent
modal: true
padding: 20
property int index: -1
Theme {
id: theme
}
Column {
id: column
spacing: 20
}
footer: DialogButtonBox {
id: dialogBox
padding: 20
alignment: Qt.AlignRight
spacing: 10
MySettingsButton {
text: qsTr("Continue")
Accessible.description: qsTr("Continue with model loading")
DialogButtonBox.buttonRole: DialogButtonBox.AcceptRole
}
MySettingsButton {
text: qsTr("Cancel")
Accessible.description: qsTr("Cancel")
DialogButtonBox.buttonRole: DialogButtonBox.RejectRole
}
background: Rectangle {
color: "transparent"
}
}
}

View File

@ -555,6 +555,7 @@ QtObject {
property real fontSizeFixedSmall: 16 property real fontSizeFixedSmall: 16
property real fontSize: Qt.application.font.pixelSize property real fontSize: Qt.application.font.pixelSize
property real fontSizeSmaller: fontSizeLarge - 4
property real fontSizeSmall: fontSizeLarge - 2 property real fontSizeSmall: fontSizeLarge - 2
property real fontSizeLarge: MySettings.fontSize === "Small" ? property real fontSizeLarge: MySettings.fontSize === "Small" ?
fontSize : MySettings.fontSize === "Medium" ? fontSize : MySettings.fontSize === "Medium" ?