Refactor the brave search and introduce an abstraction for tool calls.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat 2024-07-31 14:54:38 -04:00
parent fffd9f341a
commit dfe3e951d4
10 changed files with 363 additions and 137 deletions

View File

@ -121,9 +121,10 @@ qt_add_executable(chat
modellist.h modellist.cpp
mysettings.h mysettings.cpp
network.h network.cpp
sourceexcerpt.h
sourceexcerpt.h sourceexcerpt.cpp
server.h server.cpp
logger.h logger.cpp
tool.h tool.cpp
${APP_ICON_RESOURCE}
${CHAT_EXE_RESOURCES}
)

View File

@ -16,15 +16,19 @@
using namespace Qt::Literals::StringLiterals;
QPair<QString, QList<SourceExcerpt>> BraveSearch::search(const QString &apiKey, const QString &query, int topK, unsigned long timeout)
QString BraveSearch::run(const QJsonObject &parameters, qint64 timeout)
{
const QString apiKey = parameters["apiKey"].toString();
const QString query = parameters["query"].toString();
const int count = parameters["count"].toInt();
QThread workerThread;
BraveAPIWorker worker;
worker.moveToThread(&workerThread);
connect(&worker, &BraveAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(this, &BraveSearch::request, &worker, &BraveAPIWorker::request, Qt::QueuedConnection);
connect(&workerThread, &QThread::started, [&worker, apiKey, query, count]() {
worker.request(apiKey, query, count);
});
workerThread.start();
emit request(apiKey, query, topK);
workerThread.wait(timeout);
workerThread.quit();
workerThread.wait();
@ -34,19 +38,25 @@ QPair<QString, QList<SourceExcerpt>> BraveSearch::search(const QString &apiKey,
void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK)
{
m_topK = topK;
// Documentation on the brave web search:
// https://api.search.brave.com/app/documentation/web-search/get-started
QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search");
// Documentation on the query options:
//https://api.search.brave.com/app/documentation/web-search/query
QUrlQuery urlQuery;
urlQuery.addQueryItem("q", query);
urlQuery.addQueryItem("count", QString::number(topK));
urlQuery.addQueryItem("result_filter", "web");
urlQuery.addQueryItem("extra_snippets", "true");
jsonUrl.setQuery(urlQuery);
QNetworkRequest request(jsonUrl);
QSslConfiguration conf = request.sslConfiguration();
conf.setPeerVerifyMode(QSslSocket::VerifyNone);
request.setSslConfiguration(conf);
request.setRawHeader("X-Subscription-Token", apiKey.toUtf8());
// request.setRawHeader("Accept-Encoding", "gzip");
request.setRawHeader("Accept", "application/json");
m_networkManager = new QNetworkAccessManager(this);
QNetworkReply *reply = m_networkManager->get(request);
connect(qGuiApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
@ -54,154 +64,71 @@ void BraveAPIWorker::request(const QString &apiKey, const QString &query, int to
connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred);
}
static QPair<QString, QList<SourceExcerpt>> cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1)
static QString cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1)
{
// This parses the response from brave and formats it in json that conforms to the de facto
// standard in SourceExcerpts::fromJson(...)
QJsonParseError err;
QJsonDocument document = QJsonDocument::fromJson(jsonResponse, &err);
if (err.error != QJsonParseError::NoError) {
qWarning() << "ERROR: Couldn't parse: " << jsonResponse << err.errorString();
return QPair<QString, QList<SourceExcerpt>>();
qWarning() << "ERROR: Couldn't parse brave response: " << jsonResponse << err.errorString();
return QString();
}
QString query;
QJsonObject searchResponse = document.object();
QJsonObject cleanResponse;
QString query;
QJsonArray cleanArray;
QList<SourceExcerpt> infos;
if (searchResponse.contains("query")) {
QJsonObject queryObj = searchResponse["query"].toObject();
if (queryObj.contains("original")) {
if (queryObj.contains("original"))
query = queryObj["original"].toString();
}
}
if (searchResponse.contains("mixed")) {
QJsonObject mixedResults = searchResponse["mixed"].toObject();
QJsonArray mainResults = mixedResults["main"].toArray();
QJsonObject resultsObject = searchResponse["web"].toObject();
QJsonArray resultsArray = resultsObject["results"].toArray();
for (int i = 0; i < std::min(mainResults.size(), topK); ++i) {
for (int i = 0; i < std::min(mainResults.size(), resultsArray.size()); ++i) {
QJsonObject m = mainResults[i].toObject();
QString r_type = m["type"].toString();
int idx = m["index"].toInt();
QJsonObject resultsObject = searchResponse[r_type].toObject();
QJsonArray resultsArray = resultsObject["results"].toArray();
Q_ASSERT(r_type == "web");
const int idx = m["index"].toInt();
QJsonValue cleaned;
SourceExcerpt info;
if (r_type == "web") {
// For web data - add a single output from the search
QJsonObject resultObj = resultsArray[idx].toObject();
QStringList selectedKeys = {"type", "title", "url", "description", "date", "extra_snippets"};
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (resultObj.contains(key)) {
cleanedObj.insert(key, resultObj[key]);
}
}
QJsonObject resultObj = resultsArray[idx].toObject();
QStringList selectedKeys = {"type", "title", "url", "description"};
QJsonObject result;
for (const auto& key : selectedKeys)
if (resultObj.contains(key))
result.insert(key, resultObj[key]);
QStringList textKeys = {"description", "extra_snippets"};
QJsonObject textObj;
for (const auto& key : textKeys) {
if (resultObj.contains(key)) {
textObj.insert(key, resultObj[key]);
}
}
if (resultObj.contains("page_age"))
result.insert("date", resultObj["page_age"]);
QJsonDocument textObjDoc(textObj);
info.date = resultObj["date"].toString();
info.text = textObjDoc.toJson(QJsonDocument::Indented);
info.url = resultObj["url"].toString();
QJsonObject meta_url = resultObj["meta_url"].toObject();
info.favicon = meta_url["favicon"].toString();
info.title = resultObj["title"].toString();
cleaned = cleanedObj;
} else if (r_type == "faq") {
// For faq data - take a list of all the questions & answers
QStringList selectedKeys = {"type", "question", "answer", "title", "url"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
QJsonArray excerpts;
if (resultObj.contains("extra_snippets")) {
QJsonArray snippets = resultObj["extra_snippets"].toArray();
for (int i = 0; i < snippets.size(); ++i) {
QString snippet = snippets[i].toString();
QJsonObject excerpt;
excerpt.insert("text", snippet);
excerpts.append(excerpt);
}
cleaned = cleanedArray;
} else if (r_type == "infobox") {
QJsonObject resultObj = resultsArray[idx].toObject();
QStringList selectedKeys = {"type", "title", "url", "description", "long_desc"};
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (resultObj.contains(key)) {
cleanedObj.insert(key, resultObj[key]);
}
}
cleaned = cleanedObj;
} else if (r_type == "videos") {
QStringList selectedKeys = {"type", "url", "title", "description", "date"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else if (r_type == "locations") {
QStringList selectedKeys = {"type", "title", "url", "description", "coordinates", "postal_address", "contact", "rating", "distance", "zoom_level"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else if (r_type == "news") {
QStringList selectedKeys = {"type", "title", "url", "description"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else {
cleaned = QJsonValue();
}
infos.append(info);
cleanArray.append(cleaned);
result.insert("excerpts", excerpts);
cleanArray.append(QJsonValue(result));
}
}
cleanResponse.insert("query", query);
cleanResponse.insert("top_k", cleanArray);
cleanResponse.insert("results", cleanArray);
QJsonDocument cleanedDoc(cleanResponse);
// qDebug().noquote() << document.toJson(QJsonDocument::Indented);
// qDebug().noquote() << cleanedDoc.toJson(QJsonDocument::Indented);
return qMakePair(cleanedDoc.toJson(QJsonDocument::Indented), infos);
return cleanedDoc.toJson(QJsonDocument::Compact);
}
void BraveAPIWorker::handleFinished()

View File

@ -2,6 +2,7 @@
#define BRAVESEARCH_H
#include "sourceexcerpt.h"
#include "tool.h"
#include <QObject>
#include <QString>
@ -17,7 +18,7 @@ public:
, m_topK(1) {}
virtual ~BraveAPIWorker() {}
QPair<QString, QList<SourceExcerpt>> response() const { return m_response; }
QString response() const { return m_response; }
public Q_SLOTS:
void request(const QString &apiKey, const QString &query, int topK);
@ -31,21 +32,17 @@ private Q_SLOTS:
private:
QNetworkAccessManager *m_networkManager;
QPair<QString, QList<SourceExcerpt>> m_response;
QString m_response;
int m_topK;
};
class BraveSearch : public QObject {
class BraveSearch : public Tool {
Q_OBJECT
public:
BraveSearch()
: QObject(nullptr) {}
BraveSearch() : Tool() {}
virtual ~BraveSearch() {}
QPair<QString, QList<SourceExcerpt>> search(const QString &apiKey, const QString &query, int topK, unsigned long timeout = 2000);
Q_SIGNALS:
void request(const QString &apiKey, const QString &query, int topK);
QString run(const QJsonObject &parameters, qint64 timeout = 2000) override;
};
#endif // BRAVESEARCH_H

View File

@ -871,14 +871,26 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
const QString query = args["query"].toString();
// FIXME: This has to handle errors of the tool call
emit toolCalled(tr("searching web..."));
const QString apiKey = MySettings::globalInstance()->braveSearchAPIKey();
Q_ASSERT(apiKey != "");
BraveSearch brave;
const QPair<QString, QList<SourceExcerpt>> braveResponse = brave.search(apiKey, query, 2 /*topK*/,
2000 /*msecs to timeout*/);
emit sourceExcerptsChanged(braveResponse.second);
QJsonObject parameters;
parameters.insert("apiKey", apiKey);
parameters.insert("query", query);
parameters.insert("count", 2);
// FIXME: This has to handle errors of the tool call
const QString braveResponse = brave.run(parameters, 2000 /*msecs to timeout*/);
QString parseError;
QList<SourceExcerpt> sourceExcerpts = SourceExcerpt::fromJson(braveResponse, parseError);
if (!parseError.isEmpty()) {
qWarning() << "ERROR: Could not parse source excerpts for brave response" << parseError;
} else if (!sourceExcerpts.isEmpty()) {
emit sourceExcerptsChanged(sourceExcerpts);
}
// Erase the context of the tool call
m_ctx.n_past = std::max(0, m_ctx.n_past);
@ -889,7 +901,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
// This is a recursive call but isToolCallResponse is checked above to arrest infinite recursive
// tool calls
return promptInternal(QList<QString>()/*collectionList*/, braveResponse.first, toolTemplate,
return promptInternal(QList<QString>()/*collectionList*/, braveResponse, toolTemplate,
n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens,
true /*isToolCallResponse*/);

View File

@ -1133,7 +1133,14 @@ Rectangle {
sourceSize.width: 24
sourceSize.height: 24
mipmap: true
source: consolidatedSources[0].url === "" ? "qrc:/gpt4all/icons/db.svg" : "qrc:/gpt4all/icons/globe.svg"
source: {
if (typeof consolidatedSources === 'undefined'
|| typeof consolidatedSources[0] === 'undefined'
|| consolidatedSources[0].url === "")
return "qrc:/gpt4all/icons/db.svg";
else
return "qrc:/gpt4all/icons/globe.svg";
}
}
ColorOverlay {

View File

@ -0,0 +1,93 @@
#include "sourceexcerpt.h"
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonObject>
#include <QJsonValue>
QList<SourceExcerpt> SourceExcerpt::fromJson(const QString &json, QString &errorString)
{
QJsonParseError err;
QJsonDocument document = QJsonDocument::fromJson(json.toUtf8(), &err);
if (err.error != QJsonParseError::NoError) {
errorString = err.errorString();
return QList<SourceExcerpt>();
}
QJsonObject jsonObject = document.object();
Q_ASSERT(jsonObject.contains("results"));
if (!jsonObject.contains("results")) {
errorString = "json does not contain results array";
return QList<SourceExcerpt>();
}
QList<SourceExcerpt> excerpts;
QJsonArray results = jsonObject["results"].toArray();
for (int i = 0; i < results.size(); ++i) {
QJsonObject result = results[i].toObject();
if (!result.contains("date")) {
errorString = "result does not contain required date field";
return QList<SourceExcerpt>();
}
if (!result.contains("excerpts") || !result["excerpts"].isArray()) {
errorString = "result does not contain required excerpts array";
return QList<SourceExcerpt>();
}
QJsonArray textExcerpts = result["excerpts"].toArray();
if (textExcerpts.isEmpty()) {
errorString = "result excerpts array is empty";
return QList<SourceExcerpt>();
}
SourceExcerpt source;
source.date = result["date"].toString();
if (result.contains("collection"))
source.collection = result["text"].toString();
if (result.contains("path"))
source.path = result["path"].toString();
if (result.contains("file"))
source.file = result["file"].toString();
if (result.contains("url"))
source.url = result["url"].toString();
if (result.contains("favicon"))
source.favicon = result["favicon"].toString();
if (result.contains("title"))
source.title = result["title"].toString();
if (result.contains("author"))
source.author = result["author"].toString();
if (result.contains("description"))
source.author = result["description"].toString();
for (int i = 0; i < textExcerpts.size(); ++i) {
SourceExcerpt excerpt;
excerpt.date = source.date;
excerpt.collection = source.collection;
excerpt.path = source.path;
excerpt.file = source.file;
excerpt.url = source.url;
excerpt.favicon = source.favicon;
excerpt.title = source.title;
excerpt.author = source.author;
if (!textExcerpts[i].isObject()) {
errorString = "result excerpt is not an object";
return QList<SourceExcerpt>();
}
QJsonObject excerptObj = textExcerpts[i].toObject();
if (!excerptObj.contains("text")) {
errorString = "result excerpt is does not have text field";
return QList<SourceExcerpt>();
}
excerpt.text = excerptObj["text"].toString();
if (excerptObj.contains("page"))
excerpt.page = excerptObj["page"].toInt();
if (excerptObj.contains("from"))
excerpt.from = excerptObj["from"].toInt();
if (excerptObj.contains("to"))
excerpt.to = excerptObj["to"].toInt();
excerpts.append(excerpt);
}
}
return excerpts;
}

View File

@ -19,6 +19,7 @@ struct SourceExcerpt {
Q_PROPERTY(QString favicon MEMBER favicon)
Q_PROPERTY(QString title MEMBER title)
Q_PROPERTY(QString author MEMBER author)
Q_PROPERTY(QString description MEMBER description)
Q_PROPERTY(int page MEMBER page)
Q_PROPERTY(int from MEMBER from)
Q_PROPERTY(int to MEMBER to)
@ -34,6 +35,7 @@ public:
QString favicon; // [Optional] The favicon
QString title; // [Optional] The title of the document
QString author; // [Optional] The author of the document
QString description;// [Optional] The description of the source
int page = -1; // [Optional] The page where the text was found
int from = -1; // [Optional] The line number where the text begins
int to = -1; // [Optional] The line number where the text ends
@ -65,12 +67,15 @@ public:
result.insert("favicon", favicon);
result.insert("title", title);
result.insert("author", author);
result.insert("description", description);
result.insert("page", page);
result.insert("from", from);
result.insert("to", to);
return result;
}
static QList<SourceExcerpt> fromJson(const QString &json, QString &errorString);
bool operator==(const SourceExcerpt &other) const {
return date == other.date &&
text == other.text &&
@ -81,6 +86,7 @@ public:
favicon == other.favicon &&
title == other.title &&
author == other.author &&
description == other.description &&
page == other.page &&
from == other.from &&
to == other.to;

1
gpt4all-chat/tool.cpp Normal file
View File

@ -0,0 +1 @@
#include "tool.h"

87
gpt4all-chat/tool.h Normal file
View File

@ -0,0 +1,87 @@
#ifndef TOOL_H
#define TOOL_H
#include "sourceexcerpt.h"
#include <QObject>
#include <QJsonObject>
using namespace Qt::Literals::StringLiterals;
namespace ToolEnums {
Q_NAMESPACE
enum class ConnectionType {
Builtin = 0, // A built-in tool with bespoke connection type
Local = 1, // Starts a local process and communicates via stdin/stdout/stderr
LocalServer = 2, // Connects to an existing local process and communicates via stdin/stdout/stderr
Remote = 3, // Starts a remote process and communicates via some networking protocol TBD
RemoteServer = 4 // Connects to an existing remote process and communicates via some networking protocol TBD
};
Q_ENUM_NS(ConnectionType)
}
using namespace ToolEnums;
struct ToolInfo {
Q_GADGET
Q_PROPERTY(QString name MEMBER name)
Q_PROPERTY(QString description MEMBER description)
Q_PROPERTY(QJsonObject parameters MEMBER parameters)
Q_PROPERTY(bool isEnabled MEMBER isEnabled)
Q_PROPERTY(ConnectionType connectionType MEMBER connectionType)
public:
QString name;
QString description;
QJsonObject parameters;
bool isEnabled;
ConnectionType connectionType;
// FIXME: Should we go with essentially the OpenAI/ollama consensus for these tool
// info files? If you install a tool in GPT4All should it need to meet the spec for these:
// https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-tools
// https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-with-tools
QJsonObject toJson() const
{
QJsonObject result;
result.insert("name", name);
result.insert("description", description);
result.insert("parameters", parameters);
return result;
}
static ToolInfo fromJson(const QString &json);
bool operator==(const ToolInfo &other) const {
return name == other.name;
}
bool operator!=(const ToolInfo &other) const {
return !(*this == other);
}
};
Q_DECLARE_METATYPE(ToolInfo)
class Tool : public QObject {
Q_OBJECT
public:
Tool() : QObject(nullptr) {}
virtual ~Tool() {}
// FIXME: How to handle errors?
virtual QString run(const QJsonObject &parameters, qint64 timeout = 2000) = 0;
};
//class BuiltinTool : public Tool {
// Q_OBJECT
//public:
// BuiltinTool() : Tool() {}
// virtual QString run(const QJsonObject &parameters, qint64 timeout = 2000);
//};
//class LocalTool : public Tool {
// Q_OBJECT
//public:
// LocalTool() : Tool() {}
// virtual QString run(const QJsonObject &parameters, qint64 timeout = 2000);
//};
#endif // TOOL_H

95
gpt4all-chat/toolinfo.h Normal file
View File

@ -0,0 +1,95 @@
#ifndef SOURCEEXCERT_H
#define SOURCEEXCERT_H
#include <QObject>
#include <QJsonObject>
#include <QFileInfo>
#include <QUrl>
using namespace Qt::Literals::StringLiterals;
struct SourceExcerpt {
Q_GADGET
Q_PROPERTY(QString date MEMBER date)
Q_PROPERTY(QString text MEMBER text)
Q_PROPERTY(QString collection MEMBER collection)
Q_PROPERTY(QString path MEMBER path)
Q_PROPERTY(QString file MEMBER file)
Q_PROPERTY(QString url MEMBER url)
Q_PROPERTY(QString favicon MEMBER favicon)
Q_PROPERTY(QString title MEMBER title)
Q_PROPERTY(QString author MEMBER author)
Q_PROPERTY(int page MEMBER page)
Q_PROPERTY(int from MEMBER from)
Q_PROPERTY(int to MEMBER to)
Q_PROPERTY(QString fileUri READ fileUri STORED false)
public:
QString date; // [Required] The creation or the last modification date whichever is latest
QString text; // [Required] The text actually used in the augmented context
QString collection; // [Optional] The name of the collection
QString path; // [Optional] The full path
QString file; // [Optional] The name of the file, but not the full path
QString url; // [Optional] The name of the remote url
QString favicon; // [Optional] The favicon
QString title; // [Optional] The title of the document
QString author; // [Optional] The author of the document
int page = -1; // [Optional] The page where the text was found
int from = -1; // [Optional] The line number where the text begins
int to = -1; // [Optional] The line number where the text ends
QString fileUri() const {
// QUrl reserved chars that are not UNSAFE_PATH according to glib/gconvert.c
static const QByteArray s_exclude = "!$&'()*+,/:=@~"_ba;
Q_ASSERT(!QFileInfo(path).isRelative());
#ifdef Q_OS_WINDOWS
Q_ASSERT(!path.contains('\\')); // Qt normally uses forward slash as path separator
#endif
auto escaped = QString::fromUtf8(QUrl::toPercentEncoding(path, s_exclude));
if (escaped.front() != '/')
escaped = '/' + escaped;
return u"file://"_s + escaped;
}
QJsonObject toJson() const
{
QJsonObject result;
result.insert("date", date);
result.insert("text", text);
result.insert("collection", collection);
result.insert("path", path);
result.insert("file", file);
result.insert("url", url);
result.insert("favicon", favicon);
result.insert("title", title);
result.insert("author", author);
result.insert("page", page);
result.insert("from", from);
result.insert("to", to);
return result;
}
bool operator==(const SourceExcerpt &other) const {
return date == other.date &&
text == other.text &&
collection == other.collection &&
path == other.path &&
file == other.file &&
url == other.url &&
favicon == other.favicon &&
title == other.title &&
author == other.author &&
page == other.page &&
from == other.from &&
to == other.to;
}
bool operator!=(const SourceExcerpt &other) const {
return !(*this == other);
}
};
Q_DECLARE_METATYPE(SourceExcerpt)
#endif // SOURCEEXCERT_H