Make --trust-remote-code work for all models (#1772)

This commit is contained in:
Mylo 2023-05-04 07:01:28 +02:00 committed by GitHub
parent 0e6d17304a
commit bd531c2dc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -57,7 +57,7 @@ def find_model_type(model_name):
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])): elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
return 'gpt4chan' return 'gpt4chan'
else: else:
config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}')) config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), trust_remote_code=shared.args.trust_remote_code)
# Not a "catch all", but fairly accurate # Not a "catch all", but fairly accurate
if config.to_dict().get("is_encoder_decoder", False): if config.to_dict().get("is_encoder_decoder", False):
return 'HF_seq2seq' return 'HF_seq2seq'
@ -70,15 +70,13 @@ def load_model(model_name):
t0 = time.time() t0 = time.time()
shared.model_type = find_model_type(model_name) shared.model_type = find_model_type(model_name)
trust_remote_code = shared.args.trust_remote_code
if shared.model_type == 'chatglm': if shared.model_type == 'chatglm':
LoaderClass = AutoModel LoaderClass = AutoModel
trust_remote_code = shared.args.trust_remote_code
elif shared.model_type == 'HF_seq2seq': elif shared.model_type == 'HF_seq2seq':
LoaderClass = AutoModelForSeq2SeqLM LoaderClass = AutoModelForSeq2SeqLM
trust_remote_code = False
else: else:
LoaderClass = AutoModelForCausalLM LoaderClass = AutoModelForCausalLM
trust_remote_code = False
# Load the model in simple 16-bit mode by default # Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]): if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]):