text-generation-webui/extensions/api/streaming_api.py

112 lines
3.5 KiB
Python
Raw Normal View History

import asyncio
2023-05-10 01:49:39 +00:00
import json
from threading import Thread
2023-05-10 01:49:39 +00:00
from websockets.server import serve
from extensions.api.util import build_parameters, try_start_cloudflared
2023-05-10 01:49:39 +00:00
from modules import shared
2023-05-20 21:42:17 +00:00
from modules.chat import generate_chat_reply
2023-05-10 01:49:39 +00:00
from modules.text_generation import generate_reply
PATH = '/api/v1/stream'
async def _handle_connection(websocket, path):
2023-05-20 21:42:17 +00:00
if path == '/api/v1/stream':
async for message in websocket:
message = json.loads(message)
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
2023-05-20 21:42:17 +00:00
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
2023-05-20 21:42:17 +00:00
# As we stream, only send the new bytes.
skip_index = 0
message_num = 0
2023-05-20 21:42:17 +00:00
for a in generator:
to_send = a[skip_index:]
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
continue
2023-05-20 21:42:17 +00:00
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
}))
2023-05-20 21:42:17 +00:00
await asyncio.sleep(0)
skip_index += len(to_send)
message_num += 1
await websocket.send(json.dumps({
2023-05-20 21:42:17 +00:00
'event': 'stream_end',
'message_num': message_num
}))
2023-05-20 21:42:17 +00:00
elif path == '/api/v1/chat-stream':
async for message in websocket:
body = json.loads(message)
user_input = body['user_input']
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = True
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generator = generate_chat_reply(
2023-07-05 00:36:47 +00:00
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
2023-05-20 21:42:17 +00:00
message_num = 0
for a in generator:
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'history': a
}))
2023-04-24 06:51:32 +00:00
2023-05-20 21:42:17 +00:00
await asyncio.sleep(0)
message_num += 1
2023-05-20 21:42:17 +00:00
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
else:
print(f'Streaming api: unknown path: {path}')
return
async def _run(host: str, port: int):
2023-05-02 22:03:19 +00:00
async with serve(_handle_connection, host, port, ping_interval=None):
await asyncio.Future() # run forever
def _run_server(port: int, share: bool = False):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
def on_start(public_url: str):
public_url = public_url.replace('https://', 'wss://')
print(f'Starting streaming server at public url {public_url}{PATH}')
if share:
try:
try_start_cloudflared(port, max_attempts=3, on_start=on_start)
except Exception as e:
print(e)
else:
print(f'Starting streaming server at ws://{address}:{port}{PATH}')
asyncio.run(_run(host=address, port=port))
def start_server(port: int, share: bool = False):
Thread(target=_run_server, args=[port, share], daemon=True).start()