mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Add various checks to model loading functions
This commit is contained in:
parent
abd361b3a0
commit
ef10ffc6b4
@ -13,19 +13,18 @@ def load_quantized(model_name):
|
||||
use_safetensors = False
|
||||
|
||||
# Find the model checkpoint
|
||||
found_pts = list(path_to_model.glob("*.pt"))
|
||||
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
||||
if len(found_safetensors) > 0:
|
||||
if len(found_safetensors) > 1:
|
||||
logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.')
|
||||
for ext in ['.safetensors', '.pt', '.bin']:
|
||||
found = list(path_to_model.glob(f"*{ext}"))
|
||||
if len(found) > 0:
|
||||
if len(found) > 1:
|
||||
logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
||||
|
||||
use_safetensors = True
|
||||
pt_path = found_safetensors[-1]
|
||||
elif len(found_pts) > 0:
|
||||
if len(found_pts) > 1:
|
||||
logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.')
|
||||
pt_path = found[-1]
|
||||
break
|
||||
|
||||
pt_path = found_pts[-1]
|
||||
if pt_path is None:
|
||||
logging.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.")
|
||||
return
|
||||
|
||||
# Define the params for AutoGPTQForCausalLM.from_quantized
|
||||
params = {
|
||||
|
@ -40,10 +40,14 @@ if shared.args.deepspeed:
|
||||
# Some models require special treatment in various parts of the code.
|
||||
# This function detects those models
|
||||
def find_model_type(model_name):
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if not path_to_model.exists():
|
||||
return 'None'
|
||||
|
||||
model_name_lower = model_name.lower()
|
||||
if 'rwkv-' in model_name_lower:
|
||||
return 'rwkv'
|
||||
elif len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))) > 0:
|
||||
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
|
||||
return 'llamacpp'
|
||||
elif re.match('.*ggml.*\.bin', model_name_lower):
|
||||
return 'llamacpp'
|
||||
@ -58,7 +62,7 @@ def find_model_type(model_name):
|
||||
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), trust_remote_code=shared.args.trust_remote_code)
|
||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
||||
# Not a "catch all", but fairly accurate
|
||||
if config.to_dict().get("is_encoder_decoder", False):
|
||||
return 'HF_seq2seq'
|
||||
@ -71,11 +75,14 @@ def load_model(model_name):
|
||||
t0 = time.time()
|
||||
|
||||
shared.model_type = find_model_type(model_name)
|
||||
if shared.args.wbits > 0 or shared.args.autogptq:
|
||||
if shared.args.autogptq:
|
||||
load_func = AutoGPTQ_loader
|
||||
else:
|
||||
load_func = GPTQ_loader
|
||||
if shared.model_type == 'None':
|
||||
logging.error('The path to the model does not exist. Exiting.')
|
||||
return None, None
|
||||
|
||||
if shared.args.autogptq:
|
||||
load_func = AutoGPTQ_loader
|
||||
elif shared.args.wbits > 0:
|
||||
load_func = GPTQ_loader
|
||||
elif shared.model_type == 'llamacpp':
|
||||
load_func = llamacpp_loader
|
||||
elif shared.model_type == 'rwkv':
|
||||
@ -101,6 +108,7 @@ def load_model(model_name):
|
||||
|
||||
|
||||
def load_tokenizer(model_name, model):
|
||||
tokenizer = None
|
||||
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif type(model) is transformers.LlamaForCausalLM:
|
||||
@ -122,7 +130,9 @@ def load_tokenizer(model_name, model):
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=shared.args.trust_remote_code)
|
||||
path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
|
||||
if path_to_model.exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user