When saving model settings, only save the settings for the current loader

This commit is contained in:
oobabooga 2023-08-01 06:10:09 -07:00
parent ebb4f22028
commit 959feba602
2 changed files with 8 additions and 6 deletions

View File

@ -3,7 +3,7 @@ from pathlib import Path
import yaml
from modules import shared, ui
from modules import loaders, shared, ui
def get_model_settings_from_yamls(model):
@ -126,10 +126,12 @@ def save_model_settings(model, state):
user_config[model_regex] = {}
for k in ui.list_model_elements():
if k == 'loader' or k in loaders.loaders_and_params[state['loader']]:
user_config[model_regex][k] = state[k]
shared.model_config[model_regex][k] = state[k]
output = yaml.dump(user_config, sort_keys=False)
with open(p, 'w') as f:
f.write(yaml.dump(user_config, sort_keys=False))
f.write(output)
yield (f"Settings for {model} saved to {p}")

View File

@ -220,8 +220,8 @@ def create_model_menus():
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx)
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=32, value=shared.args.threads)
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch)
shared.gradio['n_gqa'] = gr.Slider(minimum=0, maximum=16, step=1, label="n_gqa", value=shared.args.n_gqa, info='grouped-query attention. Must be 8 for llama2 70b.')
shared.gradio['rms_norm_eps'] = gr.Slider(minimum=0, maximum=1e-5, step=1e-6, label="rms_norm_eps", value=shared.args.n_gqa, info='5e-6 is a good value for llama2 70b.')
shared.gradio['n_gqa'] = gr.Slider(minimum=0, maximum=16, step=1, label="n_gqa", value=shared.args.n_gqa, info='grouped-query attention. Must be 8 for llama-2 70b.')
shared.gradio['rms_norm_eps'] = gr.Slider(minimum=0, maximum=1e-5, step=1e-6, label="rms_norm_eps", value=shared.args.n_gqa, info='5e-6 is a good value for llama-2 models.')
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None")
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")