gpt4all-backend: Add llmodel create and destroy functions (#554)

* Add llmodel create and destroy functions

* Fix capitalization

* Fix capitalization

* Fix capitalization

* Update CMakeLists.txt

---------

Co-authored-by: kuvaus <kuvaus@users.noreply.github.com>
This commit is contained in:
kuvaus 2023-05-16 18:36:46 +03:00 committed by GitHub
parent 507e913faf
commit 3cb6dd7a66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 1 deletions

View File

@ -17,7 +17,7 @@ include_directories("${CMAKE_CURRENT_BINARY_DIR}")
set(LLMODEL_VERSION_MAJOR 0)
set(LLMODEL_VERSION_MINOR 1)
set(LLMODEL_VERSION_PATCH 0)
set(LLMODEL_VERSION_PATCH 1)
set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}")
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)

View File

@ -51,6 +51,31 @@ void llmodel_llama_destroy(llmodel_model llama)
delete wrapper;
}
llmodel_model llmodel_model_create(const char *model_path) {
uint32_t magic;
llmodel_model model;
FILE *f = fopen(model_path, "rb");
fread(&magic, sizeof(magic), 1, f);
if (magic == 0x67676d6c) { model = llmodel_gptj_create(); }
if (magic == 0x67676a74) { model = llmodel_llama_create(); }
if (magic == 0x67676d6d) { model = llmodel_mpt_create(); }
else {fprintf(stderr, "Invalid model file\n");}
fclose(f);
return model;
}
void llmodel_model_destroy(llmodel_model model) {
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);
const std::type_info &modelTypeInfo = typeid(*wrapper->llModel);
if (modelTypeInfo == typeid(GPTJ)) { llmodel_gptj_destroy(model); }
if (modelTypeInfo == typeid(LLamaModel)) { llmodel_llama_destroy(model); }
if (modelTypeInfo == typeid(MPT)) { llmodel_mpt_destroy(model); }
}
bool llmodel_loadModel(llmodel_model model, const char *model_path)
{
LLModelWrapper *wrapper = reinterpret_cast<LLModelWrapper*>(model);

View File

@ -95,6 +95,22 @@ llmodel_model llmodel_llama_create();
*/
void llmodel_llama_destroy(llmodel_model llama);
/**
* Create a llmodel instance.
* Recognises correct model type from file at model_path
* @param model_path A string representing the path to the model file.
* @return A pointer to the llmodel_model instance.
*/
llmodel_model llmodel_model_create(const char *model_path);
/**
* Destroy a llmodel instance.
* Recognises correct model type using type info
* @param model a pointer to a llmodel_model instance.
*/
void llmodel_model_destroy(llmodel_model model);
/**
* Load a model from a file.
* @param model A pointer to the llmodel_model instance.