Remove duplicate code

This commit is contained in:
oobabooga 2023-05-10 02:07:22 -03:00
parent cd36b8f739
commit dfd9ba3e90
2 changed files with 20 additions and 22 deletions

View File

@ -6,11 +6,12 @@
mode: 'chat' mode: 'chat'
skip_special_tokens: true skip_special_tokens: true
custom_stopping_strings: '' custom_stopping_strings: ''
.*llama: .*(llama|alpac|vicuna|guanaco|koala|llava|wizardlm|metharme|pygmalion-7b):
model_type: 'llama' model_type: 'llama'
.*gptq(?!u|arl|v2): .*(opt-|opt_|opt1|opt3|optfor|galactica|galpaca|pygmalion-350m):
wbits: 4 model_type: 'opt'
groupsize: 128 .*(gpt-j|gptj|gpt4all-j|malion-6b|pygway|pygmalion-6b):
model_type: 'gptj'
.*(4bit|int4): .*(4bit|int4):
wbits: 4 wbits: 4
.*(3bit|int3): .*(3bit|int3):
@ -27,8 +28,6 @@
wbits: 6 wbits: 6
.*(-5bit|_5bit|int5-): .*(-5bit|_5bit|int5-):
wbits: 5 wbits: 5
.*gptqv2:
groupsize: 'None'
.*(-gr32-|-32g-|groupsize32): .*(-gr32-|-32g-|groupsize32):
groupsize: 32 groupsize: 32
.*(-gr64-|-64g-|groupsize64): .*(-gr64-|-64g-|groupsize64):
@ -37,6 +36,11 @@
groupsize: 128 groupsize: 128
.*(gr1024|1024g|groupsize1024): .*(gr1024|1024g|groupsize1024):
groupsize: 1024 groupsize: 1024
.*gptq(?!u|arl|v2):
wbits: 4
groupsize: 128
.*gptqv2:
groupsize: 'None'
.*(oasst|stablelm-7b-sft-v7-epoch-3): .*(oasst|stablelm-7b-sft-v7-epoch-3):
mode: 'instruct' mode: 'instruct'
instruction_template: 'Open Assistant' instruction_template: 'Open Assistant'
@ -131,8 +135,4 @@
instruction_template: 'INCITE-Chat' instruction_template: 'INCITE-Chat'
.*incite.*instruct: .*incite.*instruct:
mode: 'instruct' mode: 'instruct'
instruction_template: 'INCITE-Instruct' instruction_template: 'INCITE-Instruct'
.*pygmalion-7b:
model_type: 'llama'
.*metharme-7b:
model_type: 'llama'

View File

@ -10,6 +10,7 @@ import transformers
from transformers import AutoConfig, AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM
import modules.shared as shared import modules.shared as shared
from server import get_model_specific_settings
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
@ -53,6 +54,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
if eval: if eval:
model = model.eval() model = model.eval()
layers = find_layers(model) layers = find_layers(model)
for name in exclude_layers: for name in exclude_layers:
if name in layers: if name in layers:
@ -78,7 +80,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
quant.make_quant_linear(model, layers, wbits, groupsize) quant.make_quant_linear(model, layers, wbits, groupsize)
del layers del layers
if checkpoint.endswith('.safetensors'): if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint), strict=False) model.load_state_dict(safe_load(checkpoint), strict=False)
@ -88,6 +89,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
if is_triton: if is_triton:
if shared.args.quant_attn: if shared.args.quant_attn:
quant.make_quant_attn(model) quant.make_quant_attn(model)
if eval and shared.args.fused_mlp: if eval and shared.args.fused_mlp:
quant.make_fused_mlp(model) quant.make_fused_mlp(model)
@ -141,19 +143,15 @@ def find_quantized_model_file(model_name):
# The function that loads the model in modules/models.py # The function that loads the model in modules/models.py
def load_quantized(model_name): def load_quantized(model_name):
# Find the model type # Find the model type
if not shared.args.model_type: if not shared.args.model_type:
name = model_name.lower() settings = get_model_specific_settings(model_name)
if any((k in name for k in ['opt-', 'opt_', 'opt1', 'opt3', 'optfor', 'galactica', 'galpaca', 'pygmalion-350m'])): if 'model_type' in settings and settings['model_type'] != 'None':
model_type = 'opt' model_type = settings['model_type']
elif any((k in name for k in ['gpt-j', 'gptj', 'gpt4all-j', 'malion-6b', 'pygway', 'pygmalion-6b'])):
model_type = 'gptj'
elif any((k in name for k in ['llama', 'alpac', 'vicuna', 'guanaco', 'koala', 'llava', 'wizardlm', 'metharme'])):
model_type = 'llama'
else: else:
logging.error("Can't determine model type from model name. Please specify it manually using --model_type argument") logging.error("The model could not be loaded because its type could not be inferred from its name.")
exit() logging.error("Please specify the type manually using the --model_type argument.")
return
else: else:
model_type = shared.args.model_type.lower() model_type = shared.args.model_type.lower()