mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Fix gptj to have lower memory requirements for kv cache and add versioning to the internal state to smoothly handle such a fix in the future.
This commit is contained in:
parent
ccbd16cf18
commit
8c4b8f215f
6
chat.cpp
6
chat.cpp
@ -202,6 +202,12 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
||||
stream >> m_userName;
|
||||
emit nameChanged();
|
||||
stream >> m_savedModelName;
|
||||
|
||||
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
|
||||
// unfortunately, we cannot deserialize these
|
||||
if (version < 2 && m_savedModelName.contains("gpt4all-j"))
|
||||
return false;
|
||||
|
||||
if (!m_llmodel->deserialize(stream, version))
|
||||
return false;
|
||||
if (!m_chatModel->deserialize(stream, version))
|
||||
|
@ -5,7 +5,7 @@
|
||||
#include <QDataStream>
|
||||
|
||||
#define CHAT_FORMAT_MAGIC 0xF5D553CC
|
||||
#define CHAT_FORMAT_VERSION 1
|
||||
#define CHAT_FORMAT_VERSION 2
|
||||
|
||||
ChatListModel::ChatListModel(QObject *parent)
|
||||
: QAbstractListModel(parent)
|
||||
|
21
chatllm.cpp
21
chatllm.cpp
@ -16,6 +16,10 @@
|
||||
|
||||
//#define DEBUG
|
||||
|
||||
#define MPT_INTERNAL_STATE_VERSION 0
|
||||
#define GPTJ_INTERNAL_STATE_VERSION 0
|
||||
#define LLAMA_INTERNAL_STATE_VERSION 0
|
||||
|
||||
static QString modelFilePath(const QString &modelName)
|
||||
{
|
||||
QString appPath = QCoreApplication::applicationDirPath()
|
||||
@ -96,12 +100,15 @@ bool ChatLLM::loadModel(const QString &modelName)
|
||||
isGPTJ = magic == 0x67676d6c;
|
||||
isMPT = magic == 0x67676d6d;
|
||||
if (isGPTJ) {
|
||||
m_modelType = ModelType::GPTJ_;
|
||||
m_llmodel = new GPTJ;
|
||||
m_llmodel->loadModel(filePath.toStdString());
|
||||
} else if (isMPT) {
|
||||
m_modelType = ModelType::MPT_;
|
||||
m_llmodel = new MPT;
|
||||
m_llmodel->loadModel(filePath.toStdString());
|
||||
} else {
|
||||
m_modelType = ModelType::LLAMA_;
|
||||
m_llmodel = new LLamaModel;
|
||||
m_llmodel->loadModel(filePath.toStdString());
|
||||
}
|
||||
@ -380,6 +387,15 @@ bool ChatLLM::handleNameRecalculate(bool isRecalc)
|
||||
|
||||
bool ChatLLM::serialize(QDataStream &stream, int version)
|
||||
{
|
||||
if (version > 1) {
|
||||
stream << m_modelType;
|
||||
switch (m_modelType) {
|
||||
case MPT_: stream << MPT_INTERNAL_STATE_VERSION; break;
|
||||
case GPTJ_: stream << GPTJ_INTERNAL_STATE_VERSION; break;
|
||||
case LLAMA_: stream << LLAMA_INTERNAL_STATE_VERSION; break;
|
||||
default: Q_UNREACHABLE();
|
||||
}
|
||||
}
|
||||
stream << response();
|
||||
stream << generatedName();
|
||||
stream << m_promptResponseTokens;
|
||||
@ -400,6 +416,11 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
|
||||
|
||||
bool ChatLLM::deserialize(QDataStream &stream, int version)
|
||||
{
|
||||
if (version > 1) {
|
||||
int internalStateVersion;
|
||||
stream >> m_modelType;
|
||||
stream >> internalStateVersion; // for future use
|
||||
}
|
||||
QString response;
|
||||
stream >> response;
|
||||
m_response = response.toStdString();
|
||||
|
@ -17,6 +17,12 @@ class ChatLLM : public QObject
|
||||
Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
|
||||
|
||||
public:
|
||||
enum ModelType {
|
||||
MPT_,
|
||||
GPTJ_,
|
||||
LLAMA_
|
||||
};
|
||||
|
||||
ChatLLM(Chat *parent);
|
||||
|
||||
bool isModelLoaded() const;
|
||||
@ -82,6 +88,7 @@ private:
|
||||
quint32 m_promptResponseTokens;
|
||||
quint32 m_responseLogits;
|
||||
QString m_modelName;
|
||||
ModelType m_modelType;
|
||||
Chat *m_chat;
|
||||
QByteArray m_state;
|
||||
QThread m_llmThread;
|
||||
|
@ -352,7 +352,7 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
|
||||
const int n_mem = n_layer*n_ctx;
|
||||
const int n_elements = n_embd*n_mem;
|
||||
|
||||
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F32, model.hparams.n_ctx)) {
|
||||
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
|
||||
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
|
Loading…
Reference in New Issue
Block a user