From 1f749d7633a67f3272a7dff088b2e2c7e415fb45 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sat, 8 Jul 2023 10:04:38 -0400 Subject: [PATCH] Clean up backend code a bit and hide impl. details. --- gpt4all-backend/llmodel.cpp | 64 ++++++++++++++++-------------- gpt4all-backend/llmodel.h | 51 +++++++++++------------- gpt4all-backend/llmodel_c.cpp | 6 +-- gpt4all-backend/llmodel_shared.cpp | 12 +++--- gpt4all-chat/chatllm.cpp | 8 ++-- gpt4all-chat/llm.cpp | 2 +- 6 files changed, 72 insertions(+), 71 deletions(-) diff --git a/gpt4all-backend/llmodel.cpp b/gpt4all-backend/llmodel.cpp index 5dd33535..d9300f04 100644 --- a/gpt4all-backend/llmodel.cpp +++ b/gpt4all-backend/llmodel.cpp @@ -41,41 +41,42 @@ static bool requires_avxonly() { #endif } -LLImplementation::LLImplementation(Dlhandle &&dlhandle_) : dlhandle(new Dlhandle(std::move(dlhandle_))) { - auto get_model_type = dlhandle->get("get_model_type"); +LLMImplementation::LLMImplementation(Dlhandle &&dlhandle_) + : m_dlhandle(new Dlhandle(std::move(dlhandle_))) { + auto get_model_type = m_dlhandle->get("get_model_type"); assert(get_model_type); - modelType = get_model_type(); - auto get_build_variant = dlhandle->get("get_build_variant"); + m_modelType = get_model_type(); + auto get_build_variant = m_dlhandle->get("get_build_variant"); assert(get_build_variant); - buildVariant = get_build_variant(); - magicMatch = dlhandle->get("magic_match"); + m_buildVariant = get_build_variant(); + m_magicMatch = m_dlhandle->get("magic_match"); assert(magicMatch); - construct_ = dlhandle->get("construct"); + m_construct = m_dlhandle->get("construct"); assert(construct_); } -LLImplementation::LLImplementation(LLImplementation &&o) - : construct_(o.construct_) - , modelType(o.modelType) - , buildVariant(o.buildVariant) - , magicMatch(o.magicMatch) - , dlhandle(o.dlhandle) { - o.dlhandle = nullptr; +LLMImplementation::LLMImplementation(LLMImplementation &&o) + : m_magicMatch(o.m_magicMatch) + , m_construct(o.m_construct) + , m_modelType(o.m_modelType) + , m_buildVariant(o.m_buildVariant) + , m_dlhandle(o.m_dlhandle) { + o.m_dlhandle = nullptr; } -LLImplementation::~LLImplementation() { - if (dlhandle) delete dlhandle; +LLMImplementation::~LLMImplementation() { + if (m_dlhandle) delete m_dlhandle; } -bool LLImplementation::isImplementation(const Dlhandle &dl) { +bool LLMImplementation::isImplementation(const Dlhandle &dl) { return dl.get("is_g4a_backend_model_implementation"); } -const std::vector &LLModel::implementationList() { +const std::vector &LLMImplementation::implementationList() { // NOTE: allocated on heap so we leak intentionally on exit so we have a chance to clean up the // individual models without the cleanup of the static list interfering - static auto* libs = new std::vector([] () { - std::vector fres; + static auto* libs = new std::vector([] () { + std::vector fres; auto search_in_directory = [&](const std::string& paths) { std::stringstream ss(paths); @@ -90,10 +91,10 @@ const std::vector &LLModel::implementationList() { // Add to list if model implementation try { Dlhandle dl(p.string()); - if (!LLImplementation::isImplementation(dl)) { + if (!LLMImplementation::isImplementation(dl)) { continue; } - fres.emplace_back(LLImplementation(std::move(dl))); + fres.emplace_back(LLMImplementation(std::move(dl))); } catch (...) {} } } @@ -107,17 +108,17 @@ const std::vector &LLModel::implementationList() { return *libs; } -const LLImplementation* LLModel::implementation(std::ifstream& f, const std::string& buildVariant) { +const LLMImplementation* LLMImplementation::implementation(std::ifstream& f, const std::string& buildVariant) { for (const auto& i : implementationList()) { f.seekg(0); - if (!i.magicMatch(f)) continue; - if (buildVariant != i.buildVariant) continue; + if (!i.m_magicMatch(f)) continue; + if (buildVariant != i.m_buildVariant) continue; return &i; } return nullptr; } -LLModel *LLModel::construct(const std::string &modelPath, std::string buildVariant) { +LLModel *LLMImplementation::construct(const std::string &modelPath, std::string buildVariant) { if (!has_at_least_minimal_hardware()) return nullptr; @@ -126,7 +127,7 @@ LLModel *LLModel::construct(const std::string &modelPath, std::string buildVaria std::ifstream f(modelPath, std::ios::binary); if (!f) return nullptr; // Get correct implementation - const LLImplementation* impl = nullptr; + const LLMImplementation* impl = nullptr; #if defined(__APPLE__) && defined(__arm64__) // FIXME: See if metal works for intel macs if (buildVariant == "auto") { @@ -160,14 +161,17 @@ LLModel *LLModel::construct(const std::string &modelPath, std::string buildVaria if (!impl) return nullptr; } f.close(); + // Construct and return llmodel implementation - return impl->construct(); + auto fres = impl->m_construct(); + fres->m_implementation = impl; + return fres; } -void LLModel::setImplementationsSearchPath(const std::string& path) { +void LLMImplementation::setImplementationsSearchPath(const std::string& path) { s_implementations_search_path = path; } -const std::string& LLModel::implementationsSearchPath() { +const std::string& LLMImplementation::implementationsSearchPath() { return s_implementations_search_path; } diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 920bc350..a5820174 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -12,7 +12,7 @@ #define LLMODEL_MAX_PROMPT_BATCH 128 class Dlhandle; -class LLImplementation; +class LLMImplementation; class LLModel { public: using Token = int32_t; @@ -51,17 +51,10 @@ public: virtual void setThreadCount(int32_t /*n_threads*/) {} virtual int32_t threadCount() const { return 1; } - const LLImplementation& implementation() const { + const LLMImplementation& implementation() const { return *m_implementation; } - static const std::vector& implementationList(); - static const LLImplementation *implementation(std::ifstream& f, const std::string& buildVariant); - static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto"); - - static void setImplementationsSearchPath(const std::string& path); - static const std::string& implementationsSearchPath(); - protected: // These are pure virtual because subclasses need to implement as the default implementation of // 'prompt' above calls these functions @@ -76,33 +69,37 @@ protected: // shared by all base classes so it isn't virtual void recalculateContext(PromptContext &promptCtx, std::function recalculate); - const LLImplementation *m_implementation = nullptr; + const LLMImplementation *m_implementation = nullptr; private: - friend class LLImplementation; + friend class LLMImplementation; }; -class LLImplementation { - LLModel *(*construct_)(); - +class LLMImplementation { public: - LLImplementation(Dlhandle&&); - LLImplementation(const LLImplementation&) = delete; - LLImplementation(LLImplementation&&); - ~LLImplementation(); + LLMImplementation(Dlhandle&&); + LLMImplementation(const LLMImplementation&) = delete; + LLMImplementation(LLMImplementation&&); + ~LLMImplementation(); + + std::string_view modelType() const { return m_modelType; } + std::string_view buildVariant() const { return m_buildVariant; } static bool isImplementation(const Dlhandle&); + static const std::vector& implementationList(); + static const LLMImplementation *implementation(std::ifstream& f, const std::string& buildVariant); + static LLModel *construct(const std::string &modelPath, std::string buildVariant = "auto"); + static void setImplementationsSearchPath(const std::string& path); + static const std::string& implementationsSearchPath(); - std::string_view modelType, buildVariant; - bool (*magicMatch)(std::ifstream& f); - Dlhandle *dlhandle; +private: + bool (*m_magicMatch)(std::ifstream& f); + LLModel *(*m_construct)(); - // The only way an implementation should be constructed - LLModel *construct() const { - auto fres = construct_(); - fres->m_implementation = this; - return fres; - } +private: + std::string_view m_modelType; + std::string_view m_buildVariant; + Dlhandle *m_dlhandle; }; #endif // LLMODEL_H diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index 15e5e891..2364e4fa 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -29,7 +29,7 @@ llmodel_model llmodel_model_create2(const char *model_path, const char *build_va int error_code = 0; try { - wrapper->llModel = LLModel::construct(model_path, build_variant); + wrapper->llModel = LLMImplementation::construct(model_path, build_variant); } catch (const std::exception& e) { error_code = EINVAL; last_error_message = e.what(); @@ -180,10 +180,10 @@ int32_t llmodel_threadCount(llmodel_model model) void llmodel_set_implementation_search_path(const char *path) { - LLModel::setImplementationsSearchPath(path); + LLMImplementation::setImplementationsSearchPath(path); } const char *llmodel_get_implementation_search_path() { - return LLModel::implementationsSearchPath().c_str(); + return LLMImplementation::implementationsSearchPath().c_str(); } diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index cd4ace04..881ea5ec 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -33,7 +33,7 @@ void LLModel::prompt(const std::string &prompt, PromptContext &promptCtx) { if (!isModelLoaded()) { - std::cerr << implementation().modelType << " ERROR: prompt won't work with an unloaded model!\n"; + std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n"; return; } @@ -45,7 +45,7 @@ void LLModel::prompt(const std::string &prompt, if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); - std::cerr << implementation().modelType << " ERROR: The prompt is" << embd_inp.size() << + std::cerr << implementation().modelType() << " ERROR: The prompt is" << embd_inp.size() << "tokens and the context window is" << promptCtx.n_ctx << "!\n"; return; } @@ -64,7 +64,7 @@ void LLModel::prompt(const std::string &prompt, if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) { const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; // Erase the first percentage of context from the tokens... - std::cerr << implementation().modelType << ": reached the end of the context window so resizing\n"; + std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); recalculateContext(promptCtx, recalculateCallback); @@ -72,7 +72,7 @@ void LLModel::prompt(const std::string &prompt, } if (!evalTokens(promptCtx, batch)) { - std::cerr << implementation().modelType << " ERROR: Failed to process prompt\n"; + std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n"; return; } @@ -103,7 +103,7 @@ void LLModel::prompt(const std::string &prompt, if (promptCtx.n_past + 1 > promptCtx.n_ctx) { const int32_t erasePoint = promptCtx.n_ctx * promptCtx.contextErase; // Erase the first percentage of context from the tokens... - std::cerr << implementation().modelType << ": reached the end of the context window so resizing\n"; + std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); recalculateContext(promptCtx, recalculateCallback); @@ -111,7 +111,7 @@ void LLModel::prompt(const std::string &prompt, } if (!evalTokens(promptCtx, { id })) { - std::cerr << implementation().modelType << " ERROR: Failed to predict next token\n"; + std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n"; return; } diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 181b8452..fa11cdbb 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -240,11 +240,11 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) #if defined(Q_OS_MAC) && defined(__arm__) if (m_forceMetal) - m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "metal"); + m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "metal"); else - m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "auto"); + m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "auto"); #else - m_llModelInfo.model = LLModel::construct(filePath.toStdString(), "auto"); + m_llModelInfo.model = LLMImplementation::construct(filePath.toStdString(), "auto"); #endif if (m_llModelInfo.model) { @@ -258,7 +258,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) m_llModelInfo = LLModelInfo(); emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename())); } else { - switch (m_llModelInfo.model->implementation().modelType[0]) { + switch (m_llModelInfo.model->implementation().modelType()[0]) { case 'L': m_llModelType = LLModelType::LLAMA_; break; case 'G': m_llModelType = LLModelType::GPTJ_; break; case 'M': m_llModelType = LLModelType::MPT_; break; diff --git a/gpt4all-chat/llm.cpp b/gpt4all-chat/llm.cpp index f831ea47..ff62d43e 100644 --- a/gpt4all-chat/llm.cpp +++ b/gpt4all-chat/llm.cpp @@ -34,7 +34,7 @@ LLM::LLM() if (directoryExists(frameworksDir)) llmodelSearchPaths += ";" + frameworksDir; #endif - LLModel::setImplementationsSearchPath(llmodelSearchPaths.toStdString()); + LLMImplementation::setImplementationsSearchPath(llmodelSearchPaths.toStdString()); #if defined(__x86_64__) #ifndef _MSC_VER