diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 06c9223e..b02592d9 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -815,11 +815,17 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_ctx.n_predict = old_n_predict; // now we are ready for a response } - m_checkToolCall = !isToolCallResponse; // We can't handle recursive tool calls right now + // We can't handle recursive tool calls right now otherwise we always try to check if we have a + // tool call + m_checkToolCall = !isToolCallResponse; + m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, /*allowContextShift*/ true, m_ctx); + + // After the response has been handled reset this state m_checkToolCall = false; m_maybeToolCall = false; + #if defined(DEBUG) printf("\n"); fflush(stdout); @@ -827,15 +833,66 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_timer->stop(); qint64 elapsed = totalTime.elapsed(); std::string trimmed = trim_whitespace(m_response); + + // If we found a tool call, then deal with it if (m_foundToolCall) { m_foundToolCall = false; + + const QString toolCall = QString::fromStdString(trimmed); + const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); + if (toolTemplate.isEmpty()) { + qWarning() << "ERROR: No valid tool template for this model" << toolCall; + return handleFailedToolCall(trimmed, elapsed); + } + + QJsonParseError err; + const QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err); + + if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { + qWarning() << "ERROR: The tool call had null or invalid json " << toolCall; + return handleFailedToolCall(trimmed, elapsed); + } + + QJsonObject rootObject = toolCallDoc.object(); + if (!rootObject.contains("name") || !rootObject.contains("arguments")) { + qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall; + return handleFailedToolCall(trimmed, elapsed); + } + + const QString tool = toolCallDoc["name"].toString(); + const QJsonObject args = toolCallDoc["arguments"].toObject(); + + // FIXME: In the future this will try to match the tool call to a list of tools that are supported + // according to MySettings, but for now only brave search is supported + if (tool != "brave_search" || !args.contains("query")) { + qWarning() << "ERROR: Could not find the tool and correct arguments for " << toolCall; + return handleFailedToolCall(trimmed, elapsed); + } + + const QString query = args["query"].toString(); + + // FIXME: This has to handle errors of the tool call + emit toolCalled(tr("searching web...")); + const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); + Q_ASSERT(apiKey != ""); + BraveSearch brave; + const QPair> braveResponse = brave.search(apiKey, query, 2 /*topK*/, + 2000 /*msecs to timeout*/); + emit sourceExcerptsChanged(braveResponse.second); + + // Erase the context of the tool call m_ctx.n_past = std::max(0, m_ctx.n_past); m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end()); m_promptResponseTokens = 0; m_promptTokens = 0; m_response = std::string(); - return toolCallInternal(QString::fromStdString(trimmed), n_predict, top_k, top_p, min_p, temp, - n_batch, repeat_penalty, repeat_penalty_tokens); + + // This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive + // tool calls + return promptInternal(QList()/*collectionList*/, braveResponse.first, toolTemplate, + n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, + true /*isToolCallResponse*/); + } else { if (trimmed != m_response) { m_response = trimmed; @@ -847,65 +904,19 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString generateQuestions(elapsed); else emit responseStopped(elapsed); + m_pristineLoadedState = false; + return true; } - - m_pristineLoadedState = false; - return true; } -bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32_t top_k, float top_p, - float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) +bool ChatLLM::handleFailedToolCall(const std::string &response, qint64 elapsed) { - QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo); - if (toolTemplate.isEmpty()) { - // FIXME: Not sure what to do here. The model attempted a tool call, but there is no way for - // us to process it. We should probably not even attempt further generation and just show an - // error in the chat somehow? - qWarning() << "WARNING: The model attempted a toolcall, but there is no valid tool template for this model" << toolCall; - return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, - MySettings::globalInstance()->modelPromptTemplate(m_modelInfo), - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); - } - - QJsonParseError err; - QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err); - - if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { - qWarning() << "WARNING: The tool call had null or invalid json " << toolCall; - return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); - } - - QJsonObject rootObject = toolCallDoc.object(); - if (!rootObject.contains("name") || !rootObject.contains("arguments")) { - qWarning() << "WARNING: The tool call did not have required name and argument objects " << toolCall; - return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); - } - - const QString tool = toolCallDoc["name"].toString(); - const QJsonObject args = toolCallDoc["arguments"].toObject(); - - if (tool != "brave_search" || !args.contains("query")) { - qWarning() << "WARNING: Could not find the tool and correct arguments for " << toolCall; - return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, toolTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); - } - - const QString query = args["query"].toString(); - - emit toolCalled(tr("searching web...")); - - const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey(); - Q_ASSERT(apiKey != ""); - - BraveSearch brave; - const QPair> braveResponse = brave.search(apiKey, query, 2 /*topK*/, 2000 /*msecs to timeout*/); - - emit sourceExcerptsChanged(braveResponse.second); - - return promptInternal(QList()/*collectionList*/, braveResponse.first, toolTemplate, - n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/); + // Restore the strings that we excluded previously when detecting the tool call + m_response = "" + response + ""; + emit responseChanged(QString::fromStdString(m_response)); + emit responseStopped(elapsed); + m_pristineLoadedState = false; + return true; } void ChatLLM::setShouldBeLoaded(bool b) diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index d9d47ae9..050c0b7d 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -200,8 +200,7 @@ protected: bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens, bool isToolCallResponse = false); - bool toolCallInternal(const QString &toolcall, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, - int32_t repeat_penalty_tokens); + bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); bool handleNamePrompt(int32_t token);