From dfd9ba3e90a8f9306ade30a0076b7f5841cf0df9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 10 May 2023 02:07:22 -0300 Subject: [PATCH] Remove duplicate code --- models/config.yaml | 22 +++++++++++----------- modules/GPTQ_loader.py | 20 +++++++++----------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/models/config.yaml b/models/config.yaml index 806206f2..f5c9d508 100644 --- a/models/config.yaml +++ b/models/config.yaml @@ -6,11 +6,12 @@ mode: 'chat' skip_special_tokens: true custom_stopping_strings: '' -.*llama: +.*(llama|alpac|vicuna|guanaco|koala|llava|wizardlm|metharme|pygmalion-7b): model_type: 'llama' -.*gptq(?!u|arl|v2): - wbits: 4 - groupsize: 128 +.*(opt-|opt_|opt1|opt3|optfor|galactica|galpaca|pygmalion-350m): + model_type: 'opt' +.*(gpt-j|gptj|gpt4all-j|malion-6b|pygway|pygmalion-6b): + model_type: 'gptj' .*(4bit|int4): wbits: 4 .*(3bit|int3): @@ -27,8 +28,6 @@ wbits: 6 .*(-5bit|_5bit|int5-): wbits: 5 -.*gptqv2: - groupsize: 'None' .*(-gr32-|-32g-|groupsize32): groupsize: 32 .*(-gr64-|-64g-|groupsize64): @@ -37,6 +36,11 @@ groupsize: 128 .*(gr1024|1024g|groupsize1024): groupsize: 1024 +.*gptq(?!u|arl|v2): + wbits: 4 + groupsize: 128 +.*gptqv2: + groupsize: 'None' .*(oasst|stablelm-7b-sft-v7-epoch-3): mode: 'instruct' instruction_template: 'Open Assistant' @@ -131,8 +135,4 @@ instruction_template: 'INCITE-Chat' .*incite.*instruct: mode: 'instruct' - instruction_template: 'INCITE-Instruct' -.*pygmalion-7b: - model_type: 'llama' -.*metharme-7b: - model_type: 'llama' \ No newline at end of file + instruction_template: 'INCITE-Instruct' \ No newline at end of file diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 5faec390..32381eff 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -10,6 +10,7 @@ import transformers from transformers import AutoConfig, AutoModelForCausalLM import modules.shared as shared +from server import get_model_specific_settings sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) @@ -53,6 +54,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc torch.set_default_dtype(torch.float) if eval: model = model.eval() + layers = find_layers(model) for name in exclude_layers: if name in layers: @@ -78,7 +80,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc quant.make_quant_linear(model, layers, wbits, groupsize) del layers - if checkpoint.endswith('.safetensors'): from safetensors.torch import load_file as safe_load model.load_state_dict(safe_load(checkpoint), strict=False) @@ -88,6 +89,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc if is_triton: if shared.args.quant_attn: quant.make_quant_attn(model) + if eval and shared.args.fused_mlp: quant.make_fused_mlp(model) @@ -141,19 +143,15 @@ def find_quantized_model_file(model_name): # The function that loads the model in modules/models.py def load_quantized(model_name): - # Find the model type if not shared.args.model_type: - name = model_name.lower() - if any((k in name for k in ['opt-', 'opt_', 'opt1', 'opt3', 'optfor', 'galactica', 'galpaca', 'pygmalion-350m'])): - model_type = 'opt' - elif any((k in name for k in ['gpt-j', 'gptj', 'gpt4all-j', 'malion-6b', 'pygway', 'pygmalion-6b'])): - model_type = 'gptj' - elif any((k in name for k in ['llama', 'alpac', 'vicuna', 'guanaco', 'koala', 'llava', 'wizardlm', 'metharme'])): - model_type = 'llama' + settings = get_model_specific_settings(model_name) + if 'model_type' in settings and settings['model_type'] != 'None': + model_type = settings['model_type'] else: - logging.error("Can't determine model type from model name. Please specify it manually using --model_type argument") - exit() + logging.error("The model could not be loaded because its type could not be inferred from its name.") + logging.error("Please specify the type manually using the --model_type argument.") + return else: model_type = shared.args.model_type.lower()