From 1917b1527503d7efbce3d33aa7df9a216aaf36fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A6=CF=86?= <42910943+Brawlence@users.noreply.github.com> Date: Tue, 21 Mar 2023 13:15:42 +0300 Subject: [PATCH 1/3] Unload and reload models on request --- server.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/server.py b/server.py index cdf7aa93..1309c17d 100644 --- a/server.py +++ b/server.py @@ -63,6 +63,18 @@ def load_model_wrapper(selected_model): return selected_model +def reload_model(): + if not shared.args.cpu: + gc.collect() + torch.cuda.empty_cache() + shared.model, shared.tokenizer = load_model(shared.model_name) + +def unload_model(): + shared.model = shared.tokenizer = None + if not shared.args.cpu: + gc.collect() + torch.cuda.empty_cache() + def load_lora_wrapper(selected_lora): shared.lora_name = selected_lora default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] @@ -126,6 +138,9 @@ def create_model_and_preset_menus(): with gr.Row(): shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') + with gr.Row(): + shared.gradio['unload_model'] = gr.Button(value='Unload model to free VRAM', elem_id="unload_model") + shared.gradio['reload_model'] = gr.Button(value='Reload the model into VRAM', elem_id="reload_model") def create_settings_menus(default_preset): generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) @@ -180,6 +195,8 @@ def create_settings_menus(default_preset): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) + shared.gradio['unload_model'].click(fn=unload_model,inputs=[],outputs=[]) + shared.gradio['reload_model'].click(fn=reload_model,inputs=[],outputs=[]) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True) From 483d173d23309f77d197951ad9f21632955fd13a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=CE=A6=CF=86?= <42910943+Brawlence@users.noreply.github.com> Date: Tue, 21 Mar 2023 20:19:38 +0300 Subject: [PATCH 2/3] Code reuse + indication Now shows the message in the console when unloading weights. Also reload_model() calls unload_model() first to free the memory so that multiple reloads won't overfill it. --- server.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server.py b/server.py index 1309c17d..4c3497c9 100644 --- a/server.py +++ b/server.py @@ -64,9 +64,7 @@ def load_model_wrapper(selected_model): return selected_model def reload_model(): - if not shared.args.cpu: - gc.collect() - torch.cuda.empty_cache() + unload_model() shared.model, shared.tokenizer = load_model(shared.model_name) def unload_model(): @@ -74,6 +72,7 @@ def unload_model(): if not shared.args.cpu: gc.collect() torch.cuda.empty_cache() + print("Model weights unloaded.") def load_lora_wrapper(selected_lora): shared.lora_name = selected_lora From 95c97e1747f277e62db997da73556a94904c1f9c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 26 Mar 2023 23:47:29 -0300 Subject: [PATCH 3/3] Unload the model using the "Remove all" button --- server.py | 27 ++++++++------------------- 1 file changed, 8 insertions(+), 19 deletions(-) diff --git a/server.py b/server.py index 3e31377c..db83b4f3 100644 --- a/server.py +++ b/server.py @@ -50,26 +50,20 @@ def get_available_softprompts(): def get_available_loras(): return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) +def unload_model(): + shared.model = shared.tokenizer = None + clear_torch_cache() + def load_model_wrapper(selected_model): if selected_model != shared.model_name: shared.model_name = selected_model - shared.model = shared.tokenizer = None - clear_torch_cache() - shared.model, shared.tokenizer = load_model(shared.model_name) + + unload_model() + if selected_model != '': + shared.model, shared.tokenizer = load_model(shared.model_name) return selected_model -def reload_model(): - unload_model() - shared.model, shared.tokenizer = load_model(shared.model_name) - -def unload_model(): - shared.model = shared.tokenizer = None - if not shared.args.cpu: - gc.collect() - torch.cuda.empty_cache() - print("Model weights unloaded.") - def load_lora_wrapper(selected_lora): add_lora_to_model(selected_lora) default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] @@ -128,9 +122,6 @@ def create_model_and_preset_menus(): with gr.Row(): shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') - with gr.Row(): - shared.gradio['unload_model'] = gr.Button(value='Unload model to free VRAM', elem_id="unload_model") - shared.gradio['reload_model'] = gr.Button(value='Reload the model into VRAM', elem_id="reload_model") def create_settings_menus(default_preset): generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) @@ -185,8 +176,6 @@ def create_settings_menus(default_preset): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) - shared.gradio['unload_model'].click(fn=unload_model,inputs=[],outputs=[]) - shared.gradio['reload_model'].click(fn=reload_model,inputs=[],outputs=[]) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)