From b973b91d730cc759dc7bcf755ad85e95deb81128 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:28:35 -0700 Subject: [PATCH] Automatically filter by loader (closes #4072) --- modules/ui.py | 1 + modules/ui_model_menu.py | 17 ++++++++++------- server.py | 1 + 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 1b5ab955..afb8a1ef 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -44,6 +44,7 @@ else: def list_model_elements(): elements = [ 'loader', + 'filter_by_loader', 'cpu_memory', 'auto_devices', 'disk', diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 8a0180c5..b57e11f4 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -149,24 +149,27 @@ def create_event_handlers(): ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then( update_model_parameters, gradio('interface_state'), None).then( load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False).success( - update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')) + update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')).then( + lambda x: x, gradio('loader'), gradio('filter_by_loader')) shared.gradio['load_model'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( update_model_parameters, gradio('interface_state'), None).then( partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).success( - update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')) - - shared.gradio['unload_model'].click( - unload_model, None, None).then( - lambda: "Model unloaded", None, gradio('model_status')) + update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')).then( + lambda x: x, gradio('loader'), gradio('filter_by_loader')) shared.gradio['reload_model'].click( unload_model, None, None).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( update_model_parameters, gradio('interface_state'), None).then( partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).success( - update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')) + update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length')).then( + lambda x: x, gradio('loader'), gradio('filter_by_loader')) + + shared.gradio['unload_model'].click( + unload_model, None, None).then( + lambda: "Model unloaded", None, gradio('model_status')) shared.gradio['save_model_settings'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( diff --git a/server.py b/server.py index 6f7b03ec..f07df378 100644 --- a/server.py +++ b/server.py @@ -77,6 +77,7 @@ def create_interface(): 'instruction_template': shared.settings['instruction_template'], 'prompt_menu-default': shared.settings['prompt-default'], 'prompt_menu-notebook': shared.settings['prompt-notebook'], + 'filter_by_loader': shared.args.loader or 'All' }) if Path("cache/pfp_character.png").exists():