text-generation-webui/extensions/openai/script.py

270 lines
9.6 KiB
Python
Raw Normal View History

2023-05-02 22:05:38 -04:00
import json
import os
import traceback
2023-05-02 21:49:53 -04:00
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
2023-05-09 21:49:39 -04:00
2023-05-02 21:49:53 -04:00
from modules import shared
from extensions.openai.tokens import token_count, token_encode, token_decode
import extensions.openai.models as OAImodels
import extensions.openai.edits as OAIedits
import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages
import extensions.openai.moderations as OAImoderations
import extensions.openai.completions as OAIcompletions
from extensions.openai.errors import *
from extensions.openai.utils import debug_msg
from extensions.openai.defaults import (get_default_req_params, default, clamp)
2023-05-02 21:49:53 -04:00
params = {
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
2023-05-02 21:49:53 -04:00
}
2023-07-12 14:33:25 -04:00
2023-05-02 21:49:53 -04:00
class Handler(BaseHTTPRequestHandler):
def send_access_control_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Credentials", "true")
self.send_header(
"Access-Control-Allow-Methods",
"GET,HEAD,OPTIONS,POST,PUT"
)
self.send_header(
"Access-Control-Allow-Headers",
"Origin, Accept, X-Requested-With, Content-Type, "
"Access-Control-Request-Method, Access-Control-Request-Headers, "
"Authorization"
)
def do_OPTIONS(self):
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write("OK".encode('utf-8'))
def start_sse(self):
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'text/event-stream')
self.send_header('Cache-Control', 'no-cache')
# self.send_header('Connection', 'keep-alive')
self.end_headers()
def send_sse(self, chunk: dict):
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
debug_msg(response)
self.wfile.write(response.encode('utf-8'))
def end_sse(self):
self.wfile.write('data: [DONE]\r\n\r\n'.encode('utf-8'))
def return_json(self, ret: dict, code: int = 200, no_debug=False):
self.send_response(code)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
response = json.dumps(ret)
r_utf8 = response.encode('utf-8')
self.wfile.write(r_utf8)
if not no_debug:
debug_msg(r_utf8)
2023-07-12 14:33:25 -04:00
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
error_resp = {
'error': {
'message': message,
'code': code,
'type': error_type,
'param': param,
}
}
if internal_message:
print(internal_message)
2023-07-12 14:33:25 -04:00
# error_resp['internal_message'] = internal_message
self.return_json(error_resp, code)
2023-07-12 14:33:25 -04:00
def openai_error_handler(func):
def wrapper(self):
try:
func(self)
except ServiceUnavailableError as e:
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
except InvalidRequestError as e:
self.openai_error(e.message, e.code, e.error_type, e.param, internal_message=e.internal_message)
except OpenAIError as e:
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
except Exception as e:
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
return wrapper
@openai_error_handler
2023-05-02 21:49:53 -04:00
def do_GET(self):
debug_msg(self.requestline)
debug_msg(self.headers)
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
is_legacy = 'engines' in self.path
is_list = self.path in ['/v1/engines', '/v1/models']
if is_legacy and not is_list:
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
resp = OAImodels.load_model(model_name)
elif is_list:
resp = OAImodels.list_models(is_legacy)
2023-05-02 21:49:53 -04:00
else:
model_name = self.path[len('/v1/models/'):]
resp = OAImodels.model_info()
self.return_json(resp)
elif '/billing/usage' in self.path:
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
self.return_json({"total_usage": 0}, no_debug=True)
2023-05-02 21:49:53 -04:00
else:
self.send_error(404)
@openai_error_handler
2023-05-02 21:49:53 -04:00
def do_POST(self):
debug_msg(self.requestline)
debug_msg(self.headers)
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
debug_msg(body)
2023-05-02 21:49:53 -04:00
if '/completions' in self.path or '/generate' in self.path:
if not shared.model:
self.openai_error("No model loaded.")
return
2023-05-02 21:49:53 -04:00
is_legacy = '/generate' in self.path
is_streaming = body.get('stream', False)
2023-05-02 21:49:53 -04:00
if is_streaming:
self.start_sse()
2023-05-02 21:49:53 -04:00
response = []
if 'chat' in self.path:
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
2023-05-02 21:49:53 -04:00
else:
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
2023-07-12 14:33:25 -04:00
for resp in response:
self.send_sse(resp)
2023-05-02 22:05:38 -04:00
self.end_sse()
2023-05-02 21:49:53 -04:00
else:
response = ''
if 'chat' in self.path:
response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
else:
response = OAIcompletions.completions(body, is_legacy=is_legacy)
2023-05-02 21:49:53 -04:00
self.return_json(response)
elif '/edits' in self.path:
# deprecated
if not shared.model:
self.openai_error("No model loaded.")
return
req_params = get_default_req_params()
instruction = body['instruction']
input = body.get('input', '')
2023-07-12 14:33:25 -04:00
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
response = OAIedits.edits(instruction, input, temperature, top_p)
self.return_json(response)
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
prompt = body['prompt']
size = default(body, 'size', '1024x1024')
response_format = default(body, 'response_format', 'url') # or b64_json
n = default(body, 'n', 1) # ignore the batch limits of max 10
response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
self.return_json(response, no_debug=True)
elif '/embeddings' in self.path:
encoding_format = body.get('encoding_format', '')
2023-05-02 21:49:53 -04:00
input = body.get('input', body.get('text', ''))
if not input:
raise InvalidRequestError("Missing required argument input", params='input')
2023-07-12 14:33:25 -04:00
2023-05-02 21:49:53 -04:00
if type(input) is str:
input = [input]
response = OAIembeddings.embeddings(input, encoding_format)
2023-05-02 21:49:53 -04:00
self.return_json(response, no_debug=True)
2023-05-02 21:49:53 -04:00
elif '/moderations' in self.path:
input = body['input']
if not input:
raise InvalidRequestError("Missing required argument input", params='input')
response = OAImoderations.moderations(input)
self.return_json(response, no_debug=True)
2023-05-02 21:49:53 -04:00
elif self.path == '/api/v1/token-count':
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
response = token_count(body['prompt'])
2023-07-12 14:33:25 -04:00
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token/encode':
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
encoding_format = body.get('encoding_format', '')
2023-07-12 14:33:25 -04:00
response = token_encode(body['input'], encoding_format)
2023-07-12 14:33:25 -04:00
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token/decode':
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
encoding_format = body.get('encoding_format', '')
response = token_decode(body['input'], encoding_format)
2023-07-12 14:33:25 -04:00
self.return_json(response, no_debug=True)
2023-05-02 21:49:53 -04:00
else:
self.send_error(404)
def run_server():
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
server = ThreadingHTTPServer(server_addr, Handler)
if shared.args.share:
try:
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(params['port'], params['port'] + 1)
print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1')
2023-05-02 21:49:53 -04:00
except ImportError:
print('You should install flask_cloudflared manually')
else:
print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
2023-05-02 21:49:53 -04:00
server.serve_forever()
def setup():
Thread(target=run_server, daemon=True).start()