mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
--idle-timeout flag to unload the model if unused for N minutes (#6026)
This commit is contained in:
parent
818b4e0354
commit
9f77ed1b98
@ -308,9 +308,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
||||
'internal': output['internal']
|
||||
}
|
||||
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
raise ValueError("No model is loaded! Select one in the Model tab.")
|
||||
|
||||
# Generate the prompt
|
||||
kwargs = {
|
||||
'_continue': _continue,
|
||||
@ -355,11 +352,6 @@ def impersonate_wrapper(text, state):
|
||||
|
||||
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
|
||||
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logger.error("No model is loaded! Select one in the Model tab.")
|
||||
yield '', static_output
|
||||
return
|
||||
|
||||
prompt = generate_chat_prompt('', state, impersonate=True)
|
||||
stopping_strings = get_stopping_strings(state)
|
||||
|
||||
|
@ -1,14 +1,32 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
from modules import sampler_hijack, shared
|
||||
from modules import models, sampler_hijack, shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import load_model
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
global_scores = None
|
||||
|
||||
|
||||
def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False):
|
||||
def get_next_logits(*args, **kwargs):
|
||||
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
|
||||
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
|
||||
|
||||
shared.generation_lock.acquire()
|
||||
try:
|
||||
result = _get_next_logits(*args, **kwargs)
|
||||
except:
|
||||
result = None
|
||||
|
||||
models.last_generation_time = time.time()
|
||||
shared.generation_lock.release()
|
||||
return result
|
||||
|
||||
|
||||
def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False):
|
||||
if shared.model is None:
|
||||
logger.error("No model is loaded! Select one in the Model tab.")
|
||||
return 'Error: No model is loaded1 Select one in the Model tab.', previous
|
||||
|
@ -61,6 +61,9 @@ if shared.args.deepspeed:
|
||||
sampler_hijack.hijack_samplers()
|
||||
|
||||
|
||||
last_generation_time = time.time()
|
||||
|
||||
|
||||
def load_model(model_name, loader=None):
|
||||
logger.info(f"Loading \"{model_name}\"")
|
||||
t0 = time.time()
|
||||
@ -428,6 +431,7 @@ def clear_torch_cache():
|
||||
|
||||
def unload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
shared.previous_model_name = shared.model_name
|
||||
shared.model_name = 'None'
|
||||
shared.lora_names = []
|
||||
shared.model_dirty_from_training = False
|
||||
@ -437,3 +441,21 @@ def unload_model():
|
||||
def reload_model():
|
||||
unload_model()
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
|
||||
def unload_model_if_idle():
|
||||
global last_generation_time
|
||||
|
||||
logger.info(f"Setting a timeout of {shared.args.idle_timeout} minutes to unload the model in case of inactivity.")
|
||||
|
||||
while True:
|
||||
shared.generation_lock.acquire()
|
||||
try:
|
||||
if time.time() - last_generation_time > shared.args.idle_timeout * 60:
|
||||
if shared.model is not None:
|
||||
logger.info("Unloading the model for inactivity.")
|
||||
unload_model()
|
||||
finally:
|
||||
shared.generation_lock.release()
|
||||
|
||||
time.sleep(60)
|
||||
|
@ -13,6 +13,7 @@ from modules.logging_colors import logger
|
||||
model = None
|
||||
tokenizer = None
|
||||
model_name = 'None'
|
||||
previous_model_name = 'None'
|
||||
is_seq2seq = False
|
||||
model_dirty_from_training = False
|
||||
lora_names = []
|
||||
@ -84,6 +85,7 @@ group.add_argument('--settings', type=str, help='Load the default interface sett
|
||||
group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.')
|
||||
group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||
group.add_argument('--chat-buttons', action='store_true', help='Show buttons on the chat tab instead of a hover menu.')
|
||||
group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.')
|
||||
|
||||
# Model loader
|
||||
group = parser.add_argument_group('Model loader')
|
||||
|
@ -16,6 +16,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import models
|
||||
from modules.cache_utils import process_llamacpp_cache
|
||||
from modules.callbacks import (
|
||||
Iteratorize,
|
||||
@ -27,15 +28,19 @@ from modules.grammar.grammar_utils import initialize_grammar
|
||||
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
|
||||
from modules.html_generator import generate_basic_html
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import clear_torch_cache
|
||||
from modules.models import clear_torch_cache, load_model
|
||||
|
||||
|
||||
def generate_reply(*args, **kwargs):
|
||||
if shared.args.idle_timeout > 0 and shared.model is None and shared.previous_model_name not in [None, 'None']:
|
||||
shared.model, shared.tokenizer = load_model(shared.previous_model_name)
|
||||
|
||||
shared.generation_lock.acquire()
|
||||
try:
|
||||
for result in _generate_reply(*args, **kwargs):
|
||||
yield result
|
||||
finally:
|
||||
models.last_generation_time = time.time()
|
||||
shared.generation_lock.release()
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ import sys
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from threading import Lock, Thread
|
||||
|
||||
import yaml
|
||||
|
||||
@ -52,7 +52,7 @@ from modules import (
|
||||
)
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model
|
||||
from modules.models import load_model, unload_model_if_idle
|
||||
from modules.models_settings import (
|
||||
get_fallback_settings,
|
||||
get_model_metadata,
|
||||
@ -245,6 +245,11 @@ if __name__ == "__main__":
|
||||
|
||||
shared.generation_lock = Lock()
|
||||
|
||||
if shared.args.idle_timeout > 0:
|
||||
timer_thread = Thread(target=unload_model_if_idle)
|
||||
timer_thread.daemon = True
|
||||
timer_thread.start()
|
||||
|
||||
if shared.args.nowebui:
|
||||
# Start the API in standalone mode
|
||||
shared.args.extensions = [x for x in shared.args.extensions if x != 'gallery']
|
||||
|
Loading…
Reference in New Issue
Block a user