mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
updated bindings code for updated C api
This commit is contained in:
parent
f0be66a221
commit
ae42805d49
@ -30,20 +30,12 @@ class GPT4All():
|
|||||||
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
|
model_name: Name of GPT4All or custom model. Including ".bin" file extension is optional but encouraged.
|
||||||
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
model_path: Path to directory containing model file or, if file does not exist, where to download model.
|
||||||
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
Default is None, in which case models will be stored in `~/.cache/gpt4all/`.
|
||||||
model_type: Model architecture to use - currently, options are 'llama', 'gptj', or 'mpt'. Only required if model
|
model_type: Model architecture. This argument currently does not have any functionality and is just used as
|
||||||
is custom. Note that these models still must be built from llama.cpp or GPTJ ggml architecture.
|
descriptive identifier for user. Default is None.
|
||||||
Default is None.
|
|
||||||
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
allow_download: Allow API to download models from gpt4all.io. Default is True.
|
||||||
"""
|
"""
|
||||||
self.model = None
|
self.model_type = model_type
|
||||||
|
self.model = pyllmodel.LLModel()
|
||||||
# Model type provided for when model is custom
|
|
||||||
if model_type:
|
|
||||||
self.model = GPT4All.get_model_from_type(model_type)
|
|
||||||
# Else get model from gpt4all model filenames
|
|
||||||
else:
|
|
||||||
self.model = GPT4All.get_model_from_name(model_name)
|
|
||||||
|
|
||||||
# Retrieve model and download if allowed
|
# Retrieve model and download if allowed
|
||||||
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
||||||
self.model.load_model(model_dest)
|
self.model.load_model(model_dest)
|
||||||
@ -265,61 +257,6 @@ class GPT4All():
|
|||||||
|
|
||||||
return full_prompt
|
return full_prompt
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_model_from_type(model_type: str) -> pyllmodel.LLModel:
|
|
||||||
# This needs to be updated for each new model type
|
|
||||||
# TODO: Might be worth converting model_type to enum
|
|
||||||
|
|
||||||
if model_type == "gptj":
|
|
||||||
return pyllmodel.GPTJModel()
|
|
||||||
elif model_type == "llama":
|
|
||||||
return pyllmodel.LlamaModel()
|
|
||||||
elif model_type == "mpt":
|
|
||||||
return pyllmodel.MPTModel()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"No corresponding model for model_type: {model_type}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_model_from_name(model_name: str) -> pyllmodel.LLModel:
|
|
||||||
# This needs to be updated for each new model
|
|
||||||
|
|
||||||
# NOTE: We are doing this preprocessing a lot, maybe there's a better way to organize
|
|
||||||
model_name = append_bin_suffix_if_missing(model_name)
|
|
||||||
|
|
||||||
GPTJ_MODELS = [
|
|
||||||
"ggml-gpt4all-j-v1.3-groovy.bin",
|
|
||||||
"ggml-gpt4all-j-v1.2-jazzy.bin",
|
|
||||||
"ggml-gpt4all-j-v1.1-breezy.bin",
|
|
||||||
"ggml-gpt4all-j.bin"
|
|
||||||
]
|
|
||||||
|
|
||||||
LLAMA_MODELS = [
|
|
||||||
"ggml-gpt4all-l13b-snoozy.bin",
|
|
||||||
"ggml-vicuna-7b-1.1-q4_2.bin",
|
|
||||||
"ggml-vicuna-13b-1.1-q4_2.bin",
|
|
||||||
"ggml-wizardLM-7B.q4_2.bin",
|
|
||||||
"ggml-stable-vicuna-13B.q4_2.bin",
|
|
||||||
"ggml-nous-gpt4-vicuna-13b.bin"
|
|
||||||
]
|
|
||||||
|
|
||||||
MPT_MODELS = [
|
|
||||||
"ggml-mpt-7b-base.bin",
|
|
||||||
"ggml-mpt-7b-chat.bin",
|
|
||||||
"ggml-mpt-7b-instruct.bin"
|
|
||||||
]
|
|
||||||
|
|
||||||
if model_name in GPTJ_MODELS:
|
|
||||||
return pyllmodel.GPTJModel()
|
|
||||||
elif model_name in LLAMA_MODELS:
|
|
||||||
return pyllmodel.LlamaModel()
|
|
||||||
elif model_name in MPT_MODELS:
|
|
||||||
return pyllmodel.MPTModel()
|
|
||||||
|
|
||||||
err_msg = (f"No corresponding model for provided filename {model_name}.\n"
|
|
||||||
f"If this is a custom model, make sure to specify a valid model_type.\n")
|
|
||||||
|
|
||||||
raise ValueError(err_msg)
|
|
||||||
|
|
||||||
|
|
||||||
def append_bin_suffix_if_missing(model_name):
|
def append_bin_suffix_if_missing(model_name):
|
||||||
if not model_name.endswith(".bin"):
|
if not model_name.endswith(".bin"):
|
||||||
|
@ -54,19 +54,9 @@ def load_llmodel_library():
|
|||||||
|
|
||||||
llmodel, llama = load_llmodel_library()
|
llmodel, llama = load_llmodel_library()
|
||||||
|
|
||||||
# Define C function signatures using ctypes
|
class LLModelError(ctypes.Structure):
|
||||||
llmodel.llmodel_gptj_create.restype = ctypes.c_void_p
|
_fields_ = [("message", ctypes.c_char_p),
|
||||||
llmodel.llmodel_gptj_destroy.argtypes = [ctypes.c_void_p]
|
("code", ctypes.c_int32)]
|
||||||
llmodel.llmodel_llama_create.restype = ctypes.c_void_p
|
|
||||||
llmodel.llmodel_llama_destroy.argtypes = [ctypes.c_void_p]
|
|
||||||
llmodel.llmodel_mpt_create.restype = ctypes.c_void_p
|
|
||||||
llmodel.llmodel_mpt_destroy.argtypes = [ctypes.c_void_p]
|
|
||||||
|
|
||||||
|
|
||||||
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
|
||||||
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
|
||||||
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
|
||||||
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
|
||||||
|
|
||||||
class LLModelPromptContext(ctypes.Structure):
|
class LLModelPromptContext(ctypes.Structure):
|
||||||
_fields_ = [("logits", ctypes.POINTER(ctypes.c_float)),
|
_fields_ = [("logits", ctypes.POINTER(ctypes.c_float)),
|
||||||
@ -84,6 +74,16 @@ class LLModelPromptContext(ctypes.Structure):
|
|||||||
("repeat_last_n", ctypes.c_int32),
|
("repeat_last_n", ctypes.c_int32),
|
||||||
("context_erase", ctypes.c_float)]
|
("context_erase", ctypes.c_float)]
|
||||||
|
|
||||||
|
# Define C function signatures using ctypes
|
||||||
|
|
||||||
|
llmodel.llmodel_model_create2.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.POINTER(LLModelError)]
|
||||||
|
llmodel.llmodel_model_create2.restype = ctypes.c_void_p
|
||||||
|
|
||||||
|
llmodel.llmodel_loadModel.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
|
||||||
|
llmodel.llmodel_loadModel.restype = ctypes.c_bool
|
||||||
|
llmodel.llmodel_isModelLoaded.argtypes = [ctypes.c_void_p]
|
||||||
|
llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool
|
||||||
|
|
||||||
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32)
|
||||||
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p)
|
||||||
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool)
|
||||||
@ -113,18 +113,17 @@ class LLModel:
|
|||||||
----------
|
----------
|
||||||
model: llmodel_model
|
model: llmodel_model
|
||||||
Ctype pointer to underlying model
|
Ctype pointer to underlying model
|
||||||
model_type : str
|
model_name: str
|
||||||
Model architecture identifier
|
Model name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type: str = None
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model = None
|
self.model = None
|
||||||
self.model_name = None
|
self.model_name = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
pass
|
if self.model is not None and llmodel is not None:
|
||||||
|
llmodel.llmodel_model_destroy(self.model)
|
||||||
|
|
||||||
def load_model(self, model_path: str) -> bool:
|
def load_model(self, model_path: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -139,7 +138,10 @@ class LLModel:
|
|||||||
-------
|
-------
|
||||||
True if model loaded successfully, False otherwise
|
True if model loaded successfully, False otherwise
|
||||||
"""
|
"""
|
||||||
llmodel.llmodel_loadModel(self.model, model_path.encode('utf-8'))
|
model_path_enc = model_path.encode("utf-8")
|
||||||
|
build_var = "auto".encode("utf-8")
|
||||||
|
self.model = llmodel.llmodel_model_create2(model_path_enc, build_var, None)
|
||||||
|
llmodel.llmodel_loadModel(self.model, model_path_enc)
|
||||||
filename = os.path.basename(model_path)
|
filename = os.path.basename(model_path)
|
||||||
self.model_name = os.path.splitext(filename)[0]
|
self.model_name = os.path.splitext(filename)[0]
|
||||||
|
|
||||||
@ -148,7 +150,6 @@ class LLModel:
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def set_thread_count(self, n_threads):
|
def set_thread_count(self, n_threads):
|
||||||
if not llmodel.llmodel_isModelLoaded(self.model):
|
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||||
raise Exception("Model not loaded")
|
raise Exception("Model not loaded")
|
||||||
@ -159,7 +160,6 @@ class LLModel:
|
|||||||
raise Exception("Model not loaded")
|
raise Exception("Model not loaded")
|
||||||
return llmodel.llmodel_threadCount(self.model)
|
return llmodel.llmodel_threadCount(self.model)
|
||||||
|
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
logits_size: int = 0,
|
logits_size: int = 0,
|
||||||
@ -246,45 +246,3 @@ class LLModel:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _recalculate_callback(is_recalculating):
|
def _recalculate_callback(is_recalculating):
|
||||||
return is_recalculating
|
return is_recalculating
|
||||||
|
|
||||||
|
|
||||||
class GPTJModel(LLModel):
|
|
||||||
|
|
||||||
model_type = "gptj"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.model = llmodel.llmodel_gptj_create()
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self.model is not None and llmodel is not None:
|
|
||||||
llmodel.llmodel_gptj_destroy(self.model)
|
|
||||||
super().__del__()
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaModel(LLModel):
|
|
||||||
|
|
||||||
model_type = "llama"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.model = llmodel.llmodel_llama_create()
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self.model is not None and llmodel is not None:
|
|
||||||
llmodel.llmodel_llama_destroy(self.model)
|
|
||||||
super().__del__()
|
|
||||||
|
|
||||||
|
|
||||||
class MPTModel(LLModel):
|
|
||||||
|
|
||||||
model_type = "mpt"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.model = llmodel.llmodel_mpt_create()
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self.model is not None and llmodel is not None:
|
|
||||||
llmodel.llmodel_mpt_destroy(self.model)
|
|
||||||
super().__del__()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user