mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Ollama WIP
This commit is contained in:
parent
c13b33fb4d
commit
912d08c8ae
670
gpt4all-chat/ollama_model.cpp
Normal file
670
gpt4all-chat/ollama_model.cpp
Normal file
@ -0,0 +1,670 @@
|
||||
#include "ollama_model.h"
|
||||
|
||||
#include "chat.h"
|
||||
#include "chatapi.h"
|
||||
#include "localdocs.h"
|
||||
#include "mysettings.h"
|
||||
#include "network.h"
|
||||
|
||||
#include <QDataStream>
|
||||
#include <QDebug>
|
||||
#include <QFile>
|
||||
#include <QGlobalStatic>
|
||||
#include <QIODevice>
|
||||
#include <QJsonDocument>
|
||||
#include <QJsonObject>
|
||||
#include <QMutex>
|
||||
#include <QMutexLocker>
|
||||
#include <QSet>
|
||||
#include <QStringList>
|
||||
#include <QWaitCondition>
|
||||
#include <Qt>
|
||||
#include <QtLogging>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using namespace Qt::Literals::StringLiterals;
|
||||
|
||||
|
||||
#define OLLAMA_INTERNAL_STATE_VERSION 0
|
||||
|
||||
OllamaModel::OllamaModel()
|
||||
: m_shouldBeLoaded(false)
|
||||
, m_forceUnloadModel(false)
|
||||
, m_markedForDeletion(false)
|
||||
, m_stopGenerating(false)
|
||||
, m_timer(new TokenTimer(this))
|
||||
, m_processedSystemPrompt(false)
|
||||
{
|
||||
connect(this, &OllamaModel::shouldBeLoadedChanged, this, &OllamaModel::handleShouldBeLoadedChanged);
|
||||
connect(this, &OllamaModel::trySwitchContextRequested, this, &OllamaModel::trySwitchContextOfLoadedModel);
|
||||
connect(m_timer, &TokenTimer::report, this, &OllamaModel::reportSpeed);
|
||||
|
||||
// The following are blocking operations and will block the llm thread
|
||||
connect(this, &OllamaModel::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
|
||||
Qt::BlockingQueuedConnection);
|
||||
}
|
||||
|
||||
OllamaModel::~OllamaModel()
|
||||
{
|
||||
destroy();
|
||||
}
|
||||
|
||||
void OllamaModel::destroy()
|
||||
{
|
||||
// TODO(jared): cancel pending network requests
|
||||
}
|
||||
|
||||
void OllamaModel::destroyStore()
|
||||
{
|
||||
LLModelStore::globalInstance()->destroy();
|
||||
}
|
||||
|
||||
bool OllamaModel::loadDefaultModel()
|
||||
{
|
||||
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
|
||||
if (defaultModel.filename().isEmpty()) {
|
||||
emit modelLoadingError(u"Could not find any model to load"_s);
|
||||
return false;
|
||||
}
|
||||
return loadModel(defaultModel);
|
||||
}
|
||||
|
||||
void OllamaModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
{
|
||||
// no-op: we require the model to be explicitly loaded for now.
|
||||
}
|
||||
|
||||
bool OllamaModel::loadModel(const ModelInfo &modelInfo)
|
||||
{
|
||||
// We're already loaded with this model
|
||||
if (isModelLoaded() && this->modelInfo() == modelInfo)
|
||||
return true;
|
||||
|
||||
// reset status
|
||||
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
|
||||
emit modelLoadingError("");
|
||||
|
||||
QString filePath = modelInfo.dirpath + modelInfo.filename();
|
||||
QFileInfo fileInfo(filePath);
|
||||
|
||||
// We have a live model, but it isn't the one we want
|
||||
bool alreadyAcquired = isModelLoaded();
|
||||
if (alreadyAcquired) {
|
||||
resetContext();
|
||||
m_llModelInfo.resetModel(this);
|
||||
} else {
|
||||
// This is a blocking call that tries to retrieve the model we need from the model store.
|
||||
// If it succeeds, then we just have to restore state. If the store has never had a model
|
||||
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
|
||||
acquireModel();
|
||||
// At this point it is possible that while we were blocked waiting to acquire the model from the
|
||||
// store, that our state was changed to not be loaded. If this is the case, release the model
|
||||
// back into the store and quit loading
|
||||
if (!m_shouldBeLoaded) {
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if the store just gave us exactly the model we were looking for
|
||||
if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo) {
|
||||
restoreState();
|
||||
emit modelLoadingPercentageChanged(1.0f);
|
||||
setModelInfo(modelInfo);
|
||||
Q_ASSERT(!m_modelInfo.filename().isEmpty());
|
||||
if (m_modelInfo.filename().isEmpty())
|
||||
emit modelLoadingError(u"Modelinfo is left null for %1"_s.arg(modelInfo.filename()));
|
||||
else
|
||||
processSystemPrompt();
|
||||
return true;
|
||||
} else {
|
||||
// Release the memory since we have to switch to a different model.
|
||||
m_llModelInfo.resetModel(this);
|
||||
}
|
||||
}
|
||||
|
||||
// Guarantee we've released the previous models memory
|
||||
Q_ASSERT(!m_llModelInfo.model);
|
||||
|
||||
// Store the file info in the modelInfo in case we have an error loading
|
||||
m_llModelInfo.fileInfo = fileInfo;
|
||||
|
||||
if (fileInfo.exists()) {
|
||||
QVariantMap modelLoadProps;
|
||||
|
||||
// TODO(jared): load the model here
|
||||
#if 0
|
||||
if (modelInfo.isOnline) {
|
||||
QString apiKey;
|
||||
QString requestUrl;
|
||||
QString modelName;
|
||||
{
|
||||
QFile file(filePath);
|
||||
bool success = file.open(QIODeviceBase::ReadOnly);
|
||||
(void)success;
|
||||
Q_ASSERT(success);
|
||||
QJsonDocument doc = QJsonDocument::fromJson(file.readAll());
|
||||
QJsonObject obj = doc.object();
|
||||
apiKey = obj["apiKey"].toString();
|
||||
modelName = obj["modelName"].toString();
|
||||
if (modelInfo.isCompatibleApi) {
|
||||
QString baseUrl(obj["baseUrl"].toString());
|
||||
QUrl apiUrl(QUrl::fromUserInput(baseUrl));
|
||||
if (!Network::isHttpUrlValid(apiUrl))
|
||||
return false;
|
||||
|
||||
QString currentPath(apiUrl.path());
|
||||
QString suffixPath("%1/chat/completions");
|
||||
apiUrl.setPath(suffixPath.arg(currentPath));
|
||||
requestUrl = apiUrl.toString();
|
||||
} else {
|
||||
requestUrl = modelInfo.url();
|
||||
}
|
||||
}
|
||||
ChatAPI *model = new ChatAPI();
|
||||
model->setModelName(modelName);
|
||||
model->setRequestURL(requestUrl);
|
||||
model->setAPIKey(apiKey);
|
||||
m_llModelInfo.resetModel(this, model);
|
||||
} else if (!loadNewModel(modelInfo, modelLoadProps)) {
|
||||
return false; // m_shouldBeLoaded became false
|
||||
}
|
||||
#endif
|
||||
|
||||
restoreState();
|
||||
emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f);
|
||||
emit loadedModelInfoChanged();
|
||||
|
||||
modelLoadProps.insert("model", modelInfo.filename());
|
||||
Network::globalInstance()->trackChatEvent("model_load", modelLoadProps);
|
||||
} else {
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); // release back into the store
|
||||
resetModel();
|
||||
emit modelLoadingError(u"Could not find file for model %1"_s.arg(modelInfo.filename()));
|
||||
}
|
||||
|
||||
if (m_llModelInfo.model) {
|
||||
setModelInfo(modelInfo);
|
||||
processSystemPrompt();
|
||||
}
|
||||
return bool(m_llModelInfo.model);
|
||||
}
|
||||
|
||||
bool OllamaModel::isModelLoaded() const
|
||||
{
|
||||
return m_llModelInfo.model && m_llModelInfo.model->isModelLoaded();
|
||||
}
|
||||
|
||||
// FIXME(jared): we don't actually have to re-decode the prompt to generate a new response
|
||||
void OllamaModel::regenerateResponse()
|
||||
{
|
||||
m_ctx.n_past = std::max(0, m_ctx.n_past - m_promptResponseTokens);
|
||||
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
|
||||
m_promptResponseTokens = 0;
|
||||
m_promptTokens = 0;
|
||||
m_response = std::string();
|
||||
emit responseChanged(QString::fromStdString(m_response));
|
||||
}
|
||||
|
||||
void OllamaModel::resetResponse()
|
||||
{
|
||||
m_promptTokens = 0;
|
||||
m_promptResponseTokens = 0;
|
||||
m_response = std::string();
|
||||
emit responseChanged(QString::fromStdString(m_response));
|
||||
}
|
||||
|
||||
void OllamaModel::resetContext()
|
||||
{
|
||||
resetResponse();
|
||||
m_processedSystemPrompt = false;
|
||||
m_ctx = ModelBackend::PromptContext();
|
||||
}
|
||||
|
||||
QString OllamaModel::response() const
|
||||
{
|
||||
return QString::fromStdString(remove_leading_whitespace(m_response));
|
||||
}
|
||||
|
||||
void OllamaModel::setModelInfo(const ModelInfo &modelInfo)
|
||||
{
|
||||
m_modelInfo = modelInfo;
|
||||
emit modelInfoChanged(modelInfo);
|
||||
}
|
||||
|
||||
void OllamaModel::acquireModel()
|
||||
{
|
||||
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
|
||||
emit loadedModelInfoChanged();
|
||||
}
|
||||
|
||||
void OllamaModel::resetModel()
|
||||
{
|
||||
m_llModelInfo = {};
|
||||
emit loadedModelInfoChanged();
|
||||
}
|
||||
|
||||
void OllamaModel::modelChangeRequested(const ModelInfo &modelInfo)
|
||||
{
|
||||
m_shouldBeLoaded = true;
|
||||
loadModel(modelInfo);
|
||||
}
|
||||
|
||||
bool OllamaModel::handlePrompt(int32_t token)
|
||||
{
|
||||
// m_promptResponseTokens is related to last prompt/response not
|
||||
// the entire context window which we can reset on regenerate prompt
|
||||
++m_promptTokens;
|
||||
++m_promptResponseTokens;
|
||||
m_timer->start();
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool OllamaModel::handleResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
// check for error
|
||||
if (token < 0) {
|
||||
m_response.append(response);
|
||||
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
|
||||
return false;
|
||||
}
|
||||
|
||||
// m_promptResponseTokens is related to last prompt/response not
|
||||
// the entire context window which we can reset on regenerate prompt
|
||||
++m_promptResponseTokens;
|
||||
m_timer->inc();
|
||||
Q_ASSERT(!response.empty());
|
||||
m_response.append(response);
|
||||
emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response)));
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool OllamaModel::prompt(const QList<QString> &collectionList, const QString &prompt)
|
||||
{
|
||||
if (!m_processedSystemPrompt)
|
||||
processSystemPrompt();
|
||||
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
|
||||
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
|
||||
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
|
||||
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
|
||||
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
|
||||
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
|
||||
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
|
||||
return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch,
|
||||
repeat_penalty, repeat_penalty_tokens);
|
||||
}
|
||||
|
||||
bool OllamaModel::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens)
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
return false;
|
||||
|
||||
QList<ResultInfo> databaseResults;
|
||||
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
|
||||
if (!collectionList.isEmpty()) {
|
||||
emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks
|
||||
emit databaseResultsChanged(databaseResults);
|
||||
}
|
||||
|
||||
// Augment the prompt template with the results if any
|
||||
QString docsContext;
|
||||
if (!databaseResults.isEmpty()) {
|
||||
QStringList results;
|
||||
for (const ResultInfo &info : databaseResults)
|
||||
results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text);
|
||||
|
||||
// FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template
|
||||
docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n"));
|
||||
}
|
||||
|
||||
int n_threads = MySettings::globalInstance()->threadCount();
|
||||
|
||||
m_stopGenerating = false;
|
||||
auto promptFunc = std::bind(&OllamaModel::handlePrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&OllamaModel::handleResponse, this, std::placeholders::_1,
|
||||
std::placeholders::_2);
|
||||
emit promptProcessing();
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
m_ctx.top_p = top_p;
|
||||
m_ctx.min_p = min_p;
|
||||
m_ctx.temp = temp;
|
||||
m_ctx.n_batch = n_batch;
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
|
||||
QElapsedTimer totalTime;
|
||||
totalTime.start();
|
||||
m_timer->start();
|
||||
if (!docsContext.isEmpty()) {
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response
|
||||
m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc,
|
||||
/*allowContextShift*/ true, m_ctx);
|
||||
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
||||
}
|
||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
|
||||
/*allowContextShift*/ true, m_ctx);
|
||||
|
||||
m_timer->stop();
|
||||
qint64 elapsed = totalTime.elapsed();
|
||||
std::string trimmed = trim_whitespace(m_response);
|
||||
if (trimmed != m_response) {
|
||||
m_response = trimmed;
|
||||
emit responseChanged(QString::fromStdString(m_response));
|
||||
}
|
||||
|
||||
SuggestionMode mode = MySettings::globalInstance()->suggestionMode();
|
||||
if (mode == SuggestionMode::On || (!databaseResults.isEmpty() && mode == SuggestionMode::LocalDocsOnly))
|
||||
generateQuestions(elapsed);
|
||||
else
|
||||
emit responseStopped(elapsed);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void OllamaModel::setShouldBeLoaded(bool value, bool forceUnload)
|
||||
{
|
||||
m_shouldBeLoaded = b; // atomic
|
||||
emit shouldBeLoadedChanged(forceUnload);
|
||||
}
|
||||
|
||||
void OllamaModel::requestTrySwitchContext()
|
||||
{
|
||||
m_shouldBeLoaded = true; // atomic
|
||||
emit trySwitchContextRequested(modelInfo());
|
||||
}
|
||||
|
||||
void OllamaModel::handleShouldBeLoadedChanged()
|
||||
{
|
||||
if (m_shouldBeLoaded)
|
||||
reloadModel();
|
||||
else
|
||||
unloadModel();
|
||||
}
|
||||
|
||||
void OllamaModel::unloadModel()
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
return;
|
||||
|
||||
if (!m_forceUnloadModel || !m_shouldBeLoaded)
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
else
|
||||
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
|
||||
|
||||
if (!m_markedForDeletion)
|
||||
saveState();
|
||||
|
||||
if (m_forceUnloadModel) {
|
||||
m_llModelInfo.resetModel(this);
|
||||
m_forceUnloadModel = false;
|
||||
}
|
||||
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
}
|
||||
|
||||
void OllamaModel::reloadModel()
|
||||
{
|
||||
if (isModelLoaded() && m_forceUnloadModel)
|
||||
unloadModel(); // we unload first if we are forcing an unload
|
||||
|
||||
if (isModelLoaded())
|
||||
return;
|
||||
|
||||
const ModelInfo m = modelInfo();
|
||||
if (m.name().isEmpty())
|
||||
loadDefaultModel();
|
||||
else
|
||||
loadModel(m);
|
||||
}
|
||||
|
||||
void OllamaModel::generateName()
|
||||
{
|
||||
Q_ASSERT(isModelLoaded());
|
||||
if (!isModelLoaded())
|
||||
return;
|
||||
|
||||
const QString chatNamePrompt = MySettings::globalInstance()->modelChatNamePrompt(m_modelInfo);
|
||||
if (chatNamePrompt.trimmed().isEmpty()) {
|
||||
qWarning() << "OllamaModel: not generating chat name because prompt is empty";
|
||||
return;
|
||||
}
|
||||
|
||||
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
auto promptFunc = std::bind(&OllamaModel::handleNamePrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&OllamaModel::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
|
||||
ModelBackend::PromptContext ctx = m_ctx;
|
||||
m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(),
|
||||
promptFunc, responseFunc, /*allowContextShift*/ false, ctx);
|
||||
std::string trimmed = trim_whitespace(m_nameResponse);
|
||||
if (trimmed != m_nameResponse) {
|
||||
m_nameResponse = trimmed;
|
||||
emit generatedNameChanged(QString::fromStdString(m_nameResponse));
|
||||
}
|
||||
}
|
||||
|
||||
bool OllamaModel::handleNamePrompt(int32_t token)
|
||||
{
|
||||
Q_UNUSED(token);
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool OllamaModel::handleNameResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
Q_UNUSED(token);
|
||||
|
||||
m_nameResponse.append(response);
|
||||
emit generatedNameChanged(QString::fromStdString(m_nameResponse));
|
||||
QString gen = QString::fromStdString(m_nameResponse).simplified();
|
||||
QStringList words = gen.split(' ', Qt::SkipEmptyParts);
|
||||
return words.size() <= 3;
|
||||
}
|
||||
|
||||
bool OllamaModel::handleQuestionPrompt(int32_t token)
|
||||
{
|
||||
Q_UNUSED(token);
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
bool OllamaModel::handleQuestionResponse(int32_t token, const std::string &response)
|
||||
{
|
||||
Q_UNUSED(token);
|
||||
|
||||
// add token to buffer
|
||||
m_questionResponse.append(response);
|
||||
|
||||
// match whole question sentences
|
||||
// FIXME: This only works with response by the model in english which is not ideal for a multi-language
|
||||
// model.
|
||||
static const QRegularExpression reQuestion(R"(\b(What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)");
|
||||
|
||||
// extract all questions from response
|
||||
int lastMatchEnd = -1;
|
||||
for (const auto &match : reQuestion.globalMatch(m_questionResponse)) {
|
||||
lastMatchEnd = match.capturedEnd();
|
||||
emit generatedQuestionFinished(match.captured());
|
||||
}
|
||||
|
||||
// remove processed input from buffer
|
||||
if (lastMatchEnd != -1)
|
||||
m_questionResponse.erase(m_questionResponse.cbegin(), m_questionResponse.cbegin() + lastMatchEnd);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void OllamaModel::generateQuestions(qint64 elapsed)
|
||||
{
|
||||
Q_ASSERT(isModelLoaded());
|
||||
if (!isModelLoaded()) {
|
||||
emit responseStopped(elapsed);
|
||||
return;
|
||||
}
|
||||
|
||||
const std::string suggestedFollowUpPrompt = MySettings::globalInstance()->modelSuggestedFollowUpPrompt(m_modelInfo).toStdString();
|
||||
if (QString::fromStdString(suggestedFollowUpPrompt).trimmed().isEmpty()) {
|
||||
emit responseStopped(elapsed);
|
||||
return;
|
||||
}
|
||||
|
||||
emit generatingQuestions();
|
||||
m_questionResponse.clear();
|
||||
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
|
||||
auto promptFunc = std::bind(&OllamaModel::handleQuestionPrompt, this, std::placeholders::_1);
|
||||
auto responseFunc = std::bind(&OllamaModel::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2);
|
||||
ModelBackend::PromptContext ctx = m_ctx;
|
||||
QElapsedTimer totalTime;
|
||||
totalTime.start();
|
||||
m_llModelInfo.model->prompt(suggestedFollowUpPrompt, promptTemplate.toStdString(), promptFunc, responseFunc,
|
||||
/*allowContextShift*/ false, ctx);
|
||||
elapsed += totalTime.elapsed();
|
||||
emit responseStopped(elapsed);
|
||||
}
|
||||
|
||||
|
||||
bool OllamaModel::handleSystemPrompt(int32_t token)
|
||||
{
|
||||
Q_UNUSED(token);
|
||||
return !m_stopGenerating;
|
||||
}
|
||||
|
||||
// this function serialized the cached model state to disk.
|
||||
// we want to also serialize n_ctx, and read it at load time.
|
||||
bool OllamaModel::serialize(QDataStream &stream, int version, bool serializeKV)
|
||||
{
|
||||
Q_UNUSED(serializeKV);
|
||||
|
||||
if (version < 10)
|
||||
throw std::out_of_range("ollama not avaliable until chat version 10, attempted to serialize version " + std::to_string(version));
|
||||
|
||||
stream << OLLAMA_INTERNAL_STATE_VERSION;
|
||||
|
||||
stream << response();
|
||||
stream << generatedName();
|
||||
// TODO(jared): do not save/restore m_promptResponseTokens, compute the appropriate value instead
|
||||
stream << m_promptResponseTokens;
|
||||
|
||||
stream << m_ctx.n_ctx;
|
||||
saveState();
|
||||
QByteArray compressed = qCompress(m_state);
|
||||
stream << compressed;
|
||||
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
bool OllamaModel::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
|
||||
{
|
||||
Q_UNUSED(deserializeKV);
|
||||
Q_UNUSED(discardKV);
|
||||
|
||||
Q_ASSERT(version >= 10);
|
||||
|
||||
int internalStateVersion;
|
||||
stream >> internalStateVersion; // for future use
|
||||
|
||||
QString response;
|
||||
stream >> response;
|
||||
m_response = response.toStdString();
|
||||
QString nameResponse;
|
||||
stream >> nameResponse;
|
||||
m_nameResponse = nameResponse.toStdString();
|
||||
stream >> m_promptResponseTokens;
|
||||
|
||||
uint32_t n_ctx;
|
||||
stream >> n_ctx;
|
||||
m_ctx.n_ctx = n_ctx;
|
||||
|
||||
QByteArray compressed;
|
||||
stream >> compressed;
|
||||
m_state = qUncompress(compressed);
|
||||
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
void OllamaModel::saveState()
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
return;
|
||||
|
||||
// m_llModelType == LLModelType::API_
|
||||
m_state.clear();
|
||||
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
|
||||
stream.setVersion(QDataStream::Qt_6_4);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI *>(m_llModelInfo.model.get());
|
||||
stream << chatAPI->context();
|
||||
// end API
|
||||
}
|
||||
|
||||
void OllamaModel::restoreState()
|
||||
{
|
||||
if (!isModelLoaded())
|
||||
return;
|
||||
|
||||
// m_llModelType == LLModelType::API_
|
||||
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
|
||||
stream.setVersion(QDataStream::Qt_6_4);
|
||||
ChatAPI *chatAPI = static_cast<ChatAPI *>(m_llModelInfo.model.get());
|
||||
QList<QString> context;
|
||||
stream >> context;
|
||||
chatAPI->setContext(context);
|
||||
m_state.clear();
|
||||
m_state.squeeze();
|
||||
// end API
|
||||
}
|
||||
|
||||
void OllamaModel::processSystemPrompt()
|
||||
{
|
||||
Q_ASSERT(isModelLoaded());
|
||||
if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText)
|
||||
return;
|
||||
|
||||
const std::string systemPrompt = MySettings::globalInstance()->modelSystemPrompt(m_modelInfo).toStdString();
|
||||
if (QString::fromStdString(systemPrompt).trimmed().isEmpty()) {
|
||||
m_processedSystemPrompt = true;
|
||||
return;
|
||||
}
|
||||
|
||||
// Start with a whole new context
|
||||
m_stopGenerating = false;
|
||||
m_ctx = ModelBackend::PromptContext();
|
||||
|
||||
auto promptFunc = std::bind(&OllamaModel::handleSystemPrompt, this, std::placeholders::_1);
|
||||
|
||||
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
|
||||
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
|
||||
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
|
||||
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
|
||||
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
|
||||
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
|
||||
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
|
||||
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
|
||||
int n_threads = MySettings::globalInstance()->threadCount();
|
||||
m_ctx.n_predict = n_predict;
|
||||
m_ctx.top_k = top_k;
|
||||
m_ctx.top_p = top_p;
|
||||
m_ctx.min_p = min_p;
|
||||
m_ctx.temp = temp;
|
||||
m_ctx.n_batch = n_batch;
|
||||
m_ctx.repeat_penalty = repeat_penalty;
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
|
||||
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
|
||||
// use "%1%2" and not "%1" to avoid implicit whitespace
|
||||
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true);
|
||||
m_ctx.n_predict = old_n_predict;
|
||||
|
||||
m_processedSystemPrompt = m_stopGenerating == false;
|
||||
}
|
51
gpt4all-chat/ollama_model.h
Normal file
51
gpt4all-chat/ollama_model.h
Normal file
@ -0,0 +1,51 @@
|
||||
#pragma once
|
||||
|
||||
#include "database.h" // IWYU pragma: keep
|
||||
#include "llmodel.h"
|
||||
#include "modellist.h" // IWYU pragma: keep
|
||||
|
||||
#include <QList>
|
||||
#include <QObject>
|
||||
#include <QPair>
|
||||
#include <QString>
|
||||
#include <QVector>
|
||||
|
||||
class Chat;
|
||||
class QDataStream;
|
||||
|
||||
|
||||
class OllamaModel : public LLModel
|
||||
{
|
||||
Q_OBJECT
|
||||
|
||||
public:
|
||||
OllamaModel();
|
||||
~OllamaModel() override = default;
|
||||
|
||||
void regenerateResponse() override;
|
||||
void resetResponse() override;
|
||||
void resetContext() override;
|
||||
|
||||
void stopGenerating() override;
|
||||
|
||||
void setShouldBeLoaded(bool b) override;
|
||||
void requestTrySwitchContext() override;
|
||||
void setForceUnloadModel(bool b) override;
|
||||
void setMarkedForDeletion(bool b) override;
|
||||
|
||||
void setModelInfo(const ModelInfo &info) override;
|
||||
|
||||
bool restoringFromText() const override;
|
||||
|
||||
bool serialize(QDataStream &stream, int version, bool serializeKV) override;
|
||||
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) override;
|
||||
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) override;
|
||||
|
||||
public Q_SLOTS:
|
||||
bool prompt(const QList<QString> &collectionList, const QString &prompt) override;
|
||||
bool loadDefaultModel() override;
|
||||
bool loadModel(const ModelInfo &modelInfo) override;
|
||||
void modelChangeRequested(const ModelInfo &modelInfo) override;
|
||||
void generateName() override;
|
||||
void processSystemPrompt() override;
|
||||
};
|
Loading…
Reference in New Issue
Block a user