Beginning of multi-user support (#2262)

Adds a lock to generate_reply
This commit is contained in:
flurb18 2023-05-24 08:38:20 -04:00 committed by GitHub
parent 7dc87984a2
commit d37a28730d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 1 deletions

View File

@ -6,6 +6,7 @@ import yaml
from modules.logging_colors import logger
generation_lock = None
model = None
tokenizer = None
model_name = "None"

View File

@ -1,6 +1,7 @@
import ast
import random
import re
import threading
import time
import traceback
@ -17,6 +18,15 @@ from modules.logging_colors import logger
from modules.models import clear_torch_cache, local_rank
def generate_reply(*args, **kwargs):
shared.generation_lock.acquire()
try:
for result in _generate_reply(*args, **kwargs):
yield result
finally:
shared.generation_lock.release()
def get_max_prompt_length(state):
max_length = state['truncation_length'] - state['max_new_tokens']
if shared.soft_prompt:
@ -154,7 +164,7 @@ def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=Non
yield formatted_outputs(reply, shared.model_name)
def generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False):
def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False):
state = apply_extensions('state', state)
generate_func = apply_extensions('custom_generate_reply')
if generate_func is None:

View File

@ -38,6 +38,7 @@ import zipfile
from datetime import datetime
from functools import partial
from pathlib import Path
from threading import Lock
import psutil
import torch
@ -1075,6 +1076,7 @@ if __name__ == "__main__":
'instruction_template': shared.settings['instruction_template']
})
shared.generation_lock = Lock()
# Launch the web UI
create_interface()
while True: