mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
updated bindings code for updated C api
This commit is contained in:
parent
f0be66a221
commit
ae42805d49
@ -30,20 +30,12 @@ class GPT4All():
|
||||
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
|
||||
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
||||
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
||||
model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model
|
||||
is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture.
|
||||
Default is None.
|
||||
model_type: Model architecture. This argument currently does not have any functionality and is just used as
|
||||
descriptive identifier for user. Default is None.
|
||||
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
||||
"""
|
||||
self.model = None
|
||||
|
||||
# Model type provided for when model is custom
|
||||
if model_type:
|
||||
self.model = GPT4All.get_model_from_type(model_type)
|
||||
# Else get model from gpt4all model filenames
|
||||
else:
|
||||
self.model = GPT4All.get_model_from_name(model_name)
|
||||
|
||||
self.model_type = model_type
|
||||
self.model = pyllmodel.LLModel()
|
||||
# Retrieve model and download if allowed
|
||||
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
||||
self.model.load_model(model_dest)
|
||||
@ -265,61 +257,6 @@ class GPT4All():
|
||||
|
||||
return full_prompt
|
||||
|
||||
@staticmethod
|
||||
def get_model_from_type(model_type: str) -> pyllmodel.LLModel:
|
||||
# This needs to be updated for each new model type
|
||||
# TODO: Might be worth converting model_type to enum
|
||||
|
||||
if model_type == "gptj":
|
||||
return pyllmodel.GPTJModel()
|
||||
elif model_type == "llama":
|
||||
return pyllmodel.LlamaModel()
|
||||
elif model_type == "mpt":
|
||||
return pyllmodel.MPTModel()
|
||||
else:
|
||||
raise ValueError(f"No corresponding model for model_type: {model_type}")
|
||||
|
||||
@staticmethod
|
||||
def get_model_from_name(model_name: str) -> pyllmodel.LLModel:
|
||||
# This needs to be updated for each new model
|
||||
|
||||
# NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize
|
||||
model_name = append_bin_suffix_if_missing(model_name)
|
||||
|
||||
GPTJ_MODELS = [
|
||||
"ggml-gpt4all-j-v1.3-groovy.bin",
|
||||
"ggml-gpt4all-j-v1.2-jazzy.bin",
|
||||
"ggml-gpt4all-j-v1.1-breezy.bin",
|
||||
"ggml-gpt4all-j.bin"
|
||||
]
|
||||
|
||||
LLAMA_MODELS = [
|
||||
"ggml-gpt4all-l13b-snoozy.bin",
|
||||
"ggml-vicuna-7b-1.1-q4_2.bin",
|
||||
"ggml-vicuna-13b-1.1-q4_2.bin",
|
||||
"ggml-wizardLM-7B.q4_2.bin",
|
||||
"ggml-stable-vicuna-13B.q4_2.bin",
|
||||
"ggml-nous-gpt4-vicuna-13b.bin"
|
||||
]
|
||||
|
||||
MPT_MODELS = [
|
||||
"ggml-mpt-7b-base.bin",
|
||||
"ggml-mpt-7b-chat.bin",
|
||||
"ggml-mpt-7b-instruct.bin"
|
||||
]
|
||||
|
||||
if model_name in GPTJ_MODELS:
|
||||
return pyllmodel.GPTJModel()
|
||||
elif model_name in LLAMA_MODELS:
|
||||
return pyllmodel.LlamaModel()
|
||||
elif model_name in MPT_MODELS:
|
||||
return pyllmodel.MPTModel()
|
||||
|
||||
err_msg = (f"No corresponding model for provided filename {model_name}.\n"
|
||||
f"If this is a custom model, make sure to specify a valid model_type.\n")
|
||||
|
||||
raise ValueError(err_msg)
|
||||
|
||||
|
||||
def append_bin_suffix_if_missing(model_name):
|
||||
if not model_name.endswith(".bin"):
|
||||
|
@ -54,19 +54,9 @@ def load_llmodel_library():
|
||||
|
||||
llmodel, llama = load_llmodel_library()
|
||||
|
||||
# Define C function signatures using ctypes
|
||||
llmodel.llmodel_gptj_create.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_llama_create.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_mpt_create.restype = ctypes.c_void_p
|
||||
llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p]
|
||||
|
||||
|
||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
||||
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||
class LLModelError(ctypes.Structure):
|
||||
_fields_ = [("message", ctypes.c_char_p),
|
||||
("code", ctypes.c_int32)]
|
||||
|
||||
class LLModelPromptContext(ctypes.Structure):
|
||||
_fields_ = [("logits", ctypes.POINTER(ctypes.c_float)),
|
||||
@ -84,6 +74,16 @@ class LLModelPromptContext(ctypes.Structure):
|
||||
("repeat_last_n", ctypes.c_int32),
|
||||
("context_erase", ctypes.c_float)]
|
||||
|
||||
# Define C function signatures using ctypes
|
||||
|
||||
llmodel.llmodel_model_create2.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.POINTER(LLModelError)]
|
||||
llmodel.llmodel_model_create2.restype = ctypes.c_void_p
|
||||
|
||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
||||
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||
|
||||
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
||||
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
||||
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
||||
@ -113,18 +113,17 @@ class LLModel:
|
||||
----------
|
||||
model: llmodel_model
|
||||
Ctype pointer to underlying model
|
||||
model_type : str
|
||||
Model architecture identifier
|
||||
model_name: str
|
||||
Model name.
|
||||
"""
|
||||
|
||||
model_type: str = None
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.model_name = None
|
||||
|
||||
def __del__(self):
|
||||
pass
|
||||
if self.model is not None and llmodel is not None:
|
||||
llmodel.llmodel_model_destroy(self.model)
|
||||
|
||||
def load_model(self, model_path: str) -> bool:
|
||||
"""
|
||||
@ -139,7 +138,10 @@ class LLModel:
|
||||
-------
|
||||
True if model loaded successfully, False otherwise
|
||||
"""
|
||||
llmodel.llmodel_loadModel(self.model, model_path.encode('utf-8'))
|
||||
model_path_enc = model_path.encode("utf-8")
|
||||
build_var = "auto".encode("utf-8")
|
||||
self.model = llmodel.llmodel_model_create2(model_path_enc, build_var, None)
|
||||
llmodel.llmodel_loadModel(self.model, model_path_enc)
|
||||
filename = os.path.basename(model_path)
|
||||
self.model_name = os.path.splitext(filename)[0]
|
||||
|
||||
@ -148,7 +150,6 @@ class LLModel:
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def set_thread_count(self, n_threads):
|
||||
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||
raise Exception("Model not loaded")
|
||||
@ -159,7 +160,6 @@ class LLModel:
|
||||
raise Exception("Model not loaded")
|
||||
return llmodel.llmodel_threadCount(self.model)
|
||||
|
||||
|
||||
def generate(self,
|
||||
prompt: str,
|
||||
logits_size: int = 0,
|
||||
@ -246,45 +246,3 @@ class LLModel:
|
||||
@staticmethod
|
||||
def _recalculate_callback(is_recalculating):
|
||||
return is_recalculating
|
||||
|
||||
|
||||
class GPTJModel(LLModel):
|
||||
|
||||
model_type = "gptj"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = llmodel.llmodel_gptj_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None and llmodel is not None:
|
||||
llmodel.llmodel_gptj_destroy(self.model)
|
||||
super().__del__()
|
||||
|
||||
|
||||
class LlamaModel(LLModel):
|
||||
|
||||
model_type = "llama"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = llmodel.llmodel_llama_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None and llmodel is not None:
|
||||
llmodel.llmodel_llama_destroy(self.model)
|
||||
super().__del__()
|
||||
|
||||
|
||||
class MPTModel(LLModel):
|
||||
|
||||
model_type = "mpt"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = llmodel.llmodel_mpt_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None and llmodel is not None:
|
||||
llmodel.llmodel_mpt_destroy(self.model)
|
||||
super().__del__()
|
||||
|
Loading…
Reference in New Issue
Block a user