mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Don't use a local event loop which can lead to recursion and crashes.
This commit is contained in:
parent
8467e69f24
commit
4f9e489093
@ -15,7 +15,6 @@
|
|||||||
ChatGPT::ChatGPT()
|
ChatGPT::ChatGPT()
|
||||||
: QObject(nullptr)
|
: QObject(nullptr)
|
||||||
, m_modelName("gpt-3.5-turbo")
|
, m_modelName("gpt-3.5-turbo")
|
||||||
, m_ctx(nullptr)
|
|
||||||
, m_responseCallback(nullptr)
|
, m_responseCallback(nullptr)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@ -84,9 +83,6 @@ void ChatGPT::prompt(const std::string &prompt,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
m_ctx = &promptCtx;
|
|
||||||
m_responseCallback = responseCallback;
|
|
||||||
|
|
||||||
// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
|
// FIXME: We don't set the max_tokens on purpose because in order to do so safely without encountering
|
||||||
// an error we need to be able to count the tokens in our prompt. The only way to do this is to use
|
// an error we need to be able to count the tokens in our prompt. The only way to do this is to use
|
||||||
// the OpenAI tiktokken library or to implement our own tokenization function that matches precisely
|
// the OpenAI tiktokken library or to implement our own tokenization function that matches precisely
|
||||||
@ -118,37 +114,64 @@ void ChatGPT::prompt(const std::string &prompt,
|
|||||||
qDebug() << "ChatGPT::prompt begin network request" << qPrintable(doc.toJson());
|
qDebug() << "ChatGPT::prompt begin network request" << qPrintable(doc.toJson());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
QEventLoop loop;
|
m_responseCallback = responseCallback;
|
||||||
QUrl openaiUrl("https://api.openai.com/v1/chat/completions");
|
|
||||||
const QString authorization = QString("Bearer %1").arg(m_apiKey);
|
// The following code sets up a worker thread and object to perform the actual api request to
|
||||||
QNetworkRequest request(openaiUrl);
|
// chatgpt and then blocks until it is finished
|
||||||
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
|
QThread workerThread;
|
||||||
request.setRawHeader("Authorization", authorization.toUtf8());
|
ChatGPTWorker worker(this);
|
||||||
QNetworkReply *reply = m_networkManager.post(request, doc.toJson(QJsonDocument::Compact));
|
worker.moveToThread(&workerThread);
|
||||||
connect(reply, &QNetworkReply::finished, &loop, &QEventLoop::quit);
|
connect(&worker, &ChatGPTWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
|
||||||
connect(reply, &QNetworkReply::finished, this, &ChatGPT::handleFinished);
|
connect(this, &ChatGPT::request, &worker, &ChatGPTWorker::request, Qt::QueuedConnection);
|
||||||
connect(reply, &QNetworkReply::readyRead, this, &ChatGPT::handleReadyRead);
|
workerThread.start();
|
||||||
connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPT::handleErrorOccurred);
|
emit request(m_apiKey, &promptCtx, doc.toJson(QJsonDocument::Compact));
|
||||||
loop.exec();
|
workerThread.wait();
|
||||||
|
|
||||||
|
promptCtx.n_past += 1;
|
||||||
|
m_context.append(QString::fromStdString(prompt));
|
||||||
|
m_context.append(worker.currentResponse());
|
||||||
|
m_responseCallback = nullptr;
|
||||||
|
|
||||||
#if defined(DEBUG)
|
#if defined(DEBUG)
|
||||||
qDebug() << "ChatGPT::prompt end network request";
|
qDebug() << "ChatGPT::prompt end network request";
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (m_ctx)
|
|
||||||
m_ctx->n_past += 1;
|
|
||||||
m_context.append(QString::fromStdString(prompt));
|
|
||||||
m_context.append(m_currentResponse);
|
|
||||||
|
|
||||||
m_ctx = nullptr;
|
|
||||||
m_responseCallback = nullptr;
|
|
||||||
m_currentResponse = QString();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatGPT::handleFinished()
|
bool ChatGPT::callResponse(int32_t token, const std::string& string)
|
||||||
|
{
|
||||||
|
Q_ASSERT(m_responseCallback);
|
||||||
|
if (!m_responseCallback) {
|
||||||
|
std::cerr << "ChatGPT ERROR: no response callback!\n";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return m_responseCallback(token, string);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ChatGPTWorker::request(const QString &apiKey,
|
||||||
|
LLModel::PromptContext *promptCtx,
|
||||||
|
const QByteArray &array)
|
||||||
|
{
|
||||||
|
m_ctx = promptCtx;
|
||||||
|
|
||||||
|
QUrl openaiUrl("https://api.openai.com/v1/chat/completions");
|
||||||
|
const QString authorization = QString("Bearer %1").arg(apiKey);
|
||||||
|
QNetworkRequest request(openaiUrl);
|
||||||
|
request.setHeader(QNetworkRequest::ContentTypeHeader, "application/json");
|
||||||
|
request.setRawHeader("Authorization", authorization.toUtf8());
|
||||||
|
m_networkManager = new QNetworkAccessManager(this);
|
||||||
|
QNetworkReply *reply = m_networkManager->post(request, array);
|
||||||
|
connect(reply, &QNetworkReply::finished, this, &ChatGPTWorker::handleFinished);
|
||||||
|
connect(reply, &QNetworkReply::readyRead, this, &ChatGPTWorker::handleReadyRead);
|
||||||
|
connect(reply, &QNetworkReply::errorOccurred, this, &ChatGPTWorker::handleErrorOccurred);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ChatGPTWorker::handleFinished()
|
||||||
{
|
{
|
||||||
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
||||||
if (!reply)
|
if (!reply) {
|
||||||
|
emit finished();
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
|
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
|
||||||
Q_ASSERT(response.isValid());
|
Q_ASSERT(response.isValid());
|
||||||
@ -159,21 +182,25 @@ void ChatGPT::handleFinished()
|
|||||||
.arg(code).arg(reply->errorString()).toStdString();
|
.arg(code).arg(reply->errorString()).toStdString();
|
||||||
}
|
}
|
||||||
reply->deleteLater();
|
reply->deleteLater();
|
||||||
|
emit finished();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatGPT::handleReadyRead()
|
void ChatGPTWorker::handleReadyRead()
|
||||||
{
|
{
|
||||||
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
||||||
if (!reply)
|
if (!reply) {
|
||||||
|
emit finished();
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
|
QVariant response = reply->attribute(QNetworkRequest::HttpStatusCodeAttribute);
|
||||||
Q_ASSERT(response.isValid());
|
Q_ASSERT(response.isValid());
|
||||||
bool ok;
|
bool ok;
|
||||||
int code = response.toInt(&ok);
|
int code = response.toInt(&ok);
|
||||||
if (!ok || code != 200) {
|
if (!ok || code != 200) {
|
||||||
m_responseCallback(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n")
|
m_chat->callResponse(-1, QString("\nERROR: 2 ChatGPT responded with error code \"%1-%2\" %3\n")
|
||||||
.arg(code).arg(reply->errorString()).arg(qPrintable(reply->readAll())).toStdString());
|
.arg(code).arg(reply->errorString()).arg(qPrintable(reply->readAll())).toStdString());
|
||||||
|
emit finished();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,7 +219,7 @@ void ChatGPT::handleReadyRead()
|
|||||||
QJsonParseError err;
|
QJsonParseError err;
|
||||||
const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err);
|
const QJsonDocument document = QJsonDocument::fromJson(jsonData.toUtf8(), &err);
|
||||||
if (err.error != QJsonParseError::NoError) {
|
if (err.error != QJsonParseError::NoError) {
|
||||||
m_responseCallback(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n")
|
m_chat->callResponse(-1, QString("\nERROR: ChatGPT responded with invalid json \"%1\"\n")
|
||||||
.arg(err.errorString()).toStdString());
|
.arg(err.errorString()).toStdString());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -203,21 +230,24 @@ void ChatGPT::handleReadyRead()
|
|||||||
const QJsonObject delta = choice.value("delta").toObject();
|
const QJsonObject delta = choice.value("delta").toObject();
|
||||||
const QString content = delta.value("content").toString();
|
const QString content = delta.value("content").toString();
|
||||||
Q_ASSERT(m_ctx);
|
Q_ASSERT(m_ctx);
|
||||||
Q_ASSERT(m_responseCallback);
|
|
||||||
m_currentResponse += content;
|
m_currentResponse += content;
|
||||||
if (!m_responseCallback(0, content.toStdString())) {
|
if (!m_chat->callResponse(0, content.toStdString())) {
|
||||||
reply->abort();
|
reply->abort();
|
||||||
|
emit finished();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code)
|
void ChatGPTWorker::handleErrorOccurred(QNetworkReply::NetworkError code)
|
||||||
{
|
{
|
||||||
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
|
||||||
if (!reply)
|
if (!reply) {
|
||||||
|
emit finished();
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
|
qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
|
||||||
.arg(code).arg(reply->errorString()).toStdString();
|
.arg(code).arg(reply->errorString()).toStdString();
|
||||||
|
emit finished();
|
||||||
}
|
}
|
||||||
|
@ -5,9 +5,41 @@
|
|||||||
#include <QNetworkReply>
|
#include <QNetworkReply>
|
||||||
#include <QNetworkRequest>
|
#include <QNetworkRequest>
|
||||||
#include <QNetworkAccessManager>
|
#include <QNetworkAccessManager>
|
||||||
|
#include <QThread>
|
||||||
#include "../gpt4all-backend/llmodel.h"
|
#include "../gpt4all-backend/llmodel.h"
|
||||||
|
|
||||||
class ChatGPTPrivate;
|
class ChatGPT;
|
||||||
|
class ChatGPTWorker : public QObject {
|
||||||
|
Q_OBJECT
|
||||||
|
public:
|
||||||
|
ChatGPTWorker(ChatGPT *chatGPT)
|
||||||
|
: QObject(nullptr)
|
||||||
|
, m_ctx(nullptr)
|
||||||
|
, m_networkManager(nullptr)
|
||||||
|
, m_chat(chatGPT) {}
|
||||||
|
virtual ~ChatGPTWorker() {}
|
||||||
|
|
||||||
|
QString currentResponse() const { return m_currentResponse; }
|
||||||
|
|
||||||
|
void request(const QString &apiKey,
|
||||||
|
LLModel::PromptContext *promptCtx,
|
||||||
|
const QByteArray &array);
|
||||||
|
|
||||||
|
Q_SIGNALS:
|
||||||
|
void finished();
|
||||||
|
|
||||||
|
private Q_SLOTS:
|
||||||
|
void handleFinished();
|
||||||
|
void handleReadyRead();
|
||||||
|
void handleErrorOccurred(QNetworkReply::NetworkError code);
|
||||||
|
|
||||||
|
private:
|
||||||
|
ChatGPT *m_chat;
|
||||||
|
LLModel::PromptContext *m_ctx;
|
||||||
|
QNetworkAccessManager *m_networkManager;
|
||||||
|
QString m_currentResponse;
|
||||||
|
};
|
||||||
|
|
||||||
class ChatGPT : public QObject, public LLModel {
|
class ChatGPT : public QObject, public LLModel {
|
||||||
Q_OBJECT
|
Q_OBJECT
|
||||||
public:
|
public:
|
||||||
@ -35,6 +67,13 @@ public:
|
|||||||
QList<QString> context() const { return m_context; }
|
QList<QString> context() const { return m_context; }
|
||||||
void setContext(const QList<QString> &context) { m_context = context; }
|
void setContext(const QList<QString> &context) { m_context = context; }
|
||||||
|
|
||||||
|
bool callResponse(int32_t token, const std::string& string);
|
||||||
|
|
||||||
|
Q_SIGNALS:
|
||||||
|
void request(const QString &apiKey,
|
||||||
|
LLModel::PromptContext *ctx,
|
||||||
|
const QByteArray &array);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// We have to implement these as they are pure virtual in base class, but we don't actually use
|
// We have to implement these as they are pure virtual in base class, but we don't actually use
|
||||||
// them as they are only called from the default implementation of 'prompt' which we override and
|
// them as they are only called from the default implementation of 'prompt' which we override and
|
||||||
@ -46,19 +85,11 @@ protected:
|
|||||||
int32_t contextLength() const override { return -1; }
|
int32_t contextLength() const override { return -1; }
|
||||||
const std::vector<Token>& endTokens() const override { static const std::vector<Token> fres; return fres; }
|
const std::vector<Token>& endTokens() const override { static const std::vector<Token> fres; return fres; }
|
||||||
|
|
||||||
private Q_SLOTS:
|
|
||||||
void handleFinished();
|
|
||||||
void handleReadyRead();
|
|
||||||
void handleErrorOccurred(QNetworkReply::NetworkError code);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
PromptContext *m_ctx;
|
|
||||||
std::function<bool(int32_t, const std::string&)> m_responseCallback;
|
std::function<bool(int32_t, const std::string&)> m_responseCallback;
|
||||||
QString m_modelName;
|
QString m_modelName;
|
||||||
QString m_apiKey;
|
QString m_apiKey;
|
||||||
QList<QString> m_context;
|
QList<QString> m_context;
|
||||||
QString m_currentResponse;
|
|
||||||
QNetworkAccessManager m_networkManager;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // CHATGPT_H
|
#endif // CHATGPT_H
|
||||||
|
Loading…
Reference in New Issue
Block a user