mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Add save/restore to chatgpt chats and allow serialize/deseralize from disk.
This commit is contained in:
parent
0cd509d530
commit
f931de21c5
@ -258,6 +258,7 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
||||
// unfortunately, we cannot deserialize these
|
||||
if (version < 2 && m_savedModelName.contains("gpt4all-j"))
|
||||
return false;
|
||||
m_llmodel->setModelName(m_savedModelName);
|
||||
if (!m_llmodel->deserialize(stream, version))
|
||||
return false;
|
||||
if (!m_chatModel->deserialize(stream, version))
|
||||
|
@ -46,6 +46,7 @@ bool ChatGPT::isModelLoaded() const
|
||||
return true;
|
||||
}
|
||||
|
||||
// All three of the state virtual functions are handled custom inside of chatllm save/restore
|
||||
size_t ChatGPT::stateSize() const
|
||||
{
|
||||
return 0;
|
||||
@ -53,11 +54,13 @@ size_t ChatGPT::stateSize() const
|
||||
|
||||
size_t ChatGPT::saveState(uint8_t *dest) const
|
||||
{
|
||||
Q_UNUSED(dest);
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t ChatGPT::restoreState(const uint8_t *src)
|
||||
{
|
||||
Q_UNUSED(src);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -141,8 +144,8 @@ void ChatGPT::handleFinished()
|
||||
bool ok;
|
||||
int code = response.toInt(&ok);
|
||||
if (!ok || code != 200) {
|
||||
qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n")
|
||||
.arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString();
|
||||
qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
|
||||
.arg(code).arg(reply->errorString()).toStdString();
|
||||
}
|
||||
reply->deleteLater();
|
||||
}
|
||||
@ -190,8 +193,11 @@ void ChatGPT::handleReadyRead()
|
||||
const QString content = delta.value("content").toString();
|
||||
Q_ASSERT(m_ctx);
|
||||
Q_ASSERT(m_responseCallback);
|
||||
m_responseCallback(0, content.toStdString());
|
||||
m_currentResponse += content;
|
||||
if (!m_responseCallback(0, content.toStdString())) {
|
||||
reply->abort();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,6 +207,6 @@ void ChatGPT::handleErrorOccurred(QNetworkReply::NetworkError code)
|
||||
if (!reply)
|
||||
return;
|
||||
|
||||
qWarning() << QString("\nERROR: ChatGPT responded with error code \"%1-%2%3\"\n")
|
||||
.arg(code).arg(reply->errorString()).arg(reply->readAll()).toStdString();
|
||||
qWarning() << QString("ERROR: ChatGPT responded with error code \"%1-%2\"")
|
||||
.arg(code).arg(reply->errorString()).toStdString();
|
||||
}
|
||||
|
@ -30,6 +30,9 @@ public:
|
||||
void setModelName(const QString &modelName) { m_modelName = modelName; }
|
||||
void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; }
|
||||
|
||||
QList<QString> context() const { return m_context; }
|
||||
void setContext(const QList<QString> &context) { m_context = context; }
|
||||
|
||||
protected:
|
||||
void recalculateContext(PromptContext &promptCtx,
|
||||
std::function<bool(bool)> recalculate) override {}
|
||||
|
@ -38,6 +38,19 @@ void ChatListModel::setShouldSaveChats(bool b)
|
||||
emit shouldSaveChatsChanged();
|
||||
}
|
||||
|
||||
bool ChatListModel::shouldSaveChatGPTChats() const
|
||||
{
|
||||
return m_shouldSaveChatGPTChats;
|
||||
}
|
||||
|
||||
void ChatListModel::setShouldSaveChatGPTChats(bool b)
|
||||
{
|
||||
if (m_shouldSaveChatGPTChats == b)
|
||||
return;
|
||||
m_shouldSaveChatGPTChats = b;
|
||||
emit shouldSaveChatGPTChatsChanged();
|
||||
}
|
||||
|
||||
void ChatListModel::removeChatFile(Chat *chat) const
|
||||
{
|
||||
Q_ASSERT(chat != m_serverChat);
|
||||
@ -52,15 +65,17 @@ void ChatListModel::removeChatFile(Chat *chat) const
|
||||
|
||||
void ChatListModel::saveChats() const
|
||||
{
|
||||
if (!m_shouldSaveChats)
|
||||
return;
|
||||
|
||||
QElapsedTimer timer;
|
||||
timer.start();
|
||||
const QString savePath = Download::globalInstance()->downloadLocalModelsPath();
|
||||
for (Chat *chat : m_chats) {
|
||||
if (chat == m_serverChat)
|
||||
continue;
|
||||
const bool isChatGPT = chat->modelName().startsWith("chatgpt-");
|
||||
if (!isChatGPT && !m_shouldSaveChats)
|
||||
continue;
|
||||
if (isChatGPT && !m_shouldSaveChatGPTChats)
|
||||
continue;
|
||||
QString fileName = "gpt4all-" + chat->id() + ".chat";
|
||||
QFile file(savePath + "/" + fileName);
|
||||
bool success = file.open(QIODevice::WriteOnly);
|
||||
|
@ -20,6 +20,7 @@ class ChatListModel : public QAbstractListModel
|
||||
Q_PROPERTY(int count READ count NOTIFY countChanged)
|
||||
Q_PROPERTY(Chat *currentChat READ currentChat WRITE setCurrentChat NOTIFY currentChatChanged)
|
||||
Q_PROPERTY(bool shouldSaveChats READ shouldSaveChats WRITE setShouldSaveChats NOTIFY shouldSaveChatsChanged)
|
||||
Q_PROPERTY(bool shouldSaveChatGPTChats READ shouldSaveChatGPTChats WRITE setShouldSaveChatGPTChats NOTIFY shouldSaveChatGPTChatsChanged)
|
||||
|
||||
public:
|
||||
explicit ChatListModel(QObject *parent = nullptr);
|
||||
@ -62,6 +63,9 @@ public:
|
||||
bool shouldSaveChats() const;
|
||||
void setShouldSaveChats(bool b);
|
||||
|
||||
bool shouldSaveChatGPTChats() const;
|
||||
void setShouldSaveChatGPTChats(bool b);
|
||||
|
||||
Q_INVOKABLE void addChat()
|
||||
{
|
||||
// Don't add a new chat if we already have one
|
||||
@ -199,6 +203,7 @@ Q_SIGNALS:
|
||||
void countChanged();
|
||||
void currentChatChanged();
|
||||
void shouldSaveChatsChanged();
|
||||
void shouldSaveChatGPTChatsChanged();
|
||||
|
||||
private Q_SLOTS:
|
||||
void newChatCountChanged()
|
||||
@ -240,6 +245,7 @@ private Q_SLOTS:
|
||||
|
||||
private:
|
||||
bool m_shouldSaveChats;
|
||||
bool m_shouldSaveChatGPTChats;
|
||||
Chat* m_newChat;
|
||||
Chat* m_dummyChat;
|
||||
Chat* m_serverChat;
|
||||
|
@ -611,6 +611,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
|
||||
stream >> compressed;
|
||||
m_state = qUncompress(compressed);
|
||||
} else {
|
||||
|
||||
stream >> m_state;
|
||||
}
|
||||
#if defined(DEBUG)
|
||||
@ -624,6 +625,15 @@ void ChatLLM::saveState()
|
||||
if (!isModelLoaded())
|
||||
return;
|
||||
|
||||
if (m_isChatGPT) {
|
||||
m_state.clear();
|
||||
QDataStream stream(&m_state, QIODeviceBase::WriteOnly);
|
||||
stream.setVersion(QDataStream::Qt_6_5);
|
||||
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model);
|
||||
stream << chatGPT->context();
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t stateSize = m_modelInfo.model->stateSize();
|
||||
m_state.resize(stateSize);
|
||||
#if defined(DEBUG)
|
||||
@ -637,6 +647,18 @@ void ChatLLM::restoreState()
|
||||
if (!isModelLoaded() || m_state.isEmpty())
|
||||
return;
|
||||
|
||||
if (m_isChatGPT) {
|
||||
QDataStream stream(&m_state, QIODeviceBase::ReadOnly);
|
||||
stream.setVersion(QDataStream::Qt_6_5);
|
||||
ChatGPT *chatGPT = static_cast<ChatGPT*>(m_modelInfo.model);
|
||||
QList<QString> context;
|
||||
stream >> context;
|
||||
chatGPT->setContext(context);
|
||||
m_state.clear();
|
||||
m_state.resize(0);
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(DEBUG)
|
||||
qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size();
|
||||
#endif
|
||||
|
@ -40,6 +40,7 @@ Dialog {
|
||||
property int defaultRepeatPenaltyTokens: 64
|
||||
property int defaultThreadCount: 0
|
||||
property bool defaultSaveChats: false
|
||||
property bool defaultSaveChatGPTChats: true
|
||||
property bool defaultServerChat: false
|
||||
property string defaultPromptTemplate: "### Human:
|
||||
%1
|
||||
@ -57,6 +58,7 @@ Dialog {
|
||||
property alias repeatPenaltyTokens: settings.repeatPenaltyTokens
|
||||
property alias threadCount: settings.threadCount
|
||||
property alias saveChats: settings.saveChats
|
||||
property alias saveChatGPTChats: settings.saveChatGPTChats
|
||||
property alias serverChat: settings.serverChat
|
||||
property alias modelPath: settings.modelPath
|
||||
property alias userDefaultModel: settings.userDefaultModel
|
||||
@ -70,6 +72,7 @@ Dialog {
|
||||
property int promptBatchSize: settingsDialog.defaultPromptBatchSize
|
||||
property int threadCount: settingsDialog.defaultThreadCount
|
||||
property bool saveChats: settingsDialog.defaultSaveChats
|
||||
property bool saveChatGPTChats: settingsDialog.defaultSaveChatGPTChats
|
||||
property bool serverChat: settingsDialog.defaultServerChat
|
||||
property real repeatPenalty: settingsDialog.defaultRepeatPenalty
|
||||
property int repeatPenaltyTokens: settingsDialog.defaultRepeatPenaltyTokens
|
||||
@ -94,12 +97,14 @@ Dialog {
|
||||
settings.modelPath = settingsDialog.defaultModelPath
|
||||
settings.threadCount = defaultThreadCount
|
||||
settings.saveChats = defaultSaveChats
|
||||
settings.saveChatGPTChats = defaultSaveChatGPTChats
|
||||
settings.serverChat = defaultServerChat
|
||||
settings.userDefaultModel = defaultUserDefaultModel
|
||||
Download.downloadLocalModelsPath = settings.modelPath
|
||||
LLM.threadCount = settings.threadCount
|
||||
LLM.serverEnabled = settings.serverChat
|
||||
LLM.chatListModel.shouldSaveChats = settings.saveChats
|
||||
LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats
|
||||
settings.sync()
|
||||
}
|
||||
|
||||
@ -107,6 +112,7 @@ Dialog {
|
||||
LLM.threadCount = settings.threadCount
|
||||
LLM.serverEnabled = settings.serverChat
|
||||
LLM.chatListModel.shouldSaveChats = settings.saveChats
|
||||
LLM.chatListModel.shouldSaveChatGPTChats = settings.saveChatGPTChats
|
||||
Download.downloadLocalModelsPath = settings.modelPath
|
||||
}
|
||||
|
||||
@ -803,16 +809,65 @@ Dialog {
|
||||
}
|
||||
}
|
||||
Label {
|
||||
id: serverChatLabel
|
||||
text: qsTr("Enable web server:")
|
||||
id: saveChatGPTChatsLabel
|
||||
text: qsTr("Save ChatGPT chats to disk:")
|
||||
color: theme.textColor
|
||||
Layout.row: 5
|
||||
Layout.column: 0
|
||||
}
|
||||
CheckBox {
|
||||
id: serverChatBox
|
||||
id: saveChatGPTChatsBox
|
||||
Layout.row: 5
|
||||
Layout.column: 1
|
||||
checked: settingsDialog.saveChatGPTChats
|
||||
onClicked: {
|
||||
settingsDialog.saveChatGPTChats = saveChatGPTChatsBox.checked
|
||||
LLM.chatListModel.shouldSaveChatGPTChats = saveChatGPTChatsBox.checked
|
||||
settings.sync()
|
||||
}
|
||||
|
||||
background: Rectangle {
|
||||
color: "transparent"
|
||||
}
|
||||
|
||||
indicator: Rectangle {
|
||||
implicitWidth: 26
|
||||
implicitHeight: 26
|
||||
x: saveChatGPTChatsBox.leftPadding
|
||||
y: parent.height / 2 - height / 2
|
||||
border.color: theme.dialogBorder
|
||||
color: "transparent"
|
||||
|
||||
Rectangle {
|
||||
width: 14
|
||||
height: 14
|
||||
x: 6
|
||||
y: 6
|
||||
color: theme.textColor
|
||||
visible: saveChatGPTChatsBox.checked
|
||||
}
|
||||
}
|
||||
|
||||
contentItem: Text {
|
||||
text: saveChatGPTChatsBox.text
|
||||
font: saveChatGPTChatsBox.font
|
||||
opacity: enabled ? 1.0 : 0.3
|
||||
color: theme.textColor
|
||||
verticalAlignment: Text.AlignVCenter
|
||||
leftPadding: saveChatGPTChatsBox.indicator.width + saveChatGPTChatsBox.spacing
|
||||
}
|
||||
}
|
||||
Label {
|
||||
id: serverChatLabel
|
||||
text: qsTr("Enable web server:")
|
||||
color: theme.textColor
|
||||
Layout.row: 6
|
||||
Layout.column: 0
|
||||
}
|
||||
CheckBox {
|
||||
id: serverChatBox
|
||||
Layout.row: 6
|
||||
Layout.column: 1
|
||||
checked: settings.serverChat
|
||||
onClicked: {
|
||||
settingsDialog.serverChat = serverChatBox.checked
|
||||
@ -855,7 +910,7 @@ Dialog {
|
||||
}
|
||||
}
|
||||
Button {
|
||||
Layout.row: 6
|
||||
Layout.row: 7
|
||||
Layout.column: 1
|
||||
Layout.fillWidth: true
|
||||
padding: 10
|
||||
|
Loading…
Reference in New Issue
Block a user