updated bindings code for updated C api

This commit is contained in:
Richard Guo 2023-06-01 10:45:58 -04:00 committed by AT
parent f0be66a221
commit ae42805d49
2 changed files with 27 additions and 132 deletions

View File

@ -22,7 +22,7 @@ class GPT4All():
model: Pointer to underlying C model.
"""
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True):
def __init__(self, model_name: str, model_path: str = None, model_type: str = None, allow_download=True):
"""
Constructor
@ -30,20 +30,12 @@ class GPT4All():
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
model_path: Path to directory containing model file or, if file does not exist, where to download model.
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model
is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture.
Default is None.
model_type: Model architecture. This argument currently does not have any functionality and is just used as
descriptive identifier for user. Default is None.
allow_download: Allow API to download models from gpt4all.io. Default is True.
"""
self.model = None
# Model type provided for when model is custom
if model_type:
self.model = GPT4All.get_model_from_type(model_type)
# Else get model from gpt4all model filenames
else:
self.model = GPT4All.get_model_from_name(model_name)
self.model_type = model_type
self.model = pyllmodel.LLModel()
# Retrieve model and download if allowed
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
self.model.load_model(model_dest)
@ -265,61 +257,6 @@ class GPT4All():
return full_prompt
@staticmethod
def get_model_from_type(model_type: str) -> pyllmodel.LLModel:
# This needs to be updated for each new model type
# TODO: Might be worth converting model_type to enum
if model_type == "gptj":
return pyllmodel.GPTJModel()
elif model_type == "llama":
return pyllmodel.LlamaModel()
elif model_type == "mpt":
return pyllmodel.MPTModel()
else:
raise ValueError(f"No corresponding model for model_type: {model_type}")
@staticmethod
def get_model_from_name(model_name: str) -> pyllmodel.LLModel:
# This needs to be updated for each new model
# NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize
model_name = append_bin_suffix_if_missing(model_name)
GPTJ_MODELS = [
"ggml-gpt4all-j-v1.3-groovy.bin",
"ggml-gpt4all-j-v1.2-jazzy.bin",
"ggml-gpt4all-j-v1.1-breezy.bin",
"ggml-gpt4all-j.bin"
]
LLAMA_MODELS = [
"ggml-gpt4all-l13b-snoozy.bin",
"ggml-vicuna-7b-1.1-q4_2.bin",
"ggml-vicuna-13b-1.1-q4_2.bin",
"ggml-wizardLM-7B.q4_2.bin",
"ggml-stable-vicuna-13B.q4_2.bin",
"ggml-nous-gpt4-vicuna-13b.bin"
]
MPT_MODELS = [
"ggml-mpt-7b-base.bin",
"ggml-mpt-7b-chat.bin",
"ggml-mpt-7b-instruct.bin"
]
if model_name in GPTJ_MODELS:
return pyllmodel.GPTJModel()
elif model_name in LLAMA_MODELS:
return pyllmodel.LlamaModel()
elif model_name in MPT_MODELS:
return pyllmodel.MPTModel()
err_msg = (f"No corresponding model for provided filename {model_name}.\n"
f"If this is a custom model, make sure to specify a valid model_type.\n")
raise ValueError(err_msg)
def append_bin_suffix_if_missing(model_name):
if not model_name.endswith(".bin"):

View File

@ -54,19 +54,9 @@ def load_llmodel_library():
llmodel, llama = load_llmodel_library()
# Define C function signatures using ctypes
llmodel.llmodel_gptj_create.restype = ctypes.c_void_p
llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_llama_create.restype = ctypes.c_void_p
llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_mpt_create.restype = ctypes.c_void_p
llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p]
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
llmodel.llmodel_loadModel.restype = ctypes.c_bool
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
class LLModelError(ctypes.Structure):
_fields_ = [("message", ctypes.c_char_p),
("code", ctypes.c_int32)]
class LLModelPromptContext(ctypes.Structure):
_fields_ = [("logits", ctypes.POINTER(ctypes.c_float)),
@ -83,7 +73,17 @@ class LLModelPromptContext(ctypes.Structure):
("repeat_penalty", ctypes.c_float),
("repeat_last_n", ctypes.c_int32),
("context_erase", ctypes.c_float)]
# Define C function signatures using ctypes
llmodel.llmodel_model_create2.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.POINTER(LLModelError)]
llmodel.llmodel_model_create2.restype = ctypes.c_void_p
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
llmodel.llmodel_loadModel.restype = ctypes.c_bool
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
@ -113,18 +113,17 @@ class LLModel:
----------
model: llmodel_model
Ctype pointer to underlying model
model_type : str
Model architecture identifier
model_name: str
Model name.
"""
model_type: str = None
def __init__(self):
self.model = None
self.model_name = None
def __del__(self):
pass
if self.model is not None and llmodel is not None:
llmodel.llmodel_model_destroy(self.model)
def load_model(self, model_path: str) -> bool:
"""
@ -139,7 +138,10 @@ class LLModel:
-------
True if model loaded successfully, False otherwise
"""
llmodel.llmodel_loadModel(self.model, model_path.encode('utf-8'))
model_path_enc = model_path.encode("utf-8")
build_var = "auto".encode("utf-8")
self.model = llmodel.llmodel_model_create2(model_path_enc, build_var, None)
llmodel.llmodel_loadModel(self.model, model_path_enc)
filename = os.path.basename(model_path)
self.model_name = os.path.splitext(filename)[0]
@ -148,7 +150,6 @@ class LLModel:
else:
return False
def set_thread_count(self, n_threads):
if not llmodel.llmodel_isModelLoaded(self.model):
raise Exception("Model not loaded")
@ -159,7 +160,6 @@ class LLModel:
raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model)
def generate(self,
prompt: str,
logits_size: int = 0,
@ -246,45 +246,3 @@ class LLModel:
@staticmethod
def _recalculate_callback(is_recalculating):
return is_recalculating
class GPTJModel(LLModel):
model_type = "gptj"
def __init__(self):
super().__init__()
self.model = llmodel.llmodel_gptj_create()
def __del__(self):
if self.model is not None and llmodel is not None:
llmodel.llmodel_gptj_destroy(self.model)
super().__del__()
class LlamaModel(LLModel):
model_type = "llama"
def __init__(self):
super().__init__()
self.model = llmodel.llmodel_llama_create()
def __del__(self):
if self.model is not None and llmodel is not None:
llmodel.llmodel_llama_destroy(self.model)
super().__del__()
class MPTModel(LLModel):
model_type = "mpt"
def __init__(self):
super().__init__()
self.model = llmodel.llmodel_mpt_create()
def __del__(self):
if self.model is not None and llmodel is not None:
llmodel.llmodel_mpt_destroy(self.model)
super().__del__()