fix for issue #2475: Streaming api deadlock (#3048)

This commit is contained in:
Chris Rude 2023-07-08 19:21:20 -07:00 committed by GitHub
parent 5ac4e4da8b
commit 70b088843d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 112 additions and 60 deletions

View File

@ -4,7 +4,7 @@ from threading import Thread
from websockets.server import serve
from extensions.api.util import build_parameters, try_start_cloudflared
from extensions.api.util import build_parameters, try_start_cloudflared, with_api_lock
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply
@ -12,10 +12,8 @@ from modules.text_generation import generate_reply
PATH = '/api/v1/stream'
async def _handle_connection(websocket, path):
if path == '/api/v1/stream':
async for message in websocket:
@with_api_lock
async def _handle_stream_message(websocket, message):
message = json.loads(message)
prompt = message['prompt']
@ -50,8 +48,9 @@ async def _handle_connection(websocket, path):
'message_num': message_num
}))
elif path == '/api/v1/chat-stream':
async for message in websocket:
@with_api_lock
async def _handle_chat_stream_message(websocket, message):
body = json.loads(message)
user_input = body['user_input']
@ -79,6 +78,17 @@ async def _handle_connection(websocket, path):
'message_num': message_num
}))
async def _handle_connection(websocket, path):
if path == '/api/v1/stream':
async for message in websocket:
await _handle_stream_message(websocket, message)
elif path == '/api/v1/chat-stream':
async for message in websocket:
await _handle_chat_stream_message(websocket, message)
else:
print(f'Streaming api: unknown path: {path}')
return

View File

@ -1,3 +1,6 @@
import asyncio
import functools
import threading
import time
import traceback
from threading import Thread
@ -8,6 +11,13 @@ from modules.chat import load_character_memoized
from modules.presets import load_preset_memoized
# We use a thread local to store the asyncio lock, so that each thread
# has its own lock. This isn't strictly necessary, but it makes it
# such that if we can support multiple worker threads in the future,
# thus handling multiple requests in parallel.
api_tls = threading.local()
def build_parameters(body, chat=False):
generate_params = {
@ -97,3 +107,35 @@ def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Call
time.sleep(3)
raise Exception('Could not start cloudflared.')
def _get_api_lock(tls) -> asyncio.Lock:
"""
The streaming and blocking API implementations each run on their own
thread, and multiplex requests using asyncio. If multiple outstanding
requests are received at once, we will try to acquire the shared lock
shared.generation_lock multiple times in succession in the same thread,
which will cause a deadlock.
To avoid this, we use this wrapper function to block on an asyncio
lock, and then try and grab the shared lock only while holding
the asyncio lock.
"""
if not hasattr(tls, "asyncio_lock"):
tls.asyncio_lock = asyncio.Lock()
return tls.asyncio_lock
def with_api_lock(func):
"""
This decorator should be added to all streaming API methods which
require access to the shared.generation_lock. It ensures that the
tls.asyncio_lock is acquired before the method is called, and
released afterwards.
"""
@functools.wraps(func)
async def api_wrapper(*args, **kwargs):
async with _get_api_lock(api_tls):
return await func(*args, **kwargs)
return api_wrapper