--idle-timeout flag to unload the model if unused for N minutes (#6026)

This commit is contained in:
oobabooga 2024-05-19 23:29:39 -03:00 committed by GitHub
parent 818b4e0354
commit 9f77ed1b98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 57 additions and 13 deletions

View File

@ -308,9 +308,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
'internal': output['internal']
}
if shared.model_name == 'None' or shared.model is None:
raise ValueError("No model is loaded! Select one in the Model tab.")
# Generate the prompt
kwargs = {
'_continue': _continue,
@ -355,11 +352,6 @@ def impersonate_wrapper(text, state):
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
if shared.model_name == 'None' or shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.")
yield '', static_output
return
prompt = generate_chat_prompt('', state, impersonate=True)
stopping_strings = get_stopping_strings(state)

View File

@ -1,14 +1,32 @@
import time
import torch
from transformers import is_torch_npu_available, is_torch_xpu_available
from modules import sampler_hijack, shared
from modules import models, sampler_hijack, shared
from modules.logging_colors import logger
from modules.models import load_model
from modules.text_generation import generate_reply
global_scores = None
def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False):
def get_next_logits(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
shared.generation_lock.acquire()
try:
result = _get_next_logits(*args, **kwargs)
except:
result = None
models.last_generation_time = time.time()
shared.generation_lock.release()
return result
def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False):
if shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.")
return 'Error: No model is loaded1 Select one in the Model tab.', previous

View File

@ -61,6 +61,9 @@ if shared.args.deepspeed:
sampler_hijack.hijack_samplers()
last_generation_time = time.time()
def load_model(model_name, loader=None):
logger.info(f"Loading \"{model_name}\"")
t0 = time.time()
@ -428,6 +431,7 @@ def clear_torch_cache():
def unload_model():
shared.model = shared.tokenizer = None
shared.previous_model_name = shared.model_name
shared.model_name = 'None'
shared.lora_names = []
shared.model_dirty_from_training = False
@ -437,3 +441,21 @@ def unload_model():
def reload_model():
unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name)
def unload_model_if_idle():
global last_generation_time
logger.info(f"Setting a timeout of {shared.args.idle_timeout} minutes to unload the model in case of inactivity.")
while True:
shared.generation_lock.acquire()
try:
if time.time() - last_generation_time > shared.args.idle_timeout * 60:
if shared.model is not None:
logger.info("Unloading the model for inactivity.")
unload_model()
finally:
shared.generation_lock.release()
time.sleep(60)

View File

@ -13,6 +13,7 @@ from modules.logging_colors import logger
model = None
tokenizer = None
model_name = 'None'
previous_model_name = 'None'
is_seq2seq = False
model_dirty_from_training = False
lora_names = []
@ -84,6 +85,7 @@ group.add_argument('--settings', type=str, help='Load the default interface sett
group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
group.add_argument('--chat-buttons', action='store_true', help='Show buttons on the chat tab instead of a hover menu.')
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
# Model loader
group = parser.add_argument_group('Model loader')

View File

@ -16,6 +16,7 @@ from transformers import (
)
import modules.shared as shared
from modules import models
from modules.cache_utils import process_llamacpp_cache
from modules.callbacks import (
Iteratorize,
@ -27,15 +28,19 @@ from modules.grammar.grammar_utils import initialize_grammar
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
from modules.html_generator import generate_basic_html
from modules.logging_colors import logger
from modules.models import clear_torch_cache
from modules.models import clear_torch_cache, load_model
def generate_reply(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
shared.generation_lock.acquire()
try:
for result in _generate_reply(*args, **kwargs):
yield result
finally:
models.last_generation_time = time.time()
shared.generation_lock.release()

View File

@ -32,7 +32,7 @@ import sys
import time
from functools import partial
from pathlib import Path
from threading import Lock
from threading import Lock, Thread
import yaml
@ -52,7 +52,7 @@ from modules import (
)
from modules.extensions import apply_extensions
from modules.LoRA import add_lora_to_model
from modules.models import load_model
from modules.models import load_model, unload_model_if_idle
from modules.models_settings import (
get_fallback_settings,
get_model_metadata,
@ -245,6 +245,11 @@ if __name__ == "__main__":
shared.generation_lock = Lock()
if shared.args.idle_timeout > 0:
timer_thread = Thread(target=unload_model_if_idle)
timer_thread.daemon = True
timer_thread.start()
if shared.args.nowebui:
# Start the API in standalone mode
shared.args.extensions = [x for x in shared.args.extensions if x != 'gallery']