Don't use a local event loop which can lead to recursion and crashes.

This commit is contained in:
Adam Treat 2023-07-11 10:08:03 -04:00
parent 8467e69f24
commit 4f9e489093
2 changed files with 105 additions and 44 deletions

View File

@ -15,7 +15,6 @@
ChatGPT::ChatGPT() ChatGPT::ChatGPT()
: QObject(nullptr) : QObject(nullptr)
, m_modelName("gpt-3.5-turbo") , m_modelName("gpt-3.5-turbo")
, m_ctx(nullptr)
, m_responseCallback(nullptr) , m_responseCallback(nullptr)
{ {
} }
@ -84,9 +83,6 @@ void ChatGPT::prompt(const std::string &prompt,
return; return;
} }
m_ctx = &promptCtx;
m_responseCallback = responseCallback;
// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering // FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
// an error we need to be able to count the tokens in our prompt. The only way to do this is to use // an error we need to be able to count the tokens in our prompt. The only way to do this is to use
// the OpenAI tiktokken library or to implement our own tokenization function that matches precisely // the OpenAI tiktokken library or to implement our own tokenization function that matches precisely
@ -118,37 +114,64 @@ void ChatGPT::prompt(const std::string &prompt,
qDebug() << "ChatGPT::prompt begin network request" << qPrintable(doc.toJson()); qDebug() << "ChatGPT::prompt begin network request" << qPrintable(doc.toJson());
#endif #endif
QEventLoop loop; m_responseCallback = responseCallback;
QUrl openaiUrl("https://api.openai.com/v1/chat/completions");
const QString authorization = QString("Bearer %1").arg(m_apiKey); // The following code sets up a worker thread and object to perform the actual api request to
QNetworkRequest request(openaiUrl); // chatgpt and then blocks until it is finished
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); QThread workerThread;
request.setRawHeader("Authorization", authorization.toUtf8()); ChatGPTWorker worker(this);
QNetworkReply *reply = m_networkManager.post(request, doc.toJson(QJsonDocument::Compact)); worker.moveToThread(&workerThread);
connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit); connect(&worker, &ChatGPTWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(reply, &QNetworkReply::finished, this, &ChatGPT::handleFinished); connect(this, &ChatGPT::request, &worker, &ChatGPTWorker::request, Qt::QueuedConnection);
connect(reply, &QNetworkReply::readyRead, this, &ChatGPT::handleReadyRead); workerThread.start();
connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPT::handleErrorOccurred); emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact));
loop.exec(); workerThread.wait();
promptCtx.n_past += 1;
m_context.append(QString::fromStdString(prompt));
m_context.append(worker.currentResponse());
m_responseCallback = nullptr;
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "ChatGPT::prompt end network request"; qDebug() << "ChatGPT::prompt end network request";
#endif #endif
if (m_ctx)
m_ctx->n_past += 1;
m_context.append(QString::fromStdString(prompt));
m_context.append(m_currentResponse);
m_ctx = nullptr;
m_responseCallback = nullptr;
m_currentResponse = QString();
} }
void ChatGPT::handleFinished() bool ChatGPT::callResponse(int32_t token, const std::string& string)
{
Q_ASSERT(m_responseCallback);
if (!m_responseCallback) {
std::cerr << "ChatGPT ERROR: no response callback!\n";
return false;
}
return m_responseCallback(token, string);
}
void ChatGPTWorker::request(const QString &apiKey,
LLModel::PromptContext *promptCtx,
const QByteArray &array)
{
m_ctx = promptCtx;
QUrl openaiUrl("https://api.openai.com/v1/chat/completions");
const QString authorization = QString("Bearer %1").arg(apiKey);
QNetworkRequest request(openaiUrl);
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
request.setRawHeader("Authorization", authorization.toUtf8());
m_networkManager = new QNetworkAccessManager(this);
QNetworkReply *reply = m_networkManager->post(request, array);
connect(reply, &QNetworkReply::finished, this, &ChatGPTWorker::handleFinished);
connect(reply, &QNetworkReply::readyRead, this, &ChatGPTWorker::handleReadyRead);
connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPTWorker::handleErrorOccurred);
}
void ChatGPTWorker::handleFinished()
{ {
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender()); QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply) if (!reply) {
emit finished();
return; return;
}
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
Q_ASSERT(response.isValid()); Q_ASSERT(response.isValid());
@ -159,21 +182,25 @@ void ChatGPT::handleFinished()
.arg(code).arg(reply->errorString()).toStdString(); .arg(code).arg(reply->errorString()).toStdString();
} }
reply->deleteLater(); reply->deleteLater();
emit finished();
} }
void ChatGPT::handleReadyRead() void ChatGPTWorker::handleReadyRead()
{ {
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender()); QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply) if (!reply) {
emit finished();
return; return;
}
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
Q_ASSERT(response.isValid()); Q_ASSERT(response.isValid());
bool ok; bool ok;
int code = response.toInt(&ok); int code = response.toInt(&ok);
if (!ok || code != 200) { if (!ok || code != 200) {
m_responseCallback(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n") m_chat->callResponse(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n")
.arg(code).arg(reply->errorString()).arg(qPrintable(reply->readAll())).toStdString()); .arg(code).arg(reply->errorString()).arg(qPrintable(reply->readAll())).toStdString());
emit finished();
return; return;
} }
@ -192,7 +219,7 @@ void ChatGPT::handleReadyRead()
QJsonParseError err; QJsonParseError err;
const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err); const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err);
if (err.error != QJsonParseError::NoError) { if (err.error != QJsonParseError::NoError) {
m_responseCallback(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n") m_chat->callResponse(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n")
.arg(err.errorString()).toStdString()); .arg(err.errorString()).toStdString());
continue; continue;
} }
@ -203,21 +230,24 @@ void ChatGPT::handleReadyRead()
const QJsonObject delta = choice.value("delta").toObject(); const QJsonObject delta = choice.value("delta").toObject();
const QString content = delta.value("content").toString(); const QString content = delta.value("content").toString();
Q_ASSERT(m_ctx); Q_ASSERT(m_ctx);
Q_ASSERT(m_responseCallback);
m_currentResponse += content; m_currentResponse += content;
if (!m_responseCallback(0, content.toStdString())) { if (!m_chat->callResponse(0, content.toStdString())) {
reply->abort(); reply->abort();
emit finished();
return; return;
} }
} }
} }
void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code) void ChatGPTWorker::handleErrorOccurred(QNetworkReply::NetworkError code)
{ {
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender()); QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
if (!reply) if (!reply) {
emit finished();
return; return;
}
qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"") qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
.arg(code).arg(reply->errorString()).toStdString(); .arg(code).arg(reply->errorString()).toStdString();
emit finished();
} }

View File

@ -5,9 +5,41 @@
#include <QNetworkReply> #include <QNetworkReply>
#include <QNetworkRequest> #include <QNetworkRequest>
#include <QNetworkAccessManager> #include <QNetworkAccessManager>
#include <QThread>
#include "../gpt4all-backend/llmodel.h" #include "../gpt4all-backend/llmodel.h"
class ChatGPTPrivate; class ChatGPT;
class ChatGPTWorker : public QObject {
Q_OBJECT
public:
ChatGPTWorker(ChatGPT *chatGPT)
: QObject(nullptr)
, m_ctx(nullptr)
, m_networkManager(nullptr)
, m_chat(chatGPT) {}
virtual ~ChatGPTWorker() {}
QString currentResponse() const { return m_currentResponse; }
void request(const QString &apiKey,
LLModel::PromptContext *promptCtx,
const QByteArray &array);
Q_SIGNALS:
void finished();
private Q_SLOTS:
void handleFinished();
void handleReadyRead();
void handleErrorOccurred(QNetworkReply::NetworkError code);
private:
ChatGPT *m_chat;
LLModel::PromptContext *m_ctx;
QNetworkAccessManager *m_networkManager;
QString m_currentResponse;
};
class ChatGPT : public QObject, public LLModel { class ChatGPT : public QObject, public LLModel {
Q_OBJECT Q_OBJECT
public: public:
@ -35,6 +67,13 @@ public:
QList<QString> context() const { return m_context; } QList<QString> context() const { return m_context; }
void setContext(const QList<QString> &context) { m_context = context; } void setContext(const QList<QString> &context) { m_context = context; }
bool callResponse(int32_t token, const std::string& string);
Q_SIGNALS:
void request(const QString &apiKey,
LLModel::PromptContext *ctx,
const QByteArray &array);
protected: protected:
// We have to implement these as they are pure virtual in base class, but we don't actually use // We have to implement these as they are pure virtual in base class, but we don't actually use
// them as they are only called from the default implementation of 'prompt' which we override and // them as they are only called from the default implementation of 'prompt' which we override and
@ -46,19 +85,11 @@ protected:
int32_t contextLength() const override { return -1; } int32_t contextLength() const override { return -1; }
const std::vector<Token>& endTokens() const override { static const std::vector<Token> fres; return fres; } const std::vector<Token>& endTokens() const override { static const std::vector<Token> fres; return fres; }
private Q_SLOTS:
void handleFinished();
void handleReadyRead();
void handleErrorOccurred(QNetworkReply::NetworkError code);
private: private:
PromptContext *m_ctx;
std::function<bool(int32_t, const std::string&)> m_responseCallback; std::function<bool(int32_t, const std::string&)> m_responseCallback;
QString m_modelName; QString m_modelName;
QString m_apiKey; QString m_apiKey;
QList<QString> m_context; QList<QString> m_context;
QString m_currentResponse;
QNetworkAccessManager m_networkManager;
}; };
#endif // CHATGPT_H #endif // CHATGPT_H