From 3d6cb5ed63daf77c970b36f716f5219cccaef06e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 01:21:40 -0300 Subject: [PATCH] Minor rewrite --- modules/GPTQ_loader.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 5c947762..abfa33af 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -65,13 +65,11 @@ def load_quantized(model_name): else: model_type = shared.args.model_type.lower() - if shared.args.pre_layer: - if model_type == 'llama': - load_quant = llama_inference_offload.load_quant - else: - print("Warning: ignoring --pre_layer because it only works for llama model type.") - load_quant = _load_quant + if shared.args.pre_layer and model_type == 'llama': + load_quant = llama_inference_offload.load_quant elif model_type in ('llama', 'opt', 'gptj'): + if shared.args.pre_layer: + print("Warning: ignoring --pre_layer because it only works for llama model type.") load_quant = _load_quant else: print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")