diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 3b062ea3..5c947762 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -65,8 +65,12 @@ def load_quantized(model_name): else: model_type = shared.args.model_type.lower() - if model_type == 'llama' and shared.args.pre_layer: - load_quant = llama_inference_offload.load_quant + if shared.args.pre_layer: + if model_type == 'llama': + load_quant = llama_inference_offload.load_quant + else: + print("Warning: ignoring --pre_layer because it only works for llama model type.") + load_quant = _load_quant elif model_type in ('llama', 'opt', 'gptj'): load_quant = _load_quant else: @@ -107,7 +111,7 @@ def load_quantized(model_name): exit() # qwopqwop200's offload - if shared.args.pre_layer: + if model_type == 'llama' and shared.args.pre_layer: model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) else: threshold = False if model_type == 'gptj' else 128