diff --git a/gpt4all-chat/chatgpt.cpp b/gpt4all-chat/chatgpt.cpp index d9f6114e..2b72604d 100644 --- a/gpt4all-chat/chatgpt.cpp +++ b/gpt4all-chat/chatgpt.cpp @@ -15,7 +15,6 @@ ChatGPT::ChatGPT() : QObject(nullptr) , m_modelName("gpt-3.5-turbo") - , m_ctx(nullptr) , m_responseCallback(nullptr) { } @@ -84,9 +83,6 @@ void ChatGPT::prompt(const std::string &prompt, return; } - m_ctx = &promptCtx; - m_responseCallback = responseCallback; - // FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering // an error we need to be able to count the tokens in our prompt. The only way to do this is to use // the OpenAI tiktokken library or to implement our own tokenization function that matches precisely @@ -118,37 +114,64 @@ void ChatGPT::prompt(const std::string &prompt, qDebug() << "ChatGPT::prompt begin network request" << qPrintable(doc.toJson()); #endif - QEventLoop loop; - QUrl openaiUrl("https://api.openai.com/v1/chat/completions"); - const QString authorization = QString("Bearer %1").arg(m_apiKey); - QNetworkRequest request(openaiUrl); - request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); - request.setRawHeader("Authorization", authorization.toUtf8()); - QNetworkReply *reply = m_networkManager.post(request, doc.toJson(QJsonDocument::Compact)); - connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit); - connect(reply, &QNetworkReply::finished, this, &ChatGPT::handleFinished); - connect(reply, &QNetworkReply::readyRead, this, &ChatGPT::handleReadyRead); - connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPT::handleErrorOccurred); - loop.exec(); + m_responseCallback = responseCallback; + + // The following code sets up a worker thread and object to perform the actual api request to + // chatgpt and then blocks until it is finished + QThread workerThread; + ChatGPTWorker worker(this); + worker.moveToThread(&workerThread); + connect(&worker, &ChatGPTWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); + connect(this, &ChatGPT::request, &worker, &ChatGPTWorker::request, Qt::QueuedConnection); + workerThread.start(); + emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact)); + workerThread.wait(); + + promptCtx.n_past += 1; + m_context.append(QString::fromStdString(prompt)); + m_context.append(worker.currentResponse()); + m_responseCallback = nullptr; + #if defined(DEBUG) qDebug() << "ChatGPT::prompt end network request"; #endif - - if (m_ctx) - m_ctx->n_past += 1; - m_context.append(QString::fromStdString(prompt)); - m_context.append(m_currentResponse); - - m_ctx = nullptr; - m_responseCallback = nullptr; - m_currentResponse = QString(); } -void ChatGPT::handleFinished() +bool ChatGPT::callResponse(int32_t token, const std::string& string) +{ + Q_ASSERT(m_responseCallback); + if (!m_responseCallback) { + std::cerr << "ChatGPT ERROR: no response callback!\n"; + return false; + } + return m_responseCallback(token, string); +} + +void ChatGPTWorker::request(const QString &apiKey, + LLModel::PromptContext *promptCtx, + const QByteArray &array) +{ + m_ctx = promptCtx; + + QUrl openaiUrl("https://api.openai.com/v1/chat/completions"); + const QString authorization = QString("Bearer %1").arg(apiKey); + QNetworkRequest request(openaiUrl); + request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json"); + request.setRawHeader("Authorization", authorization.toUtf8()); + m_networkManager = new QNetworkAccessManager(this); + QNetworkReply *reply = m_networkManager->post(request, array); + connect(reply, &QNetworkReply::finished, this, &ChatGPTWorker::handleFinished); + connect(reply, &QNetworkReply::readyRead, this, &ChatGPTWorker::handleReadyRead); + connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPTWorker::handleErrorOccurred); +} + +void ChatGPTWorker::handleFinished() { QNetworkReply *reply = qobject_cast(sender()); - if (!reply) + if (!reply) { + emit finished(); return; + } QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); Q_ASSERT(response.isValid()); @@ -159,21 +182,25 @@ void ChatGPT::handleFinished() .arg(code).arg(reply->errorString()).toStdString(); } reply->deleteLater(); + emit finished(); } -void ChatGPT::handleReadyRead() +void ChatGPTWorker::handleReadyRead() { QNetworkReply *reply = qobject_cast(sender()); - if (!reply) + if (!reply) { + emit finished(); return; + } QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute); Q_ASSERT(response.isValid()); bool ok; int code = response.toInt(&ok); if (!ok || code != 200) { - m_responseCallback(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n") + m_chat->callResponse(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n") .arg(code).arg(reply->errorString()).arg(qPrintable(reply->readAll())).toStdString()); + emit finished(); return; } @@ -192,7 +219,7 @@ void ChatGPT::handleReadyRead() QJsonParseError err; const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err); if (err.error != QJsonParseError::NoError) { - m_responseCallback(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n") + m_chat->callResponse(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n") .arg(err.errorString()).toStdString()); continue; } @@ -203,21 +230,24 @@ void ChatGPT::handleReadyRead() const QJsonObject delta = choice.value("delta").toObject(); const QString content = delta.value("content").toString(); Q_ASSERT(m_ctx); - Q_ASSERT(m_responseCallback); m_currentResponse += content; - if (!m_responseCallback(0, content.toStdString())) { + if (!m_chat->callResponse(0, content.toStdString())) { reply->abort(); + emit finished(); return; } } } -void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code) +void ChatGPTWorker::handleErrorOccurred(QNetworkReply::NetworkError code) { QNetworkReply *reply = qobject_cast(sender()); - if (!reply) + if (!reply) { + emit finished(); return; + } qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"") .arg(code).arg(reply->errorString()).toStdString(); + emit finished(); } diff --git a/gpt4all-chat/chatgpt.h b/gpt4all-chat/chatgpt.h index af06a4bb..b1f32298 100644 --- a/gpt4all-chat/chatgpt.h +++ b/gpt4all-chat/chatgpt.h @@ -5,9 +5,41 @@ #include #include #include +#include #include "../gpt4all-backend/llmodel.h" -class ChatGPTPrivate; +class ChatGPT; +class ChatGPTWorker : public QObject { + Q_OBJECT +public: + ChatGPTWorker(ChatGPT *chatGPT) + : QObject(nullptr) + , m_ctx(nullptr) + , m_networkManager(nullptr) + , m_chat(chatGPT) {} + virtual ~ChatGPTWorker() {} + + QString currentResponse() const { return m_currentResponse; } + + void request(const QString &apiKey, + LLModel::PromptContext *promptCtx, + const QByteArray &array); + +Q_SIGNALS: + void finished(); + +private Q_SLOTS: + void handleFinished(); + void handleReadyRead(); + void handleErrorOccurred(QNetworkReply::NetworkError code); + +private: + ChatGPT *m_chat; + LLModel::PromptContext *m_ctx; + QNetworkAccessManager *m_networkManager; + QString m_currentResponse; +}; + class ChatGPT : public QObject, public LLModel { Q_OBJECT public: @@ -35,6 +67,13 @@ public: QList context() const { return m_context; } void setContext(const QList &context) { m_context = context; } + bool callResponse(int32_t token, const std::string& string); + +Q_SIGNALS: + void request(const QString &apiKey, + LLModel::PromptContext *ctx, + const QByteArray &array); + protected: // We have to implement these as they are pure virtual in base class, but we don't actually use // them as they are only called from the default implementation of 'prompt' which we override and @@ -46,19 +85,11 @@ protected: int32_t contextLength() const override { return -1; } const std::vector& endTokens() const override { static const std::vector fres; return fres; } -private Q_SLOTS: - void handleFinished(); - void handleReadyRead(); - void handleErrorOccurred(QNetworkReply::NetworkError code); - private: - PromptContext *m_ctx; std::function m_responseCallback; QString m_modelName; QString m_apiKey; QList m_context; - QString m_currentResponse; - QNetworkAccessManager m_networkManager; }; #endif // CHATGPT_H