mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Add new C++ version of the chat model. Getting ready for chat history.
This commit is contained in:
parent
65d4b8a886
commit
d1e3198b65
@ -58,6 +58,7 @@ set (CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
|||||||
|
|
||||||
qt_add_executable(chat
|
qt_add_executable(chat
|
||||||
main.cpp
|
main.cpp
|
||||||
|
chat.h chat.cpp chatmodel.h
|
||||||
download.h download.cpp
|
download.h download.cpp
|
||||||
network.h network.cpp
|
network.h network.cpp
|
||||||
llm.h llm.cpp
|
llm.h llm.cpp
|
||||||
|
42
chat.h
Normal file
42
chat.h
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
#ifndef CHAT_H
|
||||||
|
#define CHAT_H
|
||||||
|
|
||||||
|
#include <QObject>
|
||||||
|
#include <QtQml>
|
||||||
|
|
||||||
|
#include "chatmodel.h"
|
||||||
|
#include "network.h"
|
||||||
|
|
||||||
|
class Chat : public QObject
|
||||||
|
{
|
||||||
|
Q_OBJECT
|
||||||
|
Q_PROPERTY(QString id READ id NOTIFY idChanged)
|
||||||
|
Q_PROPERTY(QString name READ name NOTIFY nameChanged)
|
||||||
|
Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged)
|
||||||
|
QML_ELEMENT
|
||||||
|
QML_UNCREATABLE("Only creatable from c++!")
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit Chat(QObject *parent = nullptr) : QObject(parent)
|
||||||
|
{
|
||||||
|
m_id = Network::globalInstance()->generateUniqueId();
|
||||||
|
m_name = tr("New Chat");
|
||||||
|
m_chatModel = new ChatModel(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
QString id() const { return m_id; }
|
||||||
|
QString name() const { return m_name; }
|
||||||
|
ChatModel *chatModel() { return m_chatModel; }
|
||||||
|
|
||||||
|
Q_SIGNALS:
|
||||||
|
void idChanged();
|
||||||
|
void nameChanged();
|
||||||
|
void chatModelChanged();
|
||||||
|
|
||||||
|
private:
|
||||||
|
QString m_id;
|
||||||
|
QString m_name;
|
||||||
|
ChatModel *m_chatModel;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // CHAT_H
|
210
chatmodel.h
Normal file
210
chatmodel.h
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
#ifndef CHATMODEL_H
|
||||||
|
#define CHATMODEL_H
|
||||||
|
|
||||||
|
#include <QAbstractListModel>
|
||||||
|
#include <QtQml>
|
||||||
|
|
||||||
|
struct ChatItem
|
||||||
|
{
|
||||||
|
Q_GADGET
|
||||||
|
Q_PROPERTY(int id MEMBER id)
|
||||||
|
Q_PROPERTY(QString name MEMBER name)
|
||||||
|
Q_PROPERTY(QString value MEMBER value)
|
||||||
|
Q_PROPERTY(QString prompt MEMBER prompt)
|
||||||
|
Q_PROPERTY(QString newResponse MEMBER newResponse)
|
||||||
|
Q_PROPERTY(bool currentResponse MEMBER currentResponse)
|
||||||
|
Q_PROPERTY(bool stopped MEMBER stopped)
|
||||||
|
Q_PROPERTY(bool thumbsUpState MEMBER thumbsUpState)
|
||||||
|
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)
|
||||||
|
|
||||||
|
public:
|
||||||
|
int id = 0;
|
||||||
|
QString name;
|
||||||
|
QString value;
|
||||||
|
QString prompt;
|
||||||
|
QString newResponse;
|
||||||
|
bool currentResponse = false;
|
||||||
|
bool stopped = false;
|
||||||
|
bool thumbsUpState = false;
|
||||||
|
bool thumbsDownState = false;
|
||||||
|
};
|
||||||
|
Q_DECLARE_METATYPE(ChatItem)
|
||||||
|
|
||||||
|
class ChatModel : public QAbstractListModel
|
||||||
|
{
|
||||||
|
Q_OBJECT
|
||||||
|
Q_PROPERTY(int count READ count NOTIFY countChanged)
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit ChatModel(QObject *parent = nullptr) : QAbstractListModel(parent) {}
|
||||||
|
|
||||||
|
enum Roles {
|
||||||
|
IdRole = Qt::UserRole + 1,
|
||||||
|
NameRole,
|
||||||
|
ValueRole,
|
||||||
|
PromptRole,
|
||||||
|
NewResponseRole,
|
||||||
|
CurrentResponseRole,
|
||||||
|
StoppedRole,
|
||||||
|
ThumbsUpStateRole,
|
||||||
|
ThumbsDownStateRole
|
||||||
|
};
|
||||||
|
|
||||||
|
int rowCount(const QModelIndex &parent = QModelIndex()) const override
|
||||||
|
{
|
||||||
|
Q_UNUSED(parent)
|
||||||
|
return m_chatItems.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
|
||||||
|
{
|
||||||
|
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
|
||||||
|
return QVariant();
|
||||||
|
|
||||||
|
const ChatItem &item = m_chatItems.at(index.row());
|
||||||
|
switch (role) {
|
||||||
|
case IdRole:
|
||||||
|
return item.id;
|
||||||
|
case NameRole:
|
||||||
|
return item.name;
|
||||||
|
case ValueRole:
|
||||||
|
return item.value;
|
||||||
|
case PromptRole:
|
||||||
|
return item.prompt;
|
||||||
|
case NewResponseRole:
|
||||||
|
return item.newResponse;
|
||||||
|
case CurrentResponseRole:
|
||||||
|
return item.currentResponse;
|
||||||
|
case StoppedRole:
|
||||||
|
return item.stopped;
|
||||||
|
case ThumbsUpStateRole:
|
||||||
|
return item.thumbsUpState;
|
||||||
|
case ThumbsDownStateRole:
|
||||||
|
return item.thumbsDownState;
|
||||||
|
}
|
||||||
|
|
||||||
|
return QVariant();
|
||||||
|
}
|
||||||
|
|
||||||
|
QHash<int, QByteArray> roleNames() const override
|
||||||
|
{
|
||||||
|
QHash<int, QByteArray> roles;
|
||||||
|
roles[IdRole] = "id";
|
||||||
|
roles[NameRole] = "name";
|
||||||
|
roles[ValueRole] = "value";
|
||||||
|
roles[PromptRole] = "prompt";
|
||||||
|
roles[NewResponseRole] = "newResponse";
|
||||||
|
roles[CurrentResponseRole] = "currentResponse";
|
||||||
|
roles[StoppedRole] = "stopped";
|
||||||
|
roles[ThumbsUpStateRole] = "thumbsUpState";
|
||||||
|
roles[ThumbsDownStateRole] = "thumbsDownState";
|
||||||
|
return roles;
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_INVOKABLE void appendPrompt(const QString &name, const QString &value)
|
||||||
|
{
|
||||||
|
ChatItem item;
|
||||||
|
item.name = name;
|
||||||
|
item.value = value;
|
||||||
|
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
|
||||||
|
m_chatItems.append(item);
|
||||||
|
endInsertRows();
|
||||||
|
emit countChanged();
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_INVOKABLE void appendResponse(const QString &name, const QString &prompt)
|
||||||
|
{
|
||||||
|
ChatItem item;
|
||||||
|
item.id = m_chatItems.count(); // This is only relevant for responses
|
||||||
|
item.name = name;
|
||||||
|
item.prompt = prompt;
|
||||||
|
item.currentResponse = true;
|
||||||
|
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
|
||||||
|
m_chatItems.append(item);
|
||||||
|
endInsertRows();
|
||||||
|
emit countChanged();
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_INVOKABLE ChatItem get(int index)
|
||||||
|
{
|
||||||
|
if (index < 0 || index >= m_chatItems.size()) return ChatItem();
|
||||||
|
return m_chatItems.at(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_INVOKABLE void updateCurrentResponse(int index, bool b)
|
||||||
|
{
|
||||||
|
if (index < 0 || index >= m_chatItems.size()) return;
|
||||||
|
|
||||||
|
ChatItem &item = m_chatItems[index];
|
||||||
|
if (item.currentResponse != b) {
|
||||||
|
item.currentResponse = b;
|
||||||
|
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_INVOKABLE void updateStopped(int index, bool b)
|
||||||
|
{
|
||||||
|
if (index < 0 || index >= m_chatItems.size()) return;
|
||||||
|
|
||||||
|
ChatItem &item = m_chatItems[index];
|
||||||
|
if (item.stopped != b) {
|
||||||
|
item.stopped = b;
|
||||||
|
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_INVOKABLE void updateValue(int index, const QString &value)
|
||||||
|
{
|
||||||
|
if (index < 0 || index >= m_chatItems.size()) return;
|
||||||
|
|
||||||
|
ChatItem &item = m_chatItems[index];
|
||||||
|
if (item.value != value) {
|
||||||
|
item.value = value;
|
||||||
|
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_INVOKABLE void updateThumbsUpState(int index, bool b)
|
||||||
|
{
|
||||||
|
if (index < 0 || index >= m_chatItems.size()) return;
|
||||||
|
|
||||||
|
ChatItem &item = m_chatItems[index];
|
||||||
|
if (item.thumbsUpState != b) {
|
||||||
|
item.thumbsUpState = b;
|
||||||
|
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_INVOKABLE void updateThumbsDownState(int index, bool b)
|
||||||
|
{
|
||||||
|
if (index < 0 || index >= m_chatItems.size()) return;
|
||||||
|
|
||||||
|
ChatItem &item = m_chatItems[index];
|
||||||
|
if (item.thumbsDownState != b) {
|
||||||
|
item.thumbsDownState = b;
|
||||||
|
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse)
|
||||||
|
{
|
||||||
|
if (index < 0 || index >= m_chatItems.size()) return;
|
||||||
|
|
||||||
|
ChatItem &item = m_chatItems[index];
|
||||||
|
if (item.newResponse != newResponse) {
|
||||||
|
item.newResponse = newResponse;
|
||||||
|
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int count() const { return m_chatItems.size(); }
|
||||||
|
|
||||||
|
Q_SIGNALS:
|
||||||
|
void countChanged();
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
QList<ChatItem> m_chatItems;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // CHATMODEL_H
|
3
llm.cpp
3
llm.cpp
@ -1,6 +1,8 @@
|
|||||||
#include "llm.h"
|
#include "llm.h"
|
||||||
#include "download.h"
|
#include "download.h"
|
||||||
#include "network.h"
|
#include "network.h"
|
||||||
|
#include "llmodel/gptj.h"
|
||||||
|
#include "llmodel/llamamodel.h"
|
||||||
|
|
||||||
#include <QCoreApplication>
|
#include <QCoreApplication>
|
||||||
#include <QDir>
|
#include <QDir>
|
||||||
@ -345,6 +347,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in
|
|||||||
|
|
||||||
LLM::LLM()
|
LLM::LLM()
|
||||||
: QObject{nullptr}
|
: QObject{nullptr}
|
||||||
|
, m_currentChat(new Chat)
|
||||||
, m_llmodel(new LLMObject)
|
, m_llmodel(new LLMObject)
|
||||||
, m_responseInProgress(false)
|
, m_responseInProgress(false)
|
||||||
{
|
{
|
||||||
|
11
llm.h
11
llm.h
@ -3,8 +3,9 @@
|
|||||||
|
|
||||||
#include <QObject>
|
#include <QObject>
|
||||||
#include <QThread>
|
#include <QThread>
|
||||||
#include "llmodel/gptj.h"
|
|
||||||
#include "llmodel/llamamodel.h"
|
#include "chat.h"
|
||||||
|
#include "llmodel/llmodel.h"
|
||||||
|
|
||||||
class LLMObject : public QObject
|
class LLMObject : public QObject
|
||||||
{
|
{
|
||||||
@ -24,6 +25,7 @@ public:
|
|||||||
void regenerateResponse();
|
void regenerateResponse();
|
||||||
void resetResponse();
|
void resetResponse();
|
||||||
void resetContext();
|
void resetContext();
|
||||||
|
|
||||||
void stopGenerating() { m_stopGenerating = true; }
|
void stopGenerating() { m_stopGenerating = true; }
|
||||||
void setThreadCount(int32_t n_threads);
|
void setThreadCount(int32_t n_threads);
|
||||||
int32_t threadCount();
|
int32_t threadCount();
|
||||||
@ -83,6 +85,7 @@ class LLM : public QObject
|
|||||||
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged)
|
||||||
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
Q_PROPERTY(int32_t threadCount READ threadCount WRITE setThreadCount NOTIFY threadCountChanged)
|
||||||
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged)
|
||||||
|
Q_PROPERTY(Chat *currentChat READ currentChat NOTIFY currentChatChanged)
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
@ -111,6 +114,8 @@ public:
|
|||||||
|
|
||||||
bool isRecalc() const;
|
bool isRecalc() const;
|
||||||
|
|
||||||
|
Chat *currentChat() const { return m_currentChat; }
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void isModelLoadedChanged();
|
void isModelLoadedChanged();
|
||||||
void responseChanged();
|
void responseChanged();
|
||||||
@ -126,12 +131,14 @@ Q_SIGNALS:
|
|||||||
void threadCountChanged();
|
void threadCountChanged();
|
||||||
void setThreadCountRequested(int32_t threadCount);
|
void setThreadCountRequested(int32_t threadCount);
|
||||||
void recalcChanged();
|
void recalcChanged();
|
||||||
|
void currentChatChanged();
|
||||||
|
|
||||||
private Q_SLOTS:
|
private Q_SLOTS:
|
||||||
void responseStarted();
|
void responseStarted();
|
||||||
void responseStopped();
|
void responseStopped();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Chat *m_currentChat;
|
||||||
LLMObject *m_llmodel;
|
LLMObject *m_llmodel;
|
||||||
int32_t m_desiredThreadCount;
|
int32_t m_desiredThreadCount;
|
||||||
bool m_responseInProgress;
|
bool m_responseInProgress;
|
||||||
|
49
main.qml
49
main.qml
@ -19,6 +19,7 @@ Window {
|
|||||||
}
|
}
|
||||||
|
|
||||||
property string chatId: Network.generateUniqueId()
|
property string chatId: Network.generateUniqueId()
|
||||||
|
property var chatModel: LLM.currentChat.chatModel
|
||||||
|
|
||||||
color: theme.textColor
|
color: theme.textColor
|
||||||
|
|
||||||
@ -666,10 +667,6 @@ Window {
|
|||||||
anchors.bottomMargin: 30
|
anchors.bottomMargin: 30
|
||||||
ScrollBar.vertical.policy: ScrollBar.AlwaysOn
|
ScrollBar.vertical.policy: ScrollBar.AlwaysOn
|
||||||
|
|
||||||
ListModel {
|
|
||||||
id: chatModel
|
|
||||||
}
|
|
||||||
|
|
||||||
Rectangle {
|
Rectangle {
|
||||||
anchors.fill: parent
|
anchors.fill: parent
|
||||||
color: theme.backgroundLighter
|
color: theme.backgroundLighter
|
||||||
@ -750,9 +747,9 @@ Window {
|
|||||||
if (thumbsDownState && !thumbsUpState && !responseHasChanged)
|
if (thumbsDownState && !thumbsUpState && !responseHasChanged)
|
||||||
return
|
return
|
||||||
|
|
||||||
newResponse = response
|
chatModel.updateNewResponse(index, response)
|
||||||
thumbsDownState = true
|
chatModel.updateThumbsUpState(index, false)
|
||||||
thumbsUpState = false
|
chatModel.updateThumbsDownState(index, true)
|
||||||
Network.sendConversation(chatId, getConversationJson());
|
Network.sendConversation(chatId, getConversationJson());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -782,9 +779,9 @@ Window {
|
|||||||
if (thumbsUpState && !thumbsDownState)
|
if (thumbsUpState && !thumbsDownState)
|
||||||
return
|
return
|
||||||
|
|
||||||
newResponse = ""
|
chatModel.updateNewResponse(index, "")
|
||||||
thumbsUpState = true
|
chatModel.updateThumbsUpState(index, true)
|
||||||
thumbsDownState = false
|
chatModel.updateThumbsDownState(index, false)
|
||||||
Network.sendConversation(chatId, getConversationJson());
|
Network.sendConversation(chatId, getConversationJson());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -862,8 +859,8 @@ Window {
|
|||||||
}
|
}
|
||||||
leftPadding: 50
|
leftPadding: 50
|
||||||
onClicked: {
|
onClicked: {
|
||||||
if (chatModel.count)
|
var index = Math.max(0, chatModel.count - 1);
|
||||||
var listElement = chatModel.get(chatModel.count - 1)
|
var listElement = chatModel.get(index);
|
||||||
|
|
||||||
if (LLM.responseInProgress) {
|
if (LLM.responseInProgress) {
|
||||||
listElement.stopped = true
|
listElement.stopped = true
|
||||||
@ -872,12 +869,12 @@ Window {
|
|||||||
LLM.regenerateResponse()
|
LLM.regenerateResponse()
|
||||||
if (chatModel.count) {
|
if (chatModel.count) {
|
||||||
if (listElement.name === qsTr("Response: ")) {
|
if (listElement.name === qsTr("Response: ")) {
|
||||||
listElement.currentResponse = true
|
chatModel.updateCurrentResponse(index, true);
|
||||||
listElement.stopped = false
|
chatModel.updateStopped(index, false);
|
||||||
listElement.value = LLM.response
|
chatModel.updateValue(index, LLM.response);
|
||||||
listElement.thumbsUpState = false
|
chatModel.updateThumbsUpState(index, false);
|
||||||
listElement.thumbsDownState = false
|
chatModel.updateThumbsDownState(index, false);
|
||||||
listElement.newResponse = ""
|
chatModel.updateNewResponse(index, "");
|
||||||
LLM.prompt(listElement.prompt, settingsDialog.promptTemplate,
|
LLM.prompt(listElement.prompt, settingsDialog.promptTemplate,
|
||||||
settingsDialog.maxLength,
|
settingsDialog.maxLength,
|
||||||
settingsDialog.topK, settingsDialog.topP,
|
settingsDialog.topK, settingsDialog.topP,
|
||||||
@ -949,18 +946,14 @@ Window {
|
|||||||
LLM.stopGenerating()
|
LLM.stopGenerating()
|
||||||
|
|
||||||
if (chatModel.count) {
|
if (chatModel.count) {
|
||||||
var listElement = chatModel.get(chatModel.count - 1)
|
var index = Math.max(0, chatModel.count - 1);
|
||||||
listElement.currentResponse = false
|
var listElement = chatModel.get(index);
|
||||||
listElement.value = LLM.response
|
chatModel.updateCurrentResponse(index, false);
|
||||||
|
chatModel.updateValue(index, LLM.response);
|
||||||
}
|
}
|
||||||
var prompt = textInput.text + "\n"
|
var prompt = textInput.text + "\n"
|
||||||
chatModel.append({"name": qsTr("Prompt: "), "currentResponse": false,
|
chatModel.appendPrompt(qsTr("Prompt: "), textInput.text);
|
||||||
"value": textInput.text})
|
chatModel.appendResponse(qsTr("Response: "), prompt);
|
||||||
chatModel.append({"id": chatModel.count, "name": qsTr("Response: "),
|
|
||||||
"currentResponse": true, "value": "", "stopped": false,
|
|
||||||
"thumbsUpState": false, "thumbsDownState": false,
|
|
||||||
"newResponse": "",
|
|
||||||
"prompt": prompt})
|
|
||||||
LLM.resetResponse()
|
LLM.resetResponse()
|
||||||
LLM.prompt(prompt, settingsDialog.promptTemplate,
|
LLM.prompt(prompt, settingsDialog.promptTemplate,
|
||||||
settingsDialog.maxLength,
|
settingsDialog.maxLength,
|
||||||
|
Loading…
Reference in New Issue
Block a user