mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Automatically set bf16 & use_eager_attention for Gemma-2
This commit is contained in:
parent
8074fba18d
commit
907137a13d
@ -9,6 +9,8 @@ from modules import chat, loaders, metadata_gguf, shared, ui
|
||||
|
||||
def get_fallback_settings():
|
||||
return {
|
||||
'bf16': False,
|
||||
'use_eager_attention': False,
|
||||
'wbits': 'None',
|
||||
'groupsize': 'None',
|
||||
'desc_act': False,
|
||||
@ -97,10 +99,18 @@ def get_model_metadata(model):
|
||||
elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']:
|
||||
model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta']
|
||||
|
||||
if 'rope_scaling' in metadata and type(metadata['rope_scaling']) is dict and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
|
||||
if 'rope_scaling' in metadata and isinstance(metadata['rope_scaling'], dict) and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
|
||||
if metadata['rope_scaling']['type'] == 'linear':
|
||||
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
|
||||
|
||||
# For Gemma-2
|
||||
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
|
||||
model_settings['bf16'] = True
|
||||
|
||||
# For Gemma-2
|
||||
if 'architectures' in metadata and isinstance(metadata['architectures'], list) and 'Gemma2ForCausalLM' in metadata['architectures']:
|
||||
model_settings['use_eager_attention'] = True
|
||||
|
||||
# Read GPTQ metadata for old GPTQ loaders
|
||||
if 'quantization_config' in metadata and metadata['quantization_config'].get('quant_method', '') != 'exl2':
|
||||
if 'bits' in metadata['quantization_config']:
|
||||
@ -133,7 +143,7 @@ def get_model_metadata(model):
|
||||
for k in ['eos_token', 'bos_token']:
|
||||
if k in metadata:
|
||||
value = metadata[k]
|
||||
if type(value) is dict:
|
||||
if isinstance(value, dict):
|
||||
value = value['content']
|
||||
|
||||
template = template.replace(k, "'{}'".format(value))
|
||||
@ -168,7 +178,7 @@ def infer_loader(model_name, model_settings):
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if not path_to_model.exists():
|
||||
loader = None
|
||||
elif (path_to_model / 'quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0):
|
||||
elif (path_to_model / 'quantize_config.json').exists() or ('wbits' in model_settings and isinstance(model_settings['wbits'], int) and model_settings['wbits'] > 0):
|
||||
loader = 'ExLlamav2_HF'
|
||||
elif (path_to_model / 'quant_config.json').exists() or re.match(r'.*-awq', model_name.lower()):
|
||||
loader = 'AutoAWQ'
|
||||
|
Loading…
Reference in New Issue
Block a user