Add new C++ version of the chat model. Getting ready for chat history.

This commit is contained in:
Adam Treat 2023-04-30 20:28:07 -04:00
parent 65d4b8a886
commit d1e3198b65
7 changed files with 287 additions and 30 deletions

View File

@ -58,6 +58,7 @@ set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
qt_add_executable(chat qt_add_executable(chat
main.cpp main.cpp
chat.h chat.cpp chatmodel.h
download.h download.cpp download.h download.cpp
network.h network.cpp network.h network.cpp
llm.h llm.cpp llm.h llm.cpp

1
chat.cpp Normal file
View File

@ -0,0 +1 @@
#include "chat.h"

42
chat.h Normal file
View File

@ -0,0 +1,42 @@
#ifndef CHAT_H
#define CHAT_H
#include <QObject>
#include <QtQml>
#include "chatmodel.h"
#include "network.h"
class Chat : public QObject
{
Q_OBJECT
Q_PROPERTY(QString id READ id NOTIFY idChanged)
Q_PROPERTY(QString name READ name NOTIFY nameChanged)
Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged)
QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!")
public:
explicit Chat(QObject *parent = nullptr) : QObject(parent)
{
m_id = Network::globalInstance()->generateUniqueId();
m_name = tr("New Chat");
m_chatModel = new ChatModel(this);
}
QString id() const { return m_id; }
QString name() const { return m_name; }
ChatModel *chatModel() { return m_chatModel; }
Q_SIGNALS:
void idChanged();
void nameChanged();
void chatModelChanged();
private:
QString m_id;
QString m_name;
ChatModel *m_chatModel;
};
#endif // CHAT_H

210
chatmodel.h Normal file
View File

@ -0,0 +1,210 @@
#ifndef CHATMODEL_H
#define CHATMODEL_H
#include <QAbstractListModel>
#include <QtQml>
struct ChatItem
{
Q_GADGET
Q_PROPERTY(int id MEMBER id)
Q_PROPERTY(QString name MEMBER name)
Q_PROPERTY(QString value MEMBER value)
Q_PROPERTY(QString prompt MEMBER prompt)
Q_PROPERTY(QString newResponse MEMBER newResponse)
Q_PROPERTY(bool currentResponse MEMBER currentResponse)
Q_PROPERTY(bool stopped MEMBER stopped)
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState)
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
public:
int id = 0;
QString name;
QString value;
QString prompt;
QString newResponse;
bool currentResponse = false;
bool stopped = false;
bool thumbsUpState = false;
bool thumbsDownState = false;
};
Q_DECLARE_METATYPE(ChatItem)
class ChatModel : public QAbstractListModel
{
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
public:
explicit ChatModel(QObject *parent = nullptr) : QAbstractListModel(parent) {}
enum Roles {
IdRole = Qt::UserRole + 1,
NameRole,
ValueRole,
PromptRole,
NewResponseRole,
CurrentResponseRole,
StoppedRole,
ThumbsUpStateRole,
ThumbsDownStateRole
};
int rowCount(const QModelIndex &parent = QModelIndex()) const override
{
Q_UNUSED(parent)
return m_chatItems.size();
}
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
{
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
return QVariant();
const ChatItem &item = m_chatItems.at(index.row());
switch (role) {
case IdRole:
return item.id;
case NameRole:
return item.name;
case ValueRole:
return item.value;
case PromptRole:
return item.prompt;
case NewResponseRole:
return item.newResponse;
case CurrentResponseRole:
return item.currentResponse;
case StoppedRole:
return item.stopped;
case ThumbsUpStateRole:
return item.thumbsUpState;
case ThumbsDownStateRole:
return item.thumbsDownState;
}
return QVariant();
}
QHash<int, QByteArray> roleNames() const override
{
QHash<int, QByteArray> roles;
roles[IdRole] = "id";
roles[NameRole] = "name";
roles[ValueRole] = "value";
roles[PromptRole] = "prompt";
roles[NewResponseRole] = "newResponse";
roles[CurrentResponseRole] = "currentResponse";
roles[StoppedRole] = "stopped";
roles[ThumbsUpStateRole] = "thumbsUpState";
roles[ThumbsDownStateRole] = "thumbsDownState";
return roles;
}
Q_INVOKABLE void appendPrompt(const QString &name, const QString &value)
{
ChatItem item;
item.name = name;
item.value = value;
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
m_chatItems.append(item);
endInsertRows();
emit countChanged();
}
Q_INVOKABLE void appendResponse(const QString &name, const QString &prompt)
{
ChatItem item;
item.id = m_chatItems.count(); // This is only relevant for responses
item.name = name;
item.prompt = prompt;
item.currentResponse = true;
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
m_chatItems.append(item);
endInsertRows();
emit countChanged();
}
Q_INVOKABLE ChatItem get(int index)
{
if (index < 0 || index >= m_chatItems.size()) return ChatItem();
return m_chatItems.at(index);
}
Q_INVOKABLE void updateCurrentResponse(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.currentResponse != b) {
item.currentResponse = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
}
}
Q_INVOKABLE void updateStopped(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.stopped != b) {
item.stopped = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole});
}
}
Q_INVOKABLE void updateValue(int index, const QString &value)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.value != value) {
item.value = value;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole});
}
}
Q_INVOKABLE void updateThumbsUpState(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.thumbsUpState != b) {
item.thumbsUpState = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole});
}
}
Q_INVOKABLE void updateThumbsDownState(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.thumbsDownState != b) {
item.thumbsDownState = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole});
}
}
Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.newResponse != newResponse) {
item.newResponse = newResponse;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
}
}
int count() const { return m_chatItems.size(); }
Q_SIGNALS:
void countChanged();
private:
QList<ChatItem> m_chatItems;
};
#endif // CHATMODEL_H

View File

@ -1,6 +1,8 @@
#include "llm.h" #include "llm.h"
#include "download.h" #include "download.h"
#include "network.h" #include "network.h"
#include "llmodel/gptj.h"
#include "llmodel/llamamodel.h"
#include <QCoreApplication> #include <QCoreApplication>
#include <QDir> #include <QDir>
@ -345,6 +347,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
LLM::LLM() LLM::LLM()
: QObject{nullptr} : QObject{nullptr}
, m_currentChat(new Chat)
, m_llmodel(new LLMObject) , m_llmodel(new LLMObject)
, m_responseInProgress(false) , m_responseInProgress(false)
{ {

11
llm.h
View File

@ -3,8 +3,9 @@
#include <QObject> #include <QObject>
#include <QThread> #include <QThread>
#include "llmodel/gptj.h"
#include "llmodel/llamamodel.h" #include "chat.h"
#include "llmodel/llmodel.h"
class LLMObject : public QObject class LLMObject : public QObject
{ {
@ -24,6 +25,7 @@ public:
void regenerateResponse(); void regenerateResponse();
void resetResponse(); void resetResponse();
void resetContext(); void resetContext();
void stopGenerating() { m_stopGenerating = true; } void stopGenerating() { m_stopGenerating = true; }
void setThreadCount(int32_t n_threads); void setThreadCount(int32_t n_threads);
int32_t threadCount(); int32_t threadCount();
@ -83,6 +85,7 @@ class LLM : public QObject
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged) Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(Chat *currentChat READ currentChat NOTIFY currentChatChanged)
public: public:
@ -111,6 +114,8 @@ public:
bool isRecalc() const; bool isRecalc() const;
Chat *currentChat() const { return m_currentChat; }
Q_SIGNALS: Q_SIGNALS:
void isModelLoadedChanged(); void isModelLoadedChanged();
void responseChanged(); void responseChanged();
@ -126,12 +131,14 @@ Q_SIGNALS:
void threadCountChanged(); void threadCountChanged();
void setThreadCountRequested(int32_t threadCount); void setThreadCountRequested(int32_t threadCount);
void recalcChanged(); void recalcChanged();
void currentChatChanged();
private Q_SLOTS: private Q_SLOTS:
void responseStarted(); void responseStarted();
void responseStopped(); void responseStopped();
private: private:
Chat *m_currentChat;
LLMObject *m_llmodel; LLMObject *m_llmodel;
int32_t m_desiredThreadCount; int32_t m_desiredThreadCount;
bool m_responseInProgress; bool m_responseInProgress;

View File

@ -19,6 +19,7 @@ Window {
} }
property string chatId: Network.generateUniqueId() property string chatId: Network.generateUniqueId()
property var chatModel: LLM.currentChat.chatModel
color: theme.textColor color: theme.textColor
@ -666,10 +667,6 @@ Window {
anchors.bottomMargin: 30 anchors.bottomMargin: 30
ScrollBar.vertical.policy: ScrollBar.AlwaysOn ScrollBar.vertical.policy: ScrollBar.AlwaysOn
ListModel {
id: chatModel
}
Rectangle { Rectangle {
anchors.fill: parent anchors.fill: parent
color: theme.backgroundLighter color: theme.backgroundLighter
@ -750,9 +747,9 @@ Window {
if (thumbsDownState && !thumbsUpState && !responseHasChanged) if (thumbsDownState && !thumbsUpState && !responseHasChanged)
return return
newResponse = response chatModel.updateNewResponse(index, response)
thumbsDownState = true chatModel.updateThumbsUpState(index, false)
thumbsUpState = false chatModel.updateThumbsDownState(index, true)
Network.sendConversation(chatId, getConversationJson()); Network.sendConversation(chatId, getConversationJson());
} }
} }
@ -782,9 +779,9 @@ Window {
if (thumbsUpState && !thumbsDownState) if (thumbsUpState && !thumbsDownState)
return return
newResponse = "" chatModel.updateNewResponse(index, "")
thumbsUpState = true chatModel.updateThumbsUpState(index, true)
thumbsDownState = false chatModel.updateThumbsDownState(index, false)
Network.sendConversation(chatId, getConversationJson()); Network.sendConversation(chatId, getConversationJson());
} }
} }
@ -862,8 +859,8 @@ Window {
} }
leftPadding: 50 leftPadding: 50
onClicked: { onClicked: {
if (chatModel.count) var index = Math.max(0, chatModel.count - 1);
var listElement = chatModel.get(chatModel.count - 1) var listElement = chatModel.get(index);
if (LLM.responseInProgress) { if (LLM.responseInProgress) {
listElement.stopped = true listElement.stopped = true
@ -872,12 +869,12 @@ Window {
LLM.regenerateResponse() LLM.regenerateResponse()
if (chatModel.count) { if (chatModel.count) {
if (listElement.name === qsTr("Response: ")) { if (listElement.name === qsTr("Response: ")) {
listElement.currentResponse = true chatModel.updateCurrentResponse(index, true);
listElement.stopped = false chatModel.updateStopped(index, false);
listElement.value = LLM.response chatModel.updateValue(index, LLM.response);
listElement.thumbsUpState = false chatModel.updateThumbsUpState(index, false);
listElement.thumbsDownState = false chatModel.updateThumbsDownState(index, false);
listElement.newResponse = "" chatModel.updateNewResponse(index, "");
LLM.prompt(listElement.prompt, settingsDialog.promptTemplate, LLM.prompt(listElement.prompt, settingsDialog.promptTemplate,
settingsDialog.maxLength, settingsDialog.maxLength,
settingsDialog.topK, settingsDialog.topP, settingsDialog.topK, settingsDialog.topP,
@ -949,18 +946,14 @@ Window {
LLM.stopGenerating() LLM.stopGenerating()
if (chatModel.count) { if (chatModel.count) {
var listElement = chatModel.get(chatModel.count - 1) var index = Math.max(0, chatModel.count - 1);
listElement.currentResponse = false var listElement = chatModel.get(index);
listElement.value = LLM.response chatModel.updateCurrentResponse(index, false);
chatModel.updateValue(index, LLM.response);
} }
var prompt = textInput.text + "\n" var prompt = textInput.text + "\n"
chatModel.append({"name": qsTr("Prompt: "), "currentResponse": false, chatModel.appendPrompt(qsTr("Prompt: "), textInput.text);
"value": textInput.text}) chatModel.appendResponse(qsTr("Response: "), prompt);
chatModel.append({"id": chatModel.count, "name": qsTr("Response: "),
"currentResponse": true, "value": "", "stopped": false,
"thumbsUpState": false, "thumbsDownState": false,
"newResponse": "",
"prompt": prompt})
LLM.resetResponse() LLM.resetResponse()
LLM.prompt(prompt, settingsDialog.promptTemplate, LLM.prompt(prompt, settingsDialog.promptTemplate,
settingsDialog.maxLength, settingsDialog.maxLength,