text-generation-webui/extensions/openai/script.py

794 lines
33 KiB
Python

import base64
import json
import os
import time
import requests
import yaml
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
import numpy as np
from modules import shared
from modules.text_generation import encode, generate_reply
params = {
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
}
debug = True if 'OPENEDAI_DEBUG' in os.environ else False
# Slightly different defaults for OpenAI's API
default_req_params = {
'max_new_tokens': 200,
'temperature': 1.0,
'top_p': 1.0,
'top_k': 1,
'repetition_penalty': 1.18,
'encoder_repetition_penalty': 1.0,
'suffix': None,
'stream': False,
'echo': False,
'seed': -1,
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
'truncation_length': 2048,
'add_bos_token': True,
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'tfs': 1.0,
'top_a': 0.0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0.0,
'length_penalty': 1,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'ban_eos_token': False,
'skip_special_tokens': True,
'custom_stopping_strings': [],
}
# Optional, install the module and download the model to enable
# v1/embeddings
try:
from sentence_transformers import SentenceTransformer
except ImportError:
pass
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
embedding_model = None
standard_stopping_strings = ['\nsystem:', '\nuser:', '\nhuman:', '\nassistant:', '\n###', ]
# little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default):
val = dic.get(key, default)
if type(val) != type(default):
# maybe it's just something like 1 instead of 1.0
try:
v = type(default)(val)
if type(val)(v) == val: # if it's the same value passed in, it's ok.
return v
except:
pass
val = default
return val
def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue))
def deduce_template():
# Alpaca is verbose so a good default prompt
default_template = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
# Use the special instruction/input/response template for anything trained like Alpaca
if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']:
return default_template
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
template = template\
.replace('<|user|>', instruct.get('user', ''))\
.replace('<|bot|>', instruct.get('bot', ''))\
.replace('<|user-message|>', '{instruction}\n{input}')
return instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
except:
return default_template
def float_list_to_base64(float_list):
# Convert the list to a float32 array that the OpenAPI client expects
float_array = np.array(float_list, dtype="float32")
# Get raw bytes
bytes_array = float_array.tobytes()
# Encode bytes into base64
encoded_bytes = base64.b64encode(bytes_array)
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode('ascii')
return ascii_string
class Handler(BaseHTTPRequestHandler):
def send_access_control_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Credentials", "true")
self.send_header(
"Access-Control-Allow-Methods",
"GET,HEAD,OPTIONS,POST,PUT"
)
self.send_header(
"Access-Control-Allow-Headers",
"Origin, Accept, X-Requested-With, Content-Type, "
"Access-Control-Request-Method, Access-Control-Request-Headers, "
"Authorization"
)
def do_OPTIONS(self):
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write("OK".encode('utf-8'))
def do_GET(self):
if self.path.startswith('/v1/models'):
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
# TODO: list all models and allow model changes via API? Lora's?
# This API should list capabilities, limits and pricing...
models = [{
"id": shared.model_name, # The real chat/completions model
"object": "model",
"owned_by": "user",
"permission": []
}, {
"id": st_model, # The real sentence transformer embeddings model
"object": "model",
"owned_by": "user",
"permission": []
}, { # these are expected by so much, so include some here as a dummy
"id": "gpt-3.5-turbo", # /v1/chat/completions
"object": "model",
"owned_by": "user",
"permission": []
}, {
"id": "text-curie-001", # /v1/completions, 2k context
"object": "model",
"owned_by": "user",
"permission": []
}, {
"id": "text-davinci-002", # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
"object": "model",
"owned_by": "user",
"permission": []
}]
response = ''
if self.path == '/v1/models':
response = json.dumps({
"object": "list",
"data": models,
})
else:
the_model_name = self.path[len('/v1/models/'):]
response = json.dumps({
"id": the_model_name,
"object": "model",
"owned_by": "user",
"permission": []
})
self.wfile.write(response.encode('utf-8'))
elif '/billing/usage' in self.path:
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
response = json.dumps({
"total_usage": 0,
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
if debug:
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if debug:
print(body)
if '/completions' in self.path or '/generate' in self.path:
is_legacy = '/generate' in self.path
is_chat = 'chat' in self.path
resp_list = 'data' if is_legacy else 'choices'
# XXX model is ignored for now
# model = body.get('model', shared.model_name) # ignored, use existing for now
model = shared.model_name
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
# Try to use openai defaults or map them to something with the same intent
stopping_strings = default(shared.settings, 'custom_stopping_strings', [])
if 'stop' in body:
if isinstance(body['stop'], str):
stopping_strings = [body['stop']]
elif isinstance(body['stop'], list):
stopping_strings = body['stop']
truncation_length = default(shared.settings, 'truncation_length', 2048)
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it.
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
req_params = default_req_params.copy()
req_params['max_new_tokens'] = max_tokens
req_params['truncation_length'] = truncation_length
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
req_params['top_k'] = default(body, 'best_of', default_req_params['top_k'])
req_params['suffix'] = default(body, 'suffix', default_req_params['suffix'])
req_params['stream'] = default(body, 'stream', default_req_params['stream'])
req_params['echo'] = default(body, 'echo', default_req_params['echo'])
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
self.send_response(200)
self.send_access_control_headers()
if req_params['stream']:
self.send_header('Content-Type', 'text/event-stream')
self.send_header('Cache-Control', 'no-cache')
# self.send_header('Connection', 'keep-alive')
else:
self.send_header('Content-Type', 'application/json')
self.end_headers()
token_count = 0
completion_token_count = 0
prompt = ''
stream_object_type = ''
object_type = ''
if is_chat:
# Chat Completions
stream_object_type = 'chat.completions.chunk'
object_type = 'chat.completions'
messages = body['messages']
role_formats = {
'user': 'user: {message}\n',
'assistant': 'assistant: {message}\n',
'system': '{message}',
'context': 'You are a helpful assistant. Answer as concisely as possible.',
'prompt': 'assistant:',
}
# Instruct models can be much better
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
system_message_template = "{message}"
system_message_default = instruct['context']
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
role_formats = {
'user': user_message_template,
'assistant': bot_message_template,
'system': system_message_template,
'context': system_message_default,
'prompt': bot_prompt,
}
if debug:
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
except:
if debug:
print("Loaded default role format.")
system_msgs = []
chat_msgs = []
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
if context_msg:
system_msgs.extend([context_msg])
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
if 'prompt' in body:
prompt_msg = role_formats['system'].format(message=body['prompt'])
system_msgs.extend([prompt_msg])
for m in messages:
role = m['role']
content = m['content']
msg = role_formats[role].format(message=content)
if role == 'system':
system_msgs.extend([msg])
else:
chat_msgs.extend([msg])
# can't really truncate the system messages
system_msg = '\n'.join(system_msgs)
if system_msg[-1] != '\n':
system_msg = system_msg + '\n'
system_token_count = len(encode(system_msg)[0])
remaining_tokens = req_params['truncation_length'] - system_token_count
chat_msg = ''
while chat_msgs:
new_msg = chat_msgs.pop()
new_size = len(encode(new_msg)[0])
if new_size <= remaining_tokens:
chat_msg = new_msg + chat_msg
remaining_tokens -= new_size
else:
print(f"Warning: too many messages for context size, dropping {len(chat_msgs) + 1} oldest message(s).")
break
prompt = system_msg + chat_msg + role_formats['prompt']
token_count = len(encode(prompt)[0])
else:
# Text Completions
stream_object_type = 'text_completion.chunk'
object_type = 'text_completion'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
if is_legacy:
prompt = body['context'] # Older engines.generate API
else:
prompt = body['prompt'] # XXX this can be different types
if isinstance(prompt, list):
prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls?
token_count = len(encode(prompt)[0])
if token_count >= req_params['truncation_length']:
new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count)
prompt = prompt[-new_len:]
new_token_count = len(encode(prompt)[0])
print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.")
token_count = new_token_count
if req_params['truncation_length'] - token_count < req_params['max_new_tokens']:
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}")
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
# pass with some expected stop strings.
# some strange cases of "##| Instruction: " sneaking through.
stopping_strings += standard_stopping_strings
req_params['custom_stopping_strings'] = stopping_strings
if req_params['stream']:
shared.args.chat = True
# begin streaming
chunk = {
"id": cmpl_id,
"object": stream_object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
}],
}
if stream_object_type == 'text_completion.chunk':
chunk[resp_list][0]["text"] = ""
else:
# So yeah... do both methods? delta and messages.
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
response = chunk_size + data_chunk
self.wfile.write(response.encode('utf-8'))
# generate reply #######################################
if debug:
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
seen_content = ''
longest_stop_len = max([len(x) for x in stopping_strings])
for a in generator:
answer = a
stop_string_found = False
len_seen = len(seen_content)
search_start = max(len_seen - longest_stop_len, 0)
for string in stopping_strings:
idx = answer.find(string, search_start)
if idx != -1:
answer = answer[:idx] # clip it.
stop_string_found = True
if stop_string_found:
break
# If something like "\nYo" is generated just before "\nYou:"
# is completed, buffer and generate more, don't send it
buffer_and_continue = False
for string in stopping_strings:
for j in range(len(string) - 1, 0, -1):
if answer[-j:] == string[:j]:
buffer_and_continue = True
break
else:
continue
break
if buffer_and_continue:
continue
if req_params['stream']:
# Streaming
new_content = answer[len_seen:]
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
chunk = {
"id": cmpl_id,
"object": stream_object_type,
"created": created_time,
"model": shared.model_name,
resp_list: [{
"index": 0,
"finish_reason": None,
}],
}
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
if stream_object_type == 'text_completion.chunk':
chunk[resp_list][0]['text'] = new_content
else:
# So yeah... do both methods? delta and messages.
chunk[resp_list][0]['message'] = {'content': new_content}
chunk[resp_list][0]['delta'] = {'content': new_content}
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
response = chunk_size + data_chunk
self.wfile.write(response.encode('utf-8'))
completion_token_count += len(encode(new_content)[0])
if req_params['stream']:
chunk = {
"id": cmpl_id,
"object": stream_object_type,
"created": created_time,
"model": model, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if stream_object_type == 'text_completion.chunk':
chunk[resp_list][0]['text'] = ''
else:
# So yeah... do both methods? delta and messages.
chunk[resp_list][0]['message'] = {'content': ''}
chunk[resp_list][0]['delta'] = {'content': ''}
data_chunk = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
chunk_size = hex(len(data_chunk))[2:] + '\r\n'
done = 'data: [DONE]\r\n\r\n'
response = chunk_size + data_chunk + done
self.wfile.write(response.encode('utf-8'))
# Finished if streaming.
if debug:
if answer and answer[0] == ' ':
answer = answer[1:]
print({'answer': answer}, chunk)
return
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
if debug:
print({'response': answer})
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length']:
stop_reason = "length"
resp = {
"id": cmpl_id,
"object": object_type,
"created": created_time,
"model": model, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if is_chat:
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
else:
resp[resp_list][0]["text"] = answer
response = json.dumps(resp)
self.wfile.write(response.encode('utf-8'))
elif '/edits' in self.path:
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
created_time = int(time.time())
# Using Alpaca format, this may work with other models too.
instruction = body['instruction']
input = body.get('input', '')
instruction_template = deduce_template()
edit_task = instruction_template.format(instruction=instruction, input=input)
truncation_length = default(shared.settings, 'truncation_length', 2048)
token_count = len(encode(edit_task)[0])
max_tokens = truncation_length - token_count
req_params = default_req_params.copy()
req_params['max_new_tokens'] = max_tokens
req_params['truncation_length'] = truncation_length
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
if debug:
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings, is_chat=False)
answer = ''
for a in generator:
answer = a
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
answer = answer[1:]
completion_token_count = len(encode(answer)[0])
resp = {
"object": "edit",
"created": created_time,
"choices": [{
"text": answer,
"index": 0,
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if debug:
print({'answer': answer, 'completion_token_count': completion_token_count})
response = json.dumps(resp)
self.wfile.write(response.encode('utf-8'))
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
# Stable Diffusion callout wrapper for txt2img
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
# Will probably work best with the stock SD models.
# SD configuration is beyond the scope of this API.
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
# require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later!
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] # ignore the restrictions on size
response_format = default(body, 'response_format', 'url') # or b64_json
payload = {
'prompt': body['prompt'], # ignore prompt limit of 1000 characters
'width': width,
'height': height,
'batch_size': default(body, 'n', 1) # ignore the batch limits of max 10
}
resp = {
'created': int(time.time()),
'data': []
}
# TODO: support SD_WEBUI_AUTH username:password pair.
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"
response = requests.post(url=sd_url, json=payload)
r = response.json()
# r['parameters']...
for b64_json in r['images']:
if response_format == 'b64_json':
resp['data'].extend([{'b64_json': b64_json}])
else:
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
response = json.dumps(resp)
self.wfile.write(response.encode('utf-8'))
elif '/embeddings' in self.path and embedding_model is not None:
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
input = body['input'] if 'input' in body else body['text']
if type(input) is str:
input = [input]
embeddings = embedding_model.encode(input).tolist()
def enc_emb(emb):
# If base64 is specified, encode. Otherwise, do nothing.
if body.get("encoding_format", "") == "base64":
return float_list_to_base64(emb)
else:
return emb
data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)]
response = json.dumps({
"object": "list",
"data": data,
"model": st_model, # return the real model
"usage": {
"prompt_tokens": 0,
"total_tokens": 0,
}
})
if debug:
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
self.wfile.write(response.encode('utf-8'))
elif '/moderations' in self.path:
# for now do nothing, just don't error.
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
response = json.dumps({
"id": "modr-5MWoLO",
"model": "text-moderation-001",
"results": [{
"categories": {
"hate": False,
"hate/threatening": False,
"self-harm": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False
},
"category_scores": {
"hate": 0.0,
"hate/threatening": 0.0,
"self-harm": 0.0,
"sexual": 0.0,
"sexual/minors": 0.0,
"violence": 0.0,
"violence/graphic": 0.0
},
"flagged": False
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
tokens = encode(body['prompt'])[0]
response = json.dumps({
'results': [{
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
print(self.path, self.headers)
self.send_error(404)
def run_server():
global embedding_model
try:
embedding_model = SentenceTransformer(st_model)
print(f"\nLoaded embedding model: {st_model}, max sequence length: {embedding_model.max_seq_length}")
except:
print(f"\nFailed to load embedding model: {st_model}")
pass
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
server = ThreadingHTTPServer(server_addr, Handler)
if shared.args.share:
try:
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(params['port'], params['port'] + 1)
print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1')
except ImportError:
print('You should install flask_cloudflared manually')
else:
print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
server.serve_forever()
def setup():
Thread(target=run_server, daemon=True).start()