From 9f77ed1b987969b077d36438303183d3e2609f5a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 19 May 2024 23:29:39 -0300 Subject: [PATCH] --idle-timeout flag to unload the model if unused for N minutes (#6026) --- modules/chat.py | 8 -------- modules/logits.py | 22 ++++++++++++++++++++-- modules/models.py | 22 ++++++++++++++++++++++ modules/shared.py | 2 ++ modules/text_generation.py | 7 ++++++- server.py | 9 +++++++-- 6 files changed, 57 insertions(+), 13 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 4e0bde1c..43f5466b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -308,9 +308,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess 'internal': output['internal'] } - if shared.model_name == 'None' or shared.model is None: - raise ValueError("No model is loaded! Select one in the Model tab.") - # Generate the prompt kwargs = { '_continue': _continue, @@ -355,11 +352,6 @@ def impersonate_wrapper(text, state): static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) - if shared.model_name == 'None' or shared.model is None: - logger.error("No model is loaded! Select one in the Model tab.") - yield '', static_output - return - prompt = generate_chat_prompt('', state, impersonate=True) stopping_strings = get_stopping_strings(state) diff --git a/modules/logits.py b/modules/logits.py index f2fd233b..3e793bd0 100644 --- a/modules/logits.py +++ b/modules/logits.py @@ -1,14 +1,32 @@ +import time + import torch from transformers import is_torch_npu_available, is_torch_xpu_available -from modules import sampler_hijack, shared +from modules import models, sampler_hijack, shared from modules.logging_colors import logger +from modules.models import load_model from modules.text_generation import generate_reply global_scores = None -def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False): +def get_next_logits(*args, **kwargs): + if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']: + shared.model, shared.tokenizer = load_model(shared.previous_model_name) + + shared.generation_lock.acquire() + try: + result = _get_next_logits(*args, **kwargs) + except: + result = None + + models.last_generation_time = time.time() + shared.generation_lock.release() + return result + + +def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False): if shared.model is None: logger.error("No model is loaded! Select one in the Model tab.") return 'Error: No model is loaded1 Select one in the Model tab.', previous diff --git a/modules/models.py b/modules/models.py index cac66393..b03e1c9d 100644 --- a/modules/models.py +++ b/modules/models.py @@ -61,6 +61,9 @@ if shared.args.deepspeed: sampler_hijack.hijack_samplers() +last_generation_time = time.time() + + def load_model(model_name, loader=None): logger.info(f"Loading \"{model_name}\"") t0 = time.time() @@ -428,6 +431,7 @@ def clear_torch_cache(): def unload_model(): shared.model = shared.tokenizer = None + shared.previous_model_name = shared.model_name shared.model_name = 'None' shared.lora_names = [] shared.model_dirty_from_training = False @@ -437,3 +441,21 @@ def unload_model(): def reload_model(): unload_model() shared.model, shared.tokenizer = load_model(shared.model_name) + + +def unload_model_if_idle(): + global last_generation_time + + logger.info(f"Setting a timeout of {shared.args.idle_timeout} minutes to unload the model in case of inactivity.") + + while True: + shared.generation_lock.acquire() + try: + if time.time() - last_generation_time > shared.args.idle_timeout * 60: + if shared.model is not None: + logger.info("Unloading the model for inactivity.") + unload_model() + finally: + shared.generation_lock.release() + + time.sleep(60) diff --git a/modules/shared.py b/modules/shared.py index a3ce584c..645ba701 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,6 +13,7 @@ from modules.logging_colors import logger model = None tokenizer = None model_name = 'None' +previous_model_name = 'None' is_seq2seq = False model_dirty_from_training = False lora_names = [] @@ -84,6 +85,7 @@ group.add_argument('--settings', type=str, help='Load the default interface sett group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') group.add_argument('--chat-buttons', action='store_true', help='Show buttons on the chat tab instead of a hover menu.') +group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.') # Model loader group = parser.add_argument_group('Model loader') diff --git a/modules/text_generation.py b/modules/text_generation.py index 79067f84..afcdaddb 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -16,6 +16,7 @@ from transformers import ( ) import modules.shared as shared +from modules import models from modules.cache_utils import process_llamacpp_cache from modules.callbacks import ( Iteratorize, @@ -27,15 +28,19 @@ from modules.grammar.grammar_utils import initialize_grammar from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor from modules.html_generator import generate_basic_html from modules.logging_colors import logger -from modules.models import clear_torch_cache +from modules.models import clear_torch_cache, load_model def generate_reply(*args, **kwargs): + if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']: + shared.model, shared.tokenizer = load_model(shared.previous_model_name) + shared.generation_lock.acquire() try: for result in _generate_reply(*args, **kwargs): yield result finally: + models.last_generation_time = time.time() shared.generation_lock.release() diff --git a/server.py b/server.py index 4b5185be..04a5b16d 100644 --- a/server.py +++ b/server.py @@ -32,7 +32,7 @@ import sys import time from functools import partial from pathlib import Path -from threading import Lock +from threading import Lock, Thread import yaml @@ -52,7 +52,7 @@ from modules import ( ) from modules.extensions import apply_extensions from modules.LoRA import add_lora_to_model -from modules.models import load_model +from modules.models import load_model, unload_model_if_idle from modules.models_settings import ( get_fallback_settings, get_model_metadata, @@ -245,6 +245,11 @@ if __name__ == "__main__": shared.generation_lock = Lock() + if shared.args.idle_timeout > 0: + timer_thread = Thread(target=unload_model_if_idle) + timer_thread.daemon = True + timer_thread.start() + if shared.args.nowebui: # Start the API in standalone mode shared.args.extensions = [x for x in shared.args.extensions if x != 'gallery']