mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
parent
5ac4e4da8b
commit
70b088843d
@ -4,7 +4,7 @@ from threading import Thread
|
|||||||
|
|
||||||
from websockets.server import serve
|
from websockets.server import serve
|
||||||
|
|
||||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
from extensions.api.util import build_parameters, try_start_cloudflared, with_api_lock
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.chat import generate_chat_reply
|
from modules.chat import generate_chat_reply
|
||||||
from modules.text_generation import generate_reply
|
from modules.text_generation import generate_reply
|
||||||
@ -12,72 +12,82 @@ from modules.text_generation import generate_reply
|
|||||||
PATH = '/api/v1/stream'
|
PATH = '/api/v1/stream'
|
||||||
|
|
||||||
|
|
||||||
|
@with_api_lock
|
||||||
|
async def _handle_stream_message(websocket, message):
|
||||||
|
message = json.loads(message)
|
||||||
|
|
||||||
|
prompt = message['prompt']
|
||||||
|
generate_params = build_parameters(message)
|
||||||
|
stopping_strings = generate_params.pop('stopping_strings')
|
||||||
|
generate_params['stream'] = True
|
||||||
|
|
||||||
|
generator = generate_reply(
|
||||||
|
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
# As we stream, only send the new bytes.
|
||||||
|
skip_index = 0
|
||||||
|
message_num = 0
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
to_send = a[skip_index:]
|
||||||
|
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
|
||||||
|
continue
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
'event': 'text_stream',
|
||||||
|
'message_num': message_num,
|
||||||
|
'text': to_send
|
||||||
|
}))
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
skip_index += len(to_send)
|
||||||
|
message_num += 1
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
'event': 'stream_end',
|
||||||
|
'message_num': message_num
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
@with_api_lock
|
||||||
|
async def _handle_chat_stream_message(websocket, message):
|
||||||
|
body = json.loads(message)
|
||||||
|
|
||||||
|
user_input = body['user_input']
|
||||||
|
generate_params = build_parameters(body, chat=True)
|
||||||
|
generate_params['stream'] = True
|
||||||
|
regenerate = body.get('regenerate', False)
|
||||||
|
_continue = body.get('_continue', False)
|
||||||
|
|
||||||
|
generator = generate_chat_reply(
|
||||||
|
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||||
|
|
||||||
|
message_num = 0
|
||||||
|
for a in generator:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
'event': 'text_stream',
|
||||||
|
'message_num': message_num,
|
||||||
|
'history': a
|
||||||
|
}))
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
message_num += 1
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
'event': 'stream_end',
|
||||||
|
'message_num': message_num
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
async def _handle_connection(websocket, path):
|
async def _handle_connection(websocket, path):
|
||||||
|
|
||||||
if path == '/api/v1/stream':
|
if path == '/api/v1/stream':
|
||||||
async for message in websocket:
|
async for message in websocket:
|
||||||
message = json.loads(message)
|
await _handle_stream_message(websocket, message)
|
||||||
|
|
||||||
prompt = message['prompt']
|
|
||||||
generate_params = build_parameters(message)
|
|
||||||
stopping_strings = generate_params.pop('stopping_strings')
|
|
||||||
generate_params['stream'] = True
|
|
||||||
|
|
||||||
generator = generate_reply(
|
|
||||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
# As we stream, only send the new bytes.
|
|
||||||
skip_index = 0
|
|
||||||
message_num = 0
|
|
||||||
|
|
||||||
for a in generator:
|
|
||||||
to_send = a[skip_index:]
|
|
||||||
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
|
|
||||||
continue
|
|
||||||
|
|
||||||
await websocket.send(json.dumps({
|
|
||||||
'event': 'text_stream',
|
|
||||||
'message_num': message_num,
|
|
||||||
'text': to_send
|
|
||||||
}))
|
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
skip_index += len(to_send)
|
|
||||||
message_num += 1
|
|
||||||
|
|
||||||
await websocket.send(json.dumps({
|
|
||||||
'event': 'stream_end',
|
|
||||||
'message_num': message_num
|
|
||||||
}))
|
|
||||||
|
|
||||||
elif path == '/api/v1/chat-stream':
|
elif path == '/api/v1/chat-stream':
|
||||||
async for message in websocket:
|
async for message in websocket:
|
||||||
body = json.loads(message)
|
await _handle_chat_stream_message(websocket, message)
|
||||||
|
|
||||||
user_input = body['user_input']
|
|
||||||
generate_params = build_parameters(body, chat=True)
|
|
||||||
generate_params['stream'] = True
|
|
||||||
regenerate = body.get('regenerate', False)
|
|
||||||
_continue = body.get('_continue', False)
|
|
||||||
|
|
||||||
generator = generate_chat_reply(
|
|
||||||
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
|
||||||
|
|
||||||
message_num = 0
|
|
||||||
for a in generator:
|
|
||||||
await websocket.send(json.dumps({
|
|
||||||
'event': 'text_stream',
|
|
||||||
'message_num': message_num,
|
|
||||||
'history': a
|
|
||||||
}))
|
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
message_num += 1
|
|
||||||
|
|
||||||
await websocket.send(json.dumps({
|
|
||||||
'event': 'stream_end',
|
|
||||||
'message_num': message_num
|
|
||||||
}))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f'Streaming api: unknown path: {path}')
|
print(f'Streaming api: unknown path: {path}')
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -8,6 +11,13 @@ from modules.chat import load_character_memoized
|
|||||||
from modules.presets import load_preset_memoized
|
from modules.presets import load_preset_memoized
|
||||||
|
|
||||||
|
|
||||||
|
# We use a thread local to store the asyncio lock, so that each thread
|
||||||
|
# has its own lock. This isn't strictly necessary, but it makes it
|
||||||
|
# such that if we can support multiple worker threads in the future,
|
||||||
|
# thus handling multiple requests in parallel.
|
||||||
|
api_tls = threading.local()
|
||||||
|
|
||||||
|
|
||||||
def build_parameters(body, chat=False):
|
def build_parameters(body, chat=False):
|
||||||
|
|
||||||
generate_params = {
|
generate_params = {
|
||||||
@ -97,3 +107,35 @@ def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Call
|
|||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
||||||
raise Exception('Could not start cloudflared.')
|
raise Exception('Could not start cloudflared.')
|
||||||
|
|
||||||
|
|
||||||
|
def _get_api_lock(tls) -> asyncio.Lock:
|
||||||
|
"""
|
||||||
|
The streaming and blocking API implementations each run on their own
|
||||||
|
thread, and multiplex requests using asyncio. If multiple outstanding
|
||||||
|
requests are received at once, we will try to acquire the shared lock
|
||||||
|
shared.generation_lock multiple times in succession in the same thread,
|
||||||
|
which will cause a deadlock.
|
||||||
|
|
||||||
|
To avoid this, we use this wrapper function to block on an asyncio
|
||||||
|
lock, and then try and grab the shared lock only while holding
|
||||||
|
the asyncio lock.
|
||||||
|
"""
|
||||||
|
if not hasattr(tls, "asyncio_lock"):
|
||||||
|
tls.asyncio_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
return tls.asyncio_lock
|
||||||
|
|
||||||
|
|
||||||
|
def with_api_lock(func):
|
||||||
|
"""
|
||||||
|
This decorator should be added to all streaming API methods which
|
||||||
|
require access to the shared.generation_lock. It ensures that the
|
||||||
|
tls.asyncio_lock is acquired before the method is called, and
|
||||||
|
released afterwards.
|
||||||
|
"""
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def api_wrapper(*args, **kwargs):
|
||||||
|
async with _get_api_lock(api_tls):
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
return api_wrapper
|
||||||
|
Loading…
Reference in New Issue
Block a user