mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Automatically set bf16 & use_eager_attention for Gemma-2
This commit is contained in:
parent
8074fba18d
commit
907137a13d
@ -9,6 +9,8 @@ from modules import chat, loaders, metadata_gguf, shared, ui
|
|||||||
|
|
||||||
def get_fallback_settings():
|
def get_fallback_settings():
|
||||||
return {
|
return {
|
||||||
|
'bf16': False,
|
||||||
|
'use_eager_attention': False,
|
||||||
'wbits': 'None',
|
'wbits': 'None',
|
||||||
'groupsize': 'None',
|
'groupsize': 'None',
|
||||||
'desc_act': False,
|
'desc_act': False,
|
||||||
@ -97,10 +99,18 @@ def get_model_metadata(model):
|
|||||||
elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']:
|
elif 'attn_config' in metadata and 'rope_theta' in metadata['attn_config']:
|
||||||
model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta']
|
model_settings['rope_freq_base'] = metadata['attn_config']['rope_theta']
|
||||||
|
|
||||||
if 'rope_scaling' in metadata and type(metadata['rope_scaling']) is dict and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
|
if 'rope_scaling' in metadata and isinstance(metadata['rope_scaling'], dict) and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
|
||||||
if metadata['rope_scaling']['type'] == 'linear':
|
if metadata['rope_scaling']['type'] == 'linear':
|
||||||
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
|
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
|
||||||
|
|
||||||
|
# For Gemma-2
|
||||||
|
if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16':
|
||||||
|
model_settings['bf16'] = True
|
||||||
|
|
||||||
|
# For Gemma-2
|
||||||
|
if 'architectures' in metadata and isinstance(metadata['architectures'], list) and 'Gemma2ForCausalLM' in metadata['architectures']:
|
||||||
|
model_settings['use_eager_attention'] = True
|
||||||
|
|
||||||
# Read GPTQ metadata for old GPTQ loaders
|
# Read GPTQ metadata for old GPTQ loaders
|
||||||
if 'quantization_config' in metadata and metadata['quantization_config'].get('quant_method', '') != 'exl2':
|
if 'quantization_config' in metadata and metadata['quantization_config'].get('quant_method', '') != 'exl2':
|
||||||
if 'bits' in metadata['quantization_config']:
|
if 'bits' in metadata['quantization_config']:
|
||||||
@ -133,7 +143,7 @@ def get_model_metadata(model):
|
|||||||
for k in ['eos_token', 'bos_token']:
|
for k in ['eos_token', 'bos_token']:
|
||||||
if k in metadata:
|
if k in metadata:
|
||||||
value = metadata[k]
|
value = metadata[k]
|
||||||
if type(value) is dict:
|
if isinstance(value, dict):
|
||||||
value = value['content']
|
value = value['content']
|
||||||
|
|
||||||
template = template.replace(k, "'{}'".format(value))
|
template = template.replace(k, "'{}'".format(value))
|
||||||
@ -168,7 +178,7 @@ def infer_loader(model_name, model_settings):
|
|||||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||||
if not path_to_model.exists():
|
if not path_to_model.exists():
|
||||||
loader = None
|
loader = None
|
||||||
elif (path_to_model / 'quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0):
|
elif (path_to_model / 'quantize_config.json').exists() or ('wbits' in model_settings and isinstance(model_settings['wbits'], int) and model_settings['wbits'] > 0):
|
||||||
loader = 'ExLlamav2_HF'
|
loader = 'ExLlamav2_HF'
|
||||||
elif (path_to_model / 'quant_config.json').exists() or re.match(r'.*-awq', model_name.lower()):
|
elif (path_to_model / 'quant_config.json').exists() or re.match(r'.*-awq', model_name.lower()):
|
||||||
loader = 'AutoAWQ'
|
loader = 'AutoAWQ'
|
||||||
|
Loading…
Reference in New Issue
Block a user