Generalized load_quantized

This commit is contained in:
Maya Eary 2023-03-28 20:38:55 +03:00
parent 46f6536fae
commit c8207d474f

View File

@ -4,13 +4,48 @@ from pathlib import Path
import accelerate
import torch
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
import modules.shared as shared
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
import llama
import llama_inference_offload
import opt
from quant import make_quant
from modelutils import find_layers
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head']):
config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in exclude_layers:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize, faster=faster_kernel)
del layers
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
print('Done.')
return model
def load_quantized(model_name):
@ -20,6 +55,8 @@ def load_quantized(model_name):
model_type = 'llama'
elif model_name.lower().startswith(('opt', 'galactica')):
model_type = 'opt'
elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')):
model_type = 'gptj'
else:
print("Can't determine model type from model name. Please specify it manually using --model_type "
"argument")
@ -27,15 +64,12 @@ def load_quantized(model_name):
else:
model_type = shared.args.model_type.lower()
if model_type == 'llama':
if not shared.args.pre_layer:
load_quant = llama.load_quant
else:
load_quant = llama_inference_offload.load_quant
elif model_type == 'opt':
load_quant = opt.load_quant
if model_type == 'llama' and shared.args.pre_layer:
oad_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'):
load_quant = _load_quant
else:
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit()
# Now we are going to try to locate the quantized model file.