trust_remote_code=shared.args.trust_remote_code (#1891)

This commit is contained in:
camenduru 2023-05-07 23:42:44 +03:00 committed by GitHub
parent b3bbda22d1
commit ba65a48ec8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -34,7 +34,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
def noop(*args, **kwargs):
pass
config = AutoConfig.from_pretrained(model)
config = AutoConfig.from_pretrained(model, trust_remote_code=shared.args.trust_remote_code)
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
@ -42,7 +42,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=shared.args.trust_remote_code)
torch.set_default_dtype(torch.float)
if eval:
model = model.eval()