diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 5c945def..30f41590 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -145,12 +145,14 @@ def create_event_handlers(): apply_model_settings_to_state, gradio('model_menu', 'interface_state'), gradio('interface_state')).then( ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then( update_model_parameters, gradio('interface_state'), None).then( - load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False) + load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False).success( + update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')) shared.gradio['load_model'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( update_model_parameters, gradio('interface_state'), None).then( - partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False) + partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).success( + update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')) shared.gradio['unload_model'].click( unload_model, None, None).then( @@ -160,7 +162,8 @@ def create_event_handlers(): unload_model, None, None).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( update_model_parameters, gradio('interface_state'), None).then( - partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False) + partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).success( + update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')) shared.gradio['save_model_settings'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( @@ -235,3 +238,12 @@ def download_model_wrapper(repo_id, progress=gr.Progress()): except: progress(1.0) yield traceback.format_exc().replace('\n', '\n\n') + + +def update_truncation_length(current_length, state): + if state['loader'] in ['ExLlama', 'ExLlama_HF']: + return state['max_seq_len'] + elif state['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']: + return state['n_ctx'] + else: + return current_length diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index a0f95158..b5ce5ac9 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -113,7 +113,7 @@ def create_ui(default_preset): with gr.Box(): with gr.Row(): with gr.Column(): - shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') + shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"') with gr.Column(): shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.') @@ -129,3 +129,12 @@ def create_ui(default_preset): def create_event_handlers(): shared.gradio['filter_by_loader'].change(loaders.blacklist_samplers, gradio('filter_by_loader'), gradio(loaders.list_all_samplers()), show_progress=False) shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params())) + + +def get_truncation_length(): + if shared.args.max_seq_len != shared.args_defaults.max_seq_len: + return shared.args.max_seq_len + if shared.args.n_ctx != shared.args_defaults.n_ctx: + return shared.args.n_ctx + else: + return shared.settings['truncation_length']