mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Refactor to handle errors in tool calling better and add source comments.
Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
parent
7c7558eed3
commit
fffd9f341a
@ -815,11 +815,17 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
m_ctx.n_predict = old_n_predict; // now we are ready for a response
|
||||
}
|
||||
|
||||
m_checkToolCall = !isToolCallResponse; // We can't handle recursive tool calls right now
|
||||
// We can't handle recursive tool calls right now otherwise we always try to check if we have a
|
||||
// tool call
|
||||
m_checkToolCall = !isToolCallResponse;
|
||||
|
||||
m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc,
|
||||
/*allowContextShift*/ true, m_ctx);
|
||||
|
||||
// After the response has been handled reset this state
|
||||
m_checkToolCall = false;
|
||||
m_maybeToolCall = false;
|
||||
|
||||
#if defined(DEBUG)
|
||||
printf("\n");
|
||||
fflush(stdout);
|
||||
@ -827,15 +833,66 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
m_timer->stop();
|
||||
qint64 elapsed = totalTime.elapsed();
|
||||
std::string trimmed = trim_whitespace(m_response);
|
||||
|
||||
// If we found a tool call, then deal with it
|
||||
if (m_foundToolCall) {
|
||||
m_foundToolCall = false;
|
||||
|
||||
const QString toolCall = QString::fromStdString(trimmed);
|
||||
const QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
|
||||
if (toolTemplate.isEmpty()) {
|
||||
qWarning() << "ERROR: No valid tool template for this model" << toolCall;
|
||||
return handleFailedToolCall(trimmed, elapsed);
|
||||
}
|
||||
|
||||
QJsonParseError err;
|
||||
const QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err);
|
||||
|
||||
if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
|
||||
qWarning() << "ERROR: The tool call had null or invalid json " << toolCall;
|
||||
return handleFailedToolCall(trimmed, elapsed);
|
||||
}
|
||||
|
||||
QJsonObject rootObject = toolCallDoc.object();
|
||||
if (!rootObject.contains("name") || !rootObject.contains("arguments")) {
|
||||
qWarning() << "ERROR: The tool call did not have required name and argument objects " << toolCall;
|
||||
return handleFailedToolCall(trimmed, elapsed);
|
||||
}
|
||||
|
||||
const QString tool = toolCallDoc["name"].toString();
|
||||
const QJsonObject args = toolCallDoc["arguments"].toObject();
|
||||
|
||||
// FIXME: In the future this will try to match the tool call to a list of tools that are supported
|
||||
// according to MySettings, but for now only brave search is supported
|
||||
if (tool != "brave_search" || !args.contains("query")) {
|
||||
qWarning() << "ERROR: Could not find the tool and correct arguments for " << toolCall;
|
||||
return handleFailedToolCall(trimmed, elapsed);
|
||||
}
|
||||
|
||||
const QString query = args["query"].toString();
|
||||
|
||||
// FIXME: This has to handle errors of the tool call
|
||||
emit toolCalled(tr("searching web..."));
|
||||
const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey();
|
||||
Q_ASSERT(apiKey != "");
|
||||
BraveSearch brave;
|
||||
const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search(apiKey, query, 2 /*topK*/,
|
||||
2000 /*msecs to timeout*/);
|
||||
emit sourceExcerptsChanged(braveResponse.second);
|
||||
|
||||
// Erase the context of the tool call
|
||||
m_ctx.n_past = std::max(0, m_ctx.n_past);
|
||||
m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end());
|
||||
m_promptResponseTokens = 0;
|
||||
m_promptTokens = 0;
|
||||
m_response = std::string();
|
||||
return toolCallInternal(QString::fromStdString(trimmed), n_predict, top_k, top_p, min_p, temp,
|
||||
n_batch, repeat_penalty, repeat_penalty_tokens);
|
||||
|
||||
// This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive
|
||||
// tool calls
|
||||
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, toolTemplate,
|
||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens,
|
||||
true /*isToolCallResponse*/);
|
||||
|
||||
} else {
|
||||
if (trimmed != m_response) {
|
||||
m_response = trimmed;
|
||||
@ -847,65 +904,19 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
|
||||
generateQuestions(elapsed);
|
||||
else
|
||||
emit responseStopped(elapsed);
|
||||
m_pristineLoadedState = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
m_pristineLoadedState = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32_t top_k, float top_p,
|
||||
float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens)
|
||||
bool ChatLLM::handleFailedToolCall(const std::string &response, qint64 elapsed)
|
||||
{
|
||||
QString toolTemplate = MySettings::globalInstance()->modelToolTemplate(m_modelInfo);
|
||||
if (toolTemplate.isEmpty()) {
|
||||
// FIXME: Not sure what to do here. The model attempted a tool call, but there is no way for
|
||||
// us to process it. We should probably not even attempt further generation and just show an
|
||||
// error in the chat somehow?
|
||||
qWarning() << "WARNING: The model attempted a toolcall, but there is no valid tool template for this model" << toolCall;
|
||||
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/,
|
||||
MySettings::globalInstance()->modelPromptTemplate(m_modelInfo),
|
||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
|
||||
}
|
||||
|
||||
QJsonParseError err;
|
||||
QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err);
|
||||
|
||||
if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
|
||||
qWarning() << "WARNING: The tool call had null or invalid json " << toolCall;
|
||||
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
|
||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
|
||||
}
|
||||
|
||||
QJsonObject rootObject = toolCallDoc.object();
|
||||
if (!rootObject.contains("name") || !rootObject.contains("arguments")) {
|
||||
qWarning() << "WARNING: The tool call did not have required name and argument objects " << toolCall;
|
||||
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
|
||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
|
||||
}
|
||||
|
||||
const QString tool = toolCallDoc["name"].toString();
|
||||
const QJsonObject args = toolCallDoc["arguments"].toObject();
|
||||
|
||||
if (tool != "brave_search" || !args.contains("query")) {
|
||||
qWarning() << "WARNING: Could not find the tool and correct arguments for " << toolCall;
|
||||
return promptInternal(QList<QString>()/*collectionList*/, QString() /*prompt*/, toolTemplate,
|
||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
|
||||
}
|
||||
|
||||
const QString query = args["query"].toString();
|
||||
|
||||
emit toolCalled(tr("searching web..."));
|
||||
|
||||
const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey();
|
||||
Q_ASSERT(apiKey != "");
|
||||
|
||||
BraveSearch brave;
|
||||
const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search(apiKey, query, 2 /*topK*/, 2000 /*msecs to timeout*/);
|
||||
|
||||
emit sourceExcerptsChanged(braveResponse.second);
|
||||
|
||||
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, toolTemplate,
|
||||
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens, true /*isToolCallResponse*/);
|
||||
// Restore the strings that we excluded previously when detecting the tool call
|
||||
m_response = "<tool_call>" + response + "</tool_call>";
|
||||
emit responseChanged(QString::fromStdString(m_response));
|
||||
emit responseStopped(elapsed);
|
||||
m_pristineLoadedState = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void ChatLLM::setShouldBeLoaded(bool b)
|
||||
|
@ -200,8 +200,7 @@ protected:
|
||||
bool promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
|
||||
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens, bool isToolCallResponse = false);
|
||||
bool toolCallInternal(const QString &toolcall, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
|
||||
int32_t repeat_penalty_tokens);
|
||||
bool handleFailedToolCall(const std::string &toolCall, qint64 elapsed);
|
||||
bool handlePrompt(int32_t token);
|
||||
bool handleResponse(int32_t token, const std::string &response);
|
||||
bool handleNamePrompt(int32_t token);
|
||||
|
Loading…
Reference in New Issue
Block a user