Don't use a local event loop which can lead to recursion and crashes.

This commit is contained in:
Adam Treat 2023-07-11 10:08:03 -04:00
parent 8467e69f24
commit 4f9e489093
2 changed files with 105 additions and 44 deletions

View File

@ -15,7 +15,6 @@
ChatGPT::ChatGPT()
: QObject(nullptr)
, m_modelName("gpt-3.5-turbo")
, m_ctx(nullptr)
, m_responseCallback(nullptr)
{
}
@ -84,9 +83,6 @@ void ChatGPT::prompt(const std::string &prompt,
return;
}
m_ctx = &promptCtx;
m_responseCallback = responseCallback;
// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
// an error we need to be able to count the tokens in our prompt. The only way to do this is to use
// the OpenAI tiktokken library or to implement our own tokenization function that matches precisely
@ -118,37 +114,64 @@ void ChatGPT::prompt(const std::string &prompt,
qDebug() << "ChatGPT::prompt begin network request" << qPrintable(doc.toJson());
#endif
QEventLoop loop;
QUrl openaiUrl("https://api.openai.com/v1/chat/completions");
const QString authorization = QString("Bearer %1").arg(m_apiKey);
QNetworkRequest request(openaiUrl);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
request.setRawHeader("Authorization", authorization.toUtf8());
QNetworkReply *reply = m_networkManager.post(request, doc.toJson(QJsonDocument::Compact));
connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit);
connect(reply, &QNetworkReply::finished, this, &ChatGPT::handleFinished);
connect(reply, &QNetworkReply::readyRead, this, &ChatGPT::handleReadyRead);
connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPT::handleErrorOccurred);
loop.exec();
m_responseCallback = responseCallback;
// The following code sets up a worker thread and object to perform the actual api request to
// chatgpt and then blocks until it is finished
QThread workerThread;
ChatGPTWorker worker(this);
worker.moveToThread(&workerThread);
connect(&worker, &ChatGPTWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(this, &ChatGPT::request, &worker, &ChatGPTWorker::request, Qt::QueuedConnection);
workerThread.start();
emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact));
workerThread.wait();
promptCtx.n_past += 1;
m_context.append(QString::fromStdString(prompt));
m_context.append(worker.currentResponse());
m_responseCallback = nullptr;
#if defined(DEBUG)
qDebug() << "ChatGPT::prompt end network request";
#endif
if (m_ctx)
m_ctx->n_past += 1;
m_context.append(QString::fromStdString(prompt));
m_context.append(m_currentResponse);
m_ctx = nullptr;
m_responseCallback = nullptr;
m_currentResponse = QString();
}
void ChatGPT::handleFinished()
bool ChatGPT::callResponse(int32_t token, const std::string& string)
{
Q_ASSERT(m_responseCallback);
if (!m_responseCallback) {
std::cerr << "ChatGPT ERROR: no response callback!\n";
return false;
}
return m_responseCallback(token, string);
}
void ChatGPTWorker::request(const QString &apiKey,
LLModel::PromptContext *promptCtx,
const QByteArray &array)
{
m_ctx = promptCtx;
QUrl openaiUrl("https://api.openai.com/v1/chat/completions");
const QString authorization = QString("Bearer %1").arg(apiKey);
QNetworkRequest request(openaiUrl);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
request.setRawHeader("Authorization", authorization.toUtf8());
m_networkManager = new QNetworkAccessManager(this);
QNetworkReply *reply = m_networkManager->post(request, array);
connect(reply, &QNetworkReply::finished, this, &ChatGPTWorker::handleFinished);
connect(reply, &QNetworkReply::readyRead, this, &ChatGPTWorker::handleReadyRead);
connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPTWorker::handleErrorOccurred);
}
void ChatGPTWorker::handleFinished()
{
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply)
if (!reply) {
emit finished();
return;
}
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
Q_ASSERT(response.isValid());
@ -159,21 +182,25 @@ void ChatGPT::handleFinished()
.arg(code).arg(reply->errorString()).toStdString();
}
reply->deleteLater();
emit finished();
}
void ChatGPT::handleReadyRead()
void ChatGPTWorker::handleReadyRead()
{
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply)
if (!reply) {
emit finished();
return;
}
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
Q_ASSERT(response.isValid());
bool ok;
int code = response.toInt(&ok);
if (!ok || code != 200) {
m_responseCallback(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n")
m_chat->callResponse(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n")
.arg(code).arg(reply->errorString()).arg(qPrintable(reply->readAll())).toStdString());
emit finished();
return;
}
@ -192,7 +219,7 @@ void ChatGPT::handleReadyRead()
QJsonParseError err;
const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err);
if (err.error != QJsonParseError::NoError) {
m_responseCallback(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n")
m_chat->callResponse(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n")
.arg(err.errorString()).toStdString());
continue;
}
@ -203,21 +230,24 @@ void ChatGPT::handleReadyRead()
const QJsonObject delta = choice.value("delta").toObject();
const QString content = delta.value("content").toString();
Q_ASSERT(m_ctx);
Q_ASSERT(m_responseCallback);
m_currentResponse += content;
if (!m_responseCallback(0, content.toStdString())) {
if (!m_chat->callResponse(0, content.toStdString())) {
reply->abort();
emit finished();
return;
}
}
}
void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code)
void ChatGPTWorker::handleErrorOccurred(QNetworkReply::NetworkError code)
{
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply)
if (!reply) {
emit finished();
return;
}
qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
.arg(code).arg(reply->errorString()).toStdString();
emit finished();
}

View File

@ -5,9 +5,41 @@
#include <QNetworkReply>
#include <QNetworkRequest>
#include <QNetworkAccessManager>
#include <QThread>
#include "../gpt4all-backend/llmodel.h"
class ChatGPTPrivate;
class ChatGPT;
class ChatGPTWorker : public QObject {
Q_OBJECT
public:
ChatGPTWorker(ChatGPT *chatGPT)
: QObject(nullptr)
, m_ctx(nullptr)
, m_networkManager(nullptr)
, m_chat(chatGPT) {}
virtual ~ChatGPTWorker() {}
QString currentResponse() const { return m_currentResponse; }
void request(const QString &apiKey,
LLModel::PromptContext *promptCtx,
const QByteArray &array);
Q_SIGNALS:
void finished();
private Q_SLOTS:
void handleFinished();
void handleReadyRead();
void handleErrorOccurred(QNetworkReply::NetworkError code);
private:
ChatGPT *m_chat;
LLModel::PromptContext *m_ctx;
QNetworkAccessManager *m_networkManager;
QString m_currentResponse;
};
class ChatGPT : public QObject, public LLModel {
Q_OBJECT
public:
@ -35,6 +67,13 @@ public:
QList<QString> context() const { return m_context; }
void setContext(const QList<QString> &context) { m_context = context; }
bool callResponse(int32_t token, const std::string& string);
Q_SIGNALS:
void request(const QString &apiKey,
LLModel::PromptContext *ctx,
const QByteArray &array);
protected:
// We have to implement these as they are pure virtual in base class, but we don't actually use
// them as they are only called from the default implementation of 'prompt' which we override and
@ -46,19 +85,11 @@ protected:
int32_t contextLength() const override { return -1; }
const std::vector<Token>& endTokens() const override { static const std::vector<Token> fres; return fres; }
private Q_SLOTS:
void handleFinished();
void handleReadyRead();
void handleErrorOccurred(QNetworkReply::NetworkError code);
private:
PromptContext *m_ctx;
std::function<bool(int32_t, const std::string&)> m_responseCallback;
QString m_modelName;
QString m_apiKey;
QList<QString> m_context;
QString m_currentResponse;
QNetworkAccessManager m_networkManager;
};
#endif // CHATGPT_H