From 3cb6dd7a66213bb83823c47d69c71ea064e21cf6 Mon Sep 17 00:00:00 2001 From: kuvaus <22169537+kuvaus@users.noreply.github.com> Date: Tue, 16 May 2023 18:36:46 +0300 Subject: [PATCH] gpt4all-backend: Add llmodel create and destroy functions (#554) * Add llmodel create and destroy functions * Fix capitalization * Fix capitalization * Fix capitalization * Update CMakeLists.txt --------- Co-authored-by: kuvaus --- gpt4all-backend/CMakeLists.txt | 2 +- gpt4all-backend/llmodel_c.cpp | 25 +++++++++++++++++++++++++ gpt4all-backend/llmodel_c.h | 16 ++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/gpt4all-backend/CMakeLists.txt b/gpt4all-backend/CMakeLists.txt index a7f1c6f0..3756884f 100644 --- a/gpt4all-backend/CMakeLists.txt +++ b/gpt4all-backend/CMakeLists.txt @@ -17,7 +17,7 @@ include_directories("${CMAKE_CURRENT_BINARY_DIR}") set(LLMODEL_VERSION_MAJOR 0) set(LLMODEL_VERSION_MINOR 1) -set(LLMODEL_VERSION_PATCH 0) +set(LLMODEL_VERSION_PATCH 1) set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}") project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C) diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index 4361a900..9d5ac10a 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -51,6 +51,31 @@ void llmodel_llama_destroy(llmodel_model llama) delete wrapper; } +llmodel_model llmodel_model_create(const char *model_path) { + + uint32_t magic; + llmodel_model model; + FILE *f = fopen(model_path, "rb"); + fread(&magic, sizeof(magic), 1, f); + + if (magic == 0x67676d6c) { model = llmodel_gptj_create(); } + if (magic == 0x67676a74) { model = llmodel_llama_create(); } + if (magic == 0x67676d6d) { model = llmodel_mpt_create(); } + else {fprintf(stderr, "Invalid model file\n");} + fclose(f); + return model; +} + +void llmodel_model_destroy(llmodel_model model) { + + LLModelWrapper *wrapper = reinterpret_cast(model); + const std::type_info &modelTypeInfo = typeid(*wrapper->llModel); + + if (modelTypeInfo == typeid(GPTJ)) { llmodel_gptj_destroy(model); } + if (modelTypeInfo == typeid(LLamaModel)) { llmodel_llama_destroy(model); } + if (modelTypeInfo == typeid(MPT)) { llmodel_mpt_destroy(model); } +} + bool llmodel_loadModel(llmodel_model model, const char *model_path) { LLModelWrapper *wrapper = reinterpret_cast(model); diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index f45bdd8d..9a3c52a0 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -95,6 +95,22 @@ llmodel_model llmodel_llama_create(); */ void llmodel_llama_destroy(llmodel_model llama); +/** + * Create a llmodel instance. + * Recognises correct model type from file at model_path + * @param model_path A string representing the path to the model file. + * @return A pointer to the llmodel_model instance. + */ +llmodel_model llmodel_model_create(const char *model_path); + +/** + * Destroy a llmodel instance. + * Recognises correct model type using type info + * @param model a pointer to a llmodel_model instance. + */ +void llmodel_model_destroy(llmodel_model model); + + /** * Load a model from a file. * @param model A pointer to the llmodel_model instance.