diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py index 2a7fb62b..88359e3e 100644 --- a/extensions/api/streaming_api.py +++ b/extensions/api/streaming_api.py @@ -4,7 +4,7 @@ from threading import Thread from websockets.server import serve -from extensions.api.util import build_parameters, try_start_cloudflared +from extensions.api.util import build_parameters, try_start_cloudflared, with_api_lock from modules import shared from modules.chat import generate_chat_reply from modules.text_generation import generate_reply @@ -12,72 +12,82 @@ from modules.text_generation import generate_reply PATH = '/api/v1/stream' +@with_api_lock +async def _handle_stream_message(websocket, message): + message = json.loads(message) + + prompt = message['prompt'] + generate_params = build_parameters(message) + stopping_strings = generate_params.pop('stopping_strings') + generate_params['stream'] = True + + generator = generate_reply( + prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) + + # As we stream, only send the new bytes. + skip_index = 0 + message_num = 0 + + for a in generator: + to_send = a[skip_index:] + if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet. + continue + + await websocket.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': to_send + })) + + await asyncio.sleep(0) + skip_index += len(to_send) + message_num += 1 + + await websocket.send(json.dumps({ + 'event': 'stream_end', + 'message_num': message_num + })) + + +@with_api_lock +async def _handle_chat_stream_message(websocket, message): + body = json.loads(message) + + user_input = body['user_input'] + generate_params = build_parameters(body, chat=True) + generate_params['stream'] = True + regenerate = body.get('regenerate', False) + _continue = body.get('_continue', False) + + generator = generate_chat_reply( + user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) + + message_num = 0 + for a in generator: + await websocket.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'history': a + })) + + await asyncio.sleep(0) + message_num += 1 + + await websocket.send(json.dumps({ + 'event': 'stream_end', + 'message_num': message_num + })) + + async def _handle_connection(websocket, path): if path == '/api/v1/stream': async for message in websocket: - message = json.loads(message) - - prompt = message['prompt'] - generate_params = build_parameters(message) - stopping_strings = generate_params.pop('stopping_strings') - generate_params['stream'] = True - - generator = generate_reply( - prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) - - # As we stream, only send the new bytes. - skip_index = 0 - message_num = 0 - - for a in generator: - to_send = a[skip_index:] - if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet. - continue - - await websocket.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': to_send - })) - - await asyncio.sleep(0) - skip_index += len(to_send) - message_num += 1 - - await websocket.send(json.dumps({ - 'event': 'stream_end', - 'message_num': message_num - })) + await _handle_stream_message(websocket, message) elif path == '/api/v1/chat-stream': async for message in websocket: - body = json.loads(message) - - user_input = body['user_input'] - generate_params = build_parameters(body, chat=True) - generate_params['stream'] = True - regenerate = body.get('regenerate', False) - _continue = body.get('_continue', False) - - generator = generate_chat_reply( - user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) - - message_num = 0 - for a in generator: - await websocket.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'history': a - })) - - await asyncio.sleep(0) - message_num += 1 - - await websocket.send(json.dumps({ - 'event': 'stream_end', - 'message_num': message_num - })) + await _handle_chat_stream_message(websocket, message) else: print(f'Streaming api: unknown path: {path}') diff --git a/extensions/api/util.py b/extensions/api/util.py index 6896c29f..a89365ce 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -1,3 +1,6 @@ +import asyncio +import functools +import threading import time import traceback from threading import Thread @@ -8,6 +11,13 @@ from modules.chat import load_character_memoized from modules.presets import load_preset_memoized +# We use a thread local to store the asyncio lock, so that each thread +# has its own lock. This isn't strictly necessary, but it makes it +# such that if we can support multiple worker threads in the future, +# thus handling multiple requests in parallel. +api_tls = threading.local() + + def build_parameters(body, chat=False): generate_params = { @@ -97,3 +107,35 @@ def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Call time.sleep(3) raise Exception('Could not start cloudflared.') + + +def _get_api_lock(tls) -> asyncio.Lock: + """ + The streaming and blocking API implementations each run on their own + thread, and multiplex requests using asyncio. If multiple outstanding + requests are received at once, we will try to acquire the shared lock + shared.generation_lock multiple times in succession in the same thread, + which will cause a deadlock. + + To avoid this, we use this wrapper function to block on an asyncio + lock, and then try and grab the shared lock only while holding + the asyncio lock. + """ + if not hasattr(tls, "asyncio_lock"): + tls.asyncio_lock = asyncio.Lock() + + return tls.asyncio_lock + + +def with_api_lock(func): + """ + This decorator should be added to all streaming API methods which + require access to the shared.generation_lock. It ensures that the + tls.asyncio_lock is acquired before the method is called, and + released afterwards. + """ + @functools.wraps(func) + async def api_wrapper(*args, **kwargs): + async with _get_api_lock(api_tls): + return await func(*args, **kwargs) + return api_wrapper