diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 5c947762..abfa33af 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -65,13 +65,11 @@ def load_quantized(model_name): else: model_type = shared.args.model_type.lower() - if shared.args.pre_layer: - if model_type == 'llama': - load_quant = llama_inference_offload.load_quant - else: - print("Warning: ignoring --pre_layer because it only works for llama model type.") - load_quant = _load_quant + if shared.args.pre_layer and model_type == 'llama': + load_quant = llama_inference_offload.load_quant elif model_type in ('llama', 'opt', 'gptj'): + if shared.args.pre_layer: + print("Warning: ignoring --pre_layer because it only works for llama model type.") load_quant = _load_quant else: print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")