trust_remote_code=shared.args.trust_remote_code (#1891)

This commit is contained in:
camenduru 2023-05-07 23:42:44 +03:00 committed by GitHub
parent b3bbda22d1
commit ba65a48ec8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -34,7 +34,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
def noop(*args, **kwargs): def noop(*args, **kwargs):
pass pass
config = AutoConfig.from_pretrained(model) config = AutoConfig.from_pretrained(model, trust_remote_code=shared.args.trust_remote_code)
torch.nn.init.kaiming_uniform_ = noop torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop torch.nn.init.normal_ = noop
@ -42,7 +42,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
torch.set_default_dtype(torch.half) torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half) torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config, trust_remote_code=shared.args.trust_remote_code)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
if eval: if eval:
model = model.eval() model = model.eval()