Add new C++ version of the chat model. Getting ready for chat history.

This commit is contained in:
Adam Treat 2023-04-30 20:28:07 -04:00
parent 65d4b8a886
commit d1e3198b65
7 changed files with 287 additions and 30 deletions

View File

@ -58,6 +58,7 @@ set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
qt_add_executable(chat
main.cpp
chat.h chat.cpp chatmodel.h
download.h download.cpp
network.h network.cpp
llm.h llm.cpp

1
chat.cpp Normal file
View File

@ -0,0 +1 @@
#include "chat.h"

42
chat.h Normal file
View File

@ -0,0 +1,42 @@
#ifndef CHAT_H
#define CHAT_H
#include <QObject>
#include <QtQml>
#include "chatmodel.h"
#include "network.h"
class Chat : public QObject
{
Q_OBJECT
Q_PROPERTY(QString id READ id NOTIFY idChanged)
Q_PROPERTY(QString name READ name NOTIFY nameChanged)
Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged)
QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!")
public:
explicit Chat(QObject *parent = nullptr) : QObject(parent)
{
m_id = Network::globalInstance()->generateUniqueId();
m_name = tr("New Chat");
m_chatModel = new ChatModel(this);
}
QString id() const { return m_id; }
QString name() const { return m_name; }
ChatModel *chatModel() { return m_chatModel; }
Q_SIGNALS:
void idChanged();
void nameChanged();
void chatModelChanged();
private:
QString m_id;
QString m_name;
ChatModel *m_chatModel;
};
#endif // CHAT_H

210
chatmodel.h Normal file
View File

@ -0,0 +1,210 @@
#ifndef CHATMODEL_H
#define CHATMODEL_H
#include <QAbstractListModel>
#include <QtQml>
struct ChatItem
{
Q_GADGET
Q_PROPERTY(int id MEMBER id)
Q_PROPERTY(QString name MEMBER name)
Q_PROPERTY(QString value MEMBER value)
Q_PROPERTY(QString prompt MEMBER prompt)
Q_PROPERTY(QString newResponse MEMBER newResponse)
Q_PROPERTY(bool currentResponse MEMBER currentResponse)
Q_PROPERTY(bool stopped MEMBER stopped)
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState)
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
public:
int id = 0;
QString name;
QString value;
QString prompt;
QString newResponse;
bool currentResponse = false;
bool stopped = false;
bool thumbsUpState = false;
bool thumbsDownState = false;
};
Q_DECLARE_METATYPE(ChatItem)
class ChatModel : public QAbstractListModel
{
Q_OBJECT
Q_PROPERTY(int count READ count NOTIFY countChanged)
public:
explicit ChatModel(QObject *parent = nullptr) : QAbstractListModel(parent) {}
enum Roles {
IdRole = Qt::UserRole + 1,
NameRole,
ValueRole,
PromptRole,
NewResponseRole,
CurrentResponseRole,
StoppedRole,
ThumbsUpStateRole,
ThumbsDownStateRole
};
int rowCount(const QModelIndex &parent = QModelIndex()) const override
{
Q_UNUSED(parent)
return m_chatItems.size();
}
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
{
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
return QVariant();
const ChatItem &item = m_chatItems.at(index.row());
switch (role) {
case IdRole:
return item.id;
case NameRole:
return item.name;
case ValueRole:
return item.value;
case PromptRole:
return item.prompt;
case NewResponseRole:
return item.newResponse;
case CurrentResponseRole:
return item.currentResponse;
case StoppedRole:
return item.stopped;
case ThumbsUpStateRole:
return item.thumbsUpState;
case ThumbsDownStateRole:
return item.thumbsDownState;
}
return QVariant();
}
QHash<int, QByteArray> roleNames() const override
{
QHash<int, QByteArray> roles;
roles[IdRole] = "id";
roles[NameRole] = "name";
roles[ValueRole] = "value";
roles[PromptRole] = "prompt";
roles[NewResponseRole] = "newResponse";
roles[CurrentResponseRole] = "currentResponse";
roles[StoppedRole] = "stopped";
roles[ThumbsUpStateRole] = "thumbsUpState";
roles[ThumbsDownStateRole] = "thumbsDownState";
return roles;
}
Q_INVOKABLE void appendPrompt(const QString &name, const QString &value)
{
ChatItem item;
item.name = name;
item.value = value;
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
m_chatItems.append(item);
endInsertRows();
emit countChanged();
}
Q_INVOKABLE void appendResponse(const QString &name, const QString &prompt)
{
ChatItem item;
item.id = m_chatItems.count(); // This is only relevant for responses
item.name = name;
item.prompt = prompt;
item.currentResponse = true;
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
m_chatItems.append(item);
endInsertRows();
emit countChanged();
}
Q_INVOKABLE ChatItem get(int index)
{
if (index < 0 || index >= m_chatItems.size()) return ChatItem();
return m_chatItems.at(index);
}
Q_INVOKABLE void updateCurrentResponse(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.currentResponse != b) {
item.currentResponse = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
}
}
Q_INVOKABLE void updateStopped(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.stopped != b) {
item.stopped = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole});
}
}
Q_INVOKABLE void updateValue(int index, const QString &value)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.value != value) {
item.value = value;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole});
}
}
Q_INVOKABLE void updateThumbsUpState(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.thumbsUpState != b) {
item.thumbsUpState = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole});
}
}
Q_INVOKABLE void updateThumbsDownState(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.thumbsDownState != b) {
item.thumbsDownState = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole});
}
}
Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse)
{
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.newResponse != newResponse) {
item.newResponse = newResponse;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
}
}
int count() const { return m_chatItems.size(); }
Q_SIGNALS:
void countChanged();
private:
QList<ChatItem> m_chatItems;
};
#endif // CHATMODEL_H

View File

@ -1,6 +1,8 @@
#include "llm.h"
#include "download.h"
#include "network.h"
#include "llmodel/gptj.h"
#include "llmodel/llamamodel.h"
#include <QCoreApplication>
#include <QDir>
@ -345,6 +347,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
LLM::LLM()
: QObject{nullptr}
, m_currentChat(new Chat)
, m_llmodel(new LLMObject)
, m_responseInProgress(false)
{

11
llm.h
View File

@ -3,8 +3,9 @@
#include <QObject>
#include <QThread>
#include "llmodel/gptj.h"
#include "llmodel/llamamodel.h"
#include "chat.h"
#include "llmodel/llmodel.h"
class LLMObject : public QObject
{
@ -24,6 +25,7 @@ public:
void regenerateResponse();
void resetResponse();
void resetContext();
void stopGenerating() { m_stopGenerating = true; }
void setThreadCount(int32_t n_threads);
int32_t threadCount();
@ -83,6 +85,7 @@ class LLM : public QObject
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
Q_PROPERTY(Chat *currentChat READ currentChat NOTIFY currentChatChanged)
public:
@ -111,6 +114,8 @@ public:
bool isRecalc() const;
Chat *currentChat() const { return m_currentChat; }
Q_SIGNALS:
void isModelLoadedChanged();
void responseChanged();
@ -126,12 +131,14 @@ Q_SIGNALS:
void threadCountChanged();
void setThreadCountRequested(int32_t threadCount);
void recalcChanged();
void currentChatChanged();
private Q_SLOTS:
void responseStarted();
void responseStopped();
private:
Chat *m_currentChat;
LLMObject *m_llmodel;
int32_t m_desiredThreadCount;
bool m_responseInProgress;

View File

@ -19,6 +19,7 @@ Window {
}
property string chatId: Network.generateUniqueId()
property var chatModel: LLM.currentChat.chatModel
color: theme.textColor
@ -666,10 +667,6 @@ Window {
anchors.bottomMargin: 30
ScrollBar.vertical.policy: ScrollBar.AlwaysOn
ListModel {
id: chatModel
}
Rectangle {
anchors.fill: parent
color: theme.backgroundLighter
@ -750,9 +747,9 @@ Window {
if (thumbsDownState && !thumbsUpState && !responseHasChanged)
return
newResponse = response
thumbsDownState = true
thumbsUpState = false
chatModel.updateNewResponse(index, response)
chatModel.updateThumbsUpState(index, false)
chatModel.updateThumbsDownState(index, true)
Network.sendConversation(chatId, getConversationJson());
}
}
@ -782,9 +779,9 @@ Window {
if (thumbsUpState && !thumbsDownState)
return
newResponse = ""
thumbsUpState = true
thumbsDownState = false
chatModel.updateNewResponse(index, "")
chatModel.updateThumbsUpState(index, true)
chatModel.updateThumbsDownState(index, false)
Network.sendConversation(chatId, getConversationJson());
}
}
@ -862,8 +859,8 @@ Window {
}
leftPadding: 50
onClicked: {
if (chatModel.count)
var listElement = chatModel.get(chatModel.count - 1)
var index = Math.max(0, chatModel.count - 1);
var listElement = chatModel.get(index);
if (LLM.responseInProgress) {
listElement.stopped = true
@ -872,12 +869,12 @@ Window {
LLM.regenerateResponse()
if (chatModel.count) {
if (listElement.name === qsTr("Response: ")) {
listElement.currentResponse = true
listElement.stopped = false
listElement.value = LLM.response
listElement.thumbsUpState = false
listElement.thumbsDownState = false
listElement.newResponse = ""
chatModel.updateCurrentResponse(index, true);
chatModel.updateStopped(index, false);
chatModel.updateValue(index, LLM.response);
chatModel.updateThumbsUpState(index, false);
chatModel.updateThumbsDownState(index, false);
chatModel.updateNewResponse(index, "");
LLM.prompt(listElement.prompt, settingsDialog.promptTemplate,
settingsDialog.maxLength,
settingsDialog.topK, settingsDialog.topP,
@ -949,18 +946,14 @@ Window {
LLM.stopGenerating()
if (chatModel.count) {
var listElement = chatModel.get(chatModel.count - 1)
listElement.currentResponse = false
listElement.value = LLM.response
var index = Math.max(0, chatModel.count - 1);
var listElement = chatModel.get(index);
chatModel.updateCurrentResponse(index, false);
chatModel.updateValue(index, LLM.response);
}
var prompt = textInput.text + "\n"
chatModel.append({"name": qsTr("Prompt: "), "currentResponse": false,
"value": textInput.text})
chatModel.append({"id": chatModel.count, "name": qsTr("Response: "),
"currentResponse": true, "value": "", "stopped": false,
"thumbsUpState": false, "thumbsDownState": false,
"newResponse": "",
"prompt": prompt})
chatModel.appendPrompt(qsTr("Prompt: "), textInput.text);
chatModel.appendResponse(qsTr("Response: "), prompt);
LLM.resetResponse()
LLM.prompt(prompt, settingsDialog.promptTemplate,
settingsDialog.maxLength,