#include "chat.h" #include "llm.h" #include "network.h" #include "download.h" Chat::Chat(QObject *parent) : QObject(parent) , m_id(Network::globalInstance()->generateUniqueId()) , m_name(tr("New Chat")) , m_chatModel(new ChatModel(this)) , m_responseInProgress(false) , m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_llmodel(new ChatLLM(this)) { // Should be in same thread connect(Download::globalInstance(), &Download::modelListChanged, this, &Chat::modelListChanged, Qt::DirectConnection); connect(this, &Chat::modelNameChanged, this, &Chat::modelListChanged, Qt::DirectConnection); // Should be in different threads connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::modelLoadingError, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection); connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection); connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection); connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); // The following are blocking operations and will block the gui thread, therefore must be fast // to respond to connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::BlockingQueuedConnection); connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection); connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection); } void Chat::reset() { stopGenerating(); // Erase our current on disk representation as we're completely resetting the chat along with id LLM::globalInstance()->chatListModel()->removeChatFile(this); emit resetContextRequested(); // blocking queued connection m_id = Network::globalInstance()->generateUniqueId(); emit idChanged(); // NOTE: We deliberately do no reset the name or creation date to indictate that this was originally // an older chat that was reset for another purpose. Resetting this data will lead to the chat // name label changing back to 'New Chat' and showing up in the chat model list as a 'New Chat' // further down in the list. This might surprise the user. In the future, we me might get rid of // the "reset context" button in the UI. Right now, by changing the model in the combobox dropdown // we effectively do a reset context. We *have* to do this right now when switching between different // types of models. The only way to get rid of that would be a very long recalculate where we rebuild // the context if we switch between different types of models. Probably the right way to fix this // is to allow switching models but throwing up a dialog warning users if we switch between types // of models that a long recalculation will ensue. m_chatModel->clear(); } bool Chat::isModelLoaded() const { return m_llmodel->isModelLoaded(); } void Chat::prompt(const QString &prompt, const QString &prompt_template, int32_t n_predict, int32_t top_k, float top_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) { emit promptRequested(prompt, prompt_template, n_predict, top_k, top_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, LLM::globalInstance()->threadCount()); } void Chat::regenerateResponse() { emit regenerateResponseRequested(); // blocking queued connection } void Chat::stopGenerating() { m_llmodel->stopGenerating(); } QString Chat::response() const { return m_llmodel->response(); } void Chat::handleResponseChanged() { const int index = m_chatModel->count() - 1; m_chatModel->updateValue(index, response()); emit responseChanged(); } void Chat::responseStarted() { m_responseInProgress = true; emit responseInProgressChanged(); } void Chat::responseStopped() { m_responseInProgress = false; emit responseInProgressChanged(); if (m_llmodel->generatedName().isEmpty()) { Network::globalInstance()->sendChatStarted(); emit generateNameRequested(); } } QString Chat::modelName() const { return m_llmodel->modelName(); } void Chat::setModelName(const QString &modelName) { // doesn't block but will unload old model and load new one which the gui can see through changes // to the isModelLoaded property emit modelNameChangeRequested(modelName); } void Chat::newPromptResponsePair(const QString &prompt) { m_chatModel->appendPrompt(tr("Prompt: "), prompt); m_chatModel->appendResponse(tr("Response: "), prompt); emit resetResponseRequested(); // blocking queued connection } bool Chat::isRecalc() const { return m_llmodel->isRecalc(); } void Chat::loadDefaultModel() { emit loadDefaultModelRequested(); } void Chat::loadModel(const QString &modelName) { emit loadModelRequested(modelName); } void Chat::unloadModel() { stopGenerating(); emit unloadModelRequested(); } void Chat::reloadModel() { emit reloadModelRequested(m_savedModelName); } void Chat::generatedNameChanged() { // Only use the first three words maximum and remove newlines and extra spaces QString gen = m_llmodel->generatedName().simplified(); QStringList words = gen.split(' ', Qt::SkipEmptyParts); int wordCount = qMin(3, words.size()); m_name = words.mid(0, wordCount).join(' '); emit nameChanged(); } void Chat::handleRecalculating() { Network::globalInstance()->sendRecalculatingContext(m_chatModel->count()); emit recalcChanged(); } void Chat::handleModelNameChanged() { m_savedModelName = modelName(); emit modelNameChanged(); } bool Chat::serialize(QDataStream &stream, int version) const { stream << m_creationDate; stream << m_id; stream << m_name; stream << m_userName; stream << m_savedModelName; if (!m_llmodel->serialize(stream, version)) return false; if (!m_chatModel->serialize(stream, version)) return false; return stream.status() == QDataStream::Ok; } bool Chat::deserialize(QDataStream &stream, int version) { stream >> m_creationDate; stream >> m_id; emit idChanged(); stream >> m_name; stream >> m_userName; emit nameChanged(); stream >> m_savedModelName; // Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so // unfortunately, we cannot deserialize these if (version < 2 && m_savedModelName.contains("gpt4all-j")) return false; if (!m_llmodel->deserialize(stream, version)) return false; if (!m_chatModel->deserialize(stream, version)) return false; emit chatModelChanged(); return stream.status() == QDataStream::Ok; } QList Chat::modelList() const { // Build a model list from exepath and from the localpath QList list; QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); QString localPath = Download::globalInstance()->downloadLocalModelsPath(); { QDir dir(exePath); dir.setNameFilters(QStringList() << "ggml-*.bin"); QStringList fileNames = dir.entryList(); for (QString f : fileNames) { QString filePath = exePath + f; QFileInfo info(filePath); QString name = info.completeBaseName().remove(0, 5); if (info.exists()) { if (name == modelName()) list.prepend(name); else list.append(name); } } } if (localPath != exePath) { QDir dir(localPath); dir.setNameFilters(QStringList() << "ggml-*.bin"); QStringList fileNames = dir.entryList(); for (QString f : fileNames) { QString filePath = localPath + f; QFileInfo info(filePath); QString name = info.completeBaseName().remove(0, 5); if (info.exists() && !list.contains(name)) { // don't allow duplicates if (name == modelName()) list.prepend(name); else list.append(name); } } } if (list.isEmpty()) { if (exePath != localPath) { qWarning() << "ERROR: Could not find any applicable models in" << exePath << "nor" << localPath; } else { qWarning() << "ERROR: Could not find any applicable models in" << exePath; } return QList(); } return list; }