mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
extensions/openai: Major openai extension updates & fixes (#3049)
* many openai updates * total reorg & cleanup. * fixups * missing import os for images * +moderations, custom_stopping_strings, more fixes * fix bugs in completion streaming * moderation fix (flagged) * updated moderation categories --------- Co-authored-by: Matthew Ashton <mashton-gitlab@zhero.org>
This commit is contained in:
parent
8db7e857b1
commit
3e7feb699c
@ -218,12 +218,11 @@ but there are some exceptions.
|
|||||||
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
|
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
|
||||||
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
||||||
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
|
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
|
||||||
|
|
||||||
## Future plans
|
## Future plans
|
||||||
* better error handling
|
|
||||||
* model changing, esp. something for swapping loras or embedding models
|
* model changing, esp. something for swapping loras or embedding models
|
||||||
* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
|
* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
|
||||||
* do something about rate limiting or locking requests for completions, most systems will only be able handle a single request at a time before OOM
|
|
||||||
|
|
||||||
## Bugs? Feedback? Comments? Pull requests?
|
## Bugs? Feedback? Comments? Pull requests?
|
||||||
|
|
||||||
|
599
extensions/openai/completions.py
Normal file
599
extensions/openai/completions.py
Normal file
@ -0,0 +1,599 @@
|
|||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
import tiktoken
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from transformers import LogitsProcessor, LogitsProcessorList
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.text_generation import encode, decode, generate_reply
|
||||||
|
|
||||||
|
from extensions.openai.defaults import get_default_req_params, default, clamp
|
||||||
|
from extensions.openai.utils import end_line, debug_msg
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
|
||||||
|
|
||||||
|
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
|
||||||
|
class LogitsBiasProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, logit_bias={}):
|
||||||
|
self.logit_bias = logit_bias
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if self.logit_bias:
|
||||||
|
keys = list([int(key) for key in self.logit_bias.keys()])
|
||||||
|
values = list([int(val) for val in self.logit_bias.values()])
|
||||||
|
logits[0, keys] += torch.tensor(values).cuda()
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class LogprobProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, logprobs=None):
|
||||||
|
self.logprobs = logprobs
|
||||||
|
self.token_alternatives = {}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if self.logprobs is not None: # 0-5
|
||||||
|
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||||
|
# XXX hack. should find the selected token and include the prob of that
|
||||||
|
# ... but we just +1 here instead because we don't know it yet.
|
||||||
|
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||||
|
top_tokens = [ decode(tok) for tok in top_indices[0] ]
|
||||||
|
self.token_alternatives = dict(zip(top_tokens, top_values[0].tolist()))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(model)
|
||||||
|
# just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||||
|
return dict([ (encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items() ])
|
||||||
|
except KeyError:
|
||||||
|
# assume native tokens if we can't find the tokenizer
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def marshal_common_params(body):
|
||||||
|
# Request Parameters
|
||||||
|
# Try to use openai defaults or map them to something with the same intent
|
||||||
|
|
||||||
|
req_params = get_default_req_params()
|
||||||
|
|
||||||
|
# Common request parameters
|
||||||
|
req_params['truncation_length'] = shared.settings['truncation_length']
|
||||||
|
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
||||||
|
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
||||||
|
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||||
|
|
||||||
|
# OpenAI API Parameters
|
||||||
|
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
|
||||||
|
req_params['requested_model'] = body.get('model', shared.model_name)
|
||||||
|
|
||||||
|
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||||
|
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0/2.0
|
||||||
|
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||||
|
n = default(body, 'n', 1)
|
||||||
|
if n != 1:
|
||||||
|
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
|
||||||
|
|
||||||
|
if 'stop' in body: # str or array, max len 4 (ignored)
|
||||||
|
if isinstance(body['stop'], str):
|
||||||
|
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
|
||||||
|
elif isinstance(body['stop'], list):
|
||||||
|
req_params['stopping_strings'] = body['stop']
|
||||||
|
|
||||||
|
# presence_penalty - ignored
|
||||||
|
# frequency_penalty - ignored
|
||||||
|
# user - ignored
|
||||||
|
|
||||||
|
logits_processor = []
|
||||||
|
logit_bias = body.get('logit_bias', None)
|
||||||
|
if logit_bias: # {str: float, ...}
|
||||||
|
# XXX convert tokens from tiktoken based on requested model
|
||||||
|
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
|
||||||
|
new_logit_bias = {}
|
||||||
|
for logit, bias in logit_bias.items():
|
||||||
|
for x in encode(encoder.decode([int(logit)]))[0]:
|
||||||
|
new_logit_bias[str(int(x))] = bias
|
||||||
|
print(logit_bias, '->', new_logit_bias)
|
||||||
|
logit_bias = new_logit_bias
|
||||||
|
except KeyError:
|
||||||
|
pass # assume native tokens if we can't find the tokenizer
|
||||||
|
|
||||||
|
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||||
|
|
||||||
|
logprobs = None # coming to chat eventually
|
||||||
|
if 'logprobs' in body:
|
||||||
|
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||||
|
req_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||||
|
logits_processor.extend([req_params['logprob_proc']])
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
if logits_processor: # requires logits_processor support
|
||||||
|
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||||
|
|
||||||
|
return req_params
|
||||||
|
|
||||||
|
|
||||||
|
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||||
|
# functions
|
||||||
|
if body.get('functions', []): # chat only
|
||||||
|
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||||
|
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
|
||||||
|
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
||||||
|
|
||||||
|
if not 'messages' in body:
|
||||||
|
raise InvalidRequestError(message="messages is required", param='messages')
|
||||||
|
|
||||||
|
messages = body['messages']
|
||||||
|
|
||||||
|
role_formats = {
|
||||||
|
'user': 'user: {message}\n',
|
||||||
|
'assistant': 'assistant: {message}\n',
|
||||||
|
'system': '{message}',
|
||||||
|
'context': 'You are a helpful assistant. Answer as concisely as possible.',
|
||||||
|
'prompt': 'assistant:',
|
||||||
|
}
|
||||||
|
|
||||||
|
if not 'stopping_strings' in req_params:
|
||||||
|
req_params['stopping_strings'] = []
|
||||||
|
|
||||||
|
# Instruct models can be much better
|
||||||
|
if shared.settings['instruction_template']:
|
||||||
|
try:
|
||||||
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
|
template = instruct['turn_template']
|
||||||
|
system_message_template = "{message}"
|
||||||
|
system_message_default = instruct['context']
|
||||||
|
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||||
|
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
||||||
|
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
||||||
|
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||||
|
|
||||||
|
role_formats = {
|
||||||
|
'user': user_message_template,
|
||||||
|
'assistant': bot_message_template,
|
||||||
|
'system': system_message_template,
|
||||||
|
'context': system_message_default,
|
||||||
|
'prompt': bot_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'Alpaca' in shared.settings['instruction_template']:
|
||||||
|
req_params['stopping_strings'].extend(['\n###'])
|
||||||
|
elif instruct['user']: # WizardLM and some others have no user prompt.
|
||||||
|
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||||
|
|
||||||
|
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
req_params['stopping_strings'].extend(['\nuser:'])
|
||||||
|
|
||||||
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||||
|
print("Warning: Loaded default instruction-following template for model.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
req_params['stopping_strings'].extend(['\nuser:'])
|
||||||
|
print("Warning: Loaded default instruction-following template for model.")
|
||||||
|
|
||||||
|
system_msgs = []
|
||||||
|
chat_msgs = []
|
||||||
|
|
||||||
|
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
||||||
|
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
|
||||||
|
context_msg = end_line(context_msg)
|
||||||
|
|
||||||
|
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
||||||
|
if 'prompt' in body:
|
||||||
|
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
|
||||||
|
|
||||||
|
for m in messages:
|
||||||
|
role = m['role']
|
||||||
|
content = m['content']
|
||||||
|
# name = m.get('name', None)
|
||||||
|
# function_call = m.get('function_call', None) # user name or function name with output in content
|
||||||
|
msg = role_formats[role].format(message=content)
|
||||||
|
if role == 'system':
|
||||||
|
system_msgs.extend([msg])
|
||||||
|
elif role == 'function':
|
||||||
|
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||||
|
else:
|
||||||
|
chat_msgs.extend([msg])
|
||||||
|
|
||||||
|
system_msg = '\n'.join(system_msgs)
|
||||||
|
system_msg = end_line(system_msg)
|
||||||
|
|
||||||
|
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt']
|
||||||
|
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
|
if token_count >= req_params['truncation_length']:
|
||||||
|
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
|
||||||
|
raise InvalidRequestError(message=err_msg)
|
||||||
|
|
||||||
|
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
|
||||||
|
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
|
||||||
|
print(f"Warning: ${err_msg}")
|
||||||
|
#raise InvalidRequestError(message=err_msg)
|
||||||
|
|
||||||
|
return prompt, token_count
|
||||||
|
|
||||||
|
|
||||||
|
def chat_completions(body: dict, is_legacy: bool=False) -> dict:
|
||||||
|
# Chat Completions
|
||||||
|
object_type = 'chat.completions'
|
||||||
|
created_time = int(time.time())
|
||||||
|
cmpl_id = "chatcmpl-%d" % (int(time.time()*1000000000))
|
||||||
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
|
# common params
|
||||||
|
req_params = marshal_common_params(body)
|
||||||
|
req_params['stream'] = False
|
||||||
|
requested_model = req_params.pop('requested_model')
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||||
|
|
||||||
|
# chat default max_tokens is 'inf', but also flexible
|
||||||
|
max_tokens = 0
|
||||||
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
|
if max_tokens_str in body:
|
||||||
|
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
||||||
|
req_params['max_new_tokens'] = max_tokens
|
||||||
|
else:
|
||||||
|
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||||
|
|
||||||
|
# format the prompt from messages
|
||||||
|
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
answer = ''
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if answer and answer[0] == ' ':
|
||||||
|
answer = answer[1:]
|
||||||
|
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name, # TODO: add Lora info?
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": stop_reason,
|
||||||
|
"message": {"role": "assistant", "content": answer}
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if logprob_proc: # not official for chat yet
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
# else:
|
||||||
|
# resp[resp_list][0]["logprobs"] = None
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
# generator
|
||||||
|
def stream_chat_completions(body: dict, is_legacy: bool=False):
|
||||||
|
|
||||||
|
# Chat Completions
|
||||||
|
stream_object_type = 'chat.completions.chunk'
|
||||||
|
created_time = int(time.time())
|
||||||
|
cmpl_id = "chatcmpl-%d" % (int(time.time()*1000000000))
|
||||||
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
|
# common params
|
||||||
|
req_params = marshal_common_params(body)
|
||||||
|
req_params['stream'] = True
|
||||||
|
requested_model = req_params.pop('requested_model')
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||||
|
|
||||||
|
# chat default max_tokens is 'inf', but also flexible
|
||||||
|
max_tokens = 0
|
||||||
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
|
if max_tokens_str in body:
|
||||||
|
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
||||||
|
req_params['max_new_tokens'] = max_tokens
|
||||||
|
else:
|
||||||
|
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||||
|
|
||||||
|
# format the prompt from messages
|
||||||
|
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||||
|
|
||||||
|
def chat_streaming_chunk(content):
|
||||||
|
# begin streaming
|
||||||
|
chunk = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": stream_object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": None,
|
||||||
|
# So yeah... do both methods? delta and messages.
|
||||||
|
"message": {'role': 'assistant', 'content': content},
|
||||||
|
"delta": {'role': 'assistant', 'content': content},
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
if logprob_proc: # not official for chat yet
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
#else:
|
||||||
|
# chunk[resp_list][0]["logprobs"] = None
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
yield chat_streaming_chunk('')
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
|
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
|
||||||
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
answer = ''
|
||||||
|
seen_content = ''
|
||||||
|
completion_token_count = 0
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
len_seen = len(seen_content)
|
||||||
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
|
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen_content = answer
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if len_seen == 0 and new_content[0] == ' ':
|
||||||
|
new_content = new_content[1:]
|
||||||
|
|
||||||
|
completion_token_count += len(encode(new_content)[0])
|
||||||
|
chunk = chat_streaming_chunk(new_content)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
chunk = chat_streaming_chunk('')
|
||||||
|
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||||
|
chunk['usage'] = {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
def completions(body: dict, is_legacy: bool=False):
|
||||||
|
# Legacy
|
||||||
|
# Text Completions
|
||||||
|
object_type = 'text_completion'
|
||||||
|
created_time = int(time.time())
|
||||||
|
cmpl_id = "conv-%d" % (int(time.time()*1000000000))
|
||||||
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
|
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||||
|
prompt_str = 'context' if is_legacy else 'prompt'
|
||||||
|
if not prompt_str in body:
|
||||||
|
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||||
|
|
||||||
|
prompt = body[prompt_str]
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
if prompt and isinstance(prompt[0], int):
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(requested_model)
|
||||||
|
prompt = encode(encoder.decode(prompt))[0]
|
||||||
|
except KeyError:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
else:
|
||||||
|
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||||
|
|
||||||
|
# common params
|
||||||
|
req_params = marshal_common_params(body)
|
||||||
|
req_params['stream'] = False
|
||||||
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
|
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||||
|
req_params['max_new_tokens'] = max_tokens
|
||||||
|
requested_model = req_params.pop('requested_model')
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
|
if token_count + max_tokens > req_params['truncation_length']:
|
||||||
|
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||||
|
#print(f"Warning: ${err_msg}")
|
||||||
|
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||||
|
|
||||||
|
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||||
|
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
answer = ''
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if answer and answer[0] == ' ':
|
||||||
|
answer = answer[1:]
|
||||||
|
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name, # TODO: add Lora info?
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": stop_reason,
|
||||||
|
"text": answer,
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if logprob_proc:
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
else:
|
||||||
|
resp[resp_list][0]["logprobs"] = None
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
# generator
|
||||||
|
def stream_completions(body: dict, is_legacy: bool=False):
|
||||||
|
# Legacy
|
||||||
|
# Text Completions
|
||||||
|
#object_type = 'text_completion'
|
||||||
|
stream_object_type = 'text_completion.chunk'
|
||||||
|
created_time = int(time.time())
|
||||||
|
cmpl_id = "conv-%d" % (int(time.time()*1000000000))
|
||||||
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
|
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||||
|
prompt_str = 'context' if is_legacy else 'prompt'
|
||||||
|
if not prompt_str in body:
|
||||||
|
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||||
|
|
||||||
|
prompt = body[prompt_str]
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
if prompt and isinstance(prompt[0], int):
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(requested_model)
|
||||||
|
prompt = encode(encoder.decode(prompt))[0]
|
||||||
|
except KeyError:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
else:
|
||||||
|
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||||
|
|
||||||
|
# common params
|
||||||
|
req_params = marshal_common_params(body)
|
||||||
|
req_params['stream'] = True
|
||||||
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
|
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||||
|
req_params['max_new_tokens'] = max_tokens
|
||||||
|
requested_model = req_params.pop('requested_model')
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
|
if token_count + max_tokens > req_params['truncation_length']:
|
||||||
|
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||||
|
#print(f"Warning: ${err_msg}")
|
||||||
|
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||||
|
|
||||||
|
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||||
|
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||||
|
|
||||||
|
def text_streaming_chunk(content):
|
||||||
|
# begin streaming
|
||||||
|
chunk = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": stream_object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": None,
|
||||||
|
"text": content,
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
if logprob_proc:
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
else:
|
||||||
|
chunk[resp_list][0]["logprobs"] = None
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
yield text_streaming_chunk('')
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
answer = ''
|
||||||
|
seen_content = ''
|
||||||
|
completion_token_count = 0
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
len_seen = len(seen_content)
|
||||||
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
|
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen_content = answer
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if len_seen == 0 and new_content[0] == ' ':
|
||||||
|
new_content = new_content[1:]
|
||||||
|
|
||||||
|
chunk = text_streaming_chunk(new_content)
|
||||||
|
|
||||||
|
completion_token_count += len(encode(new_content)[0])
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
chunk = text_streaming_chunk('')
|
||||||
|
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||||
|
chunk["usage"] = {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
|
||||||
|
yield chunk
|
64
extensions/openai/defaults.py
Normal file
64
extensions/openai/defaults.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
|
# Slightly different defaults for OpenAI's API
|
||||||
|
# Data type is important, Ex. use 0.0 for a float 0
|
||||||
|
default_req_params = {
|
||||||
|
'max_new_tokens': 16, # 'Inf' for chat
|
||||||
|
'temperature': 1.0,
|
||||||
|
'top_p': 1.0,
|
||||||
|
'top_k': 1, # choose 20 for chat in absence of another default
|
||||||
|
'repetition_penalty': 1.18,
|
||||||
|
'repetition_penalty_range': 0,
|
||||||
|
'encoder_repetition_penalty': 1.0,
|
||||||
|
'suffix': None,
|
||||||
|
'stream': False,
|
||||||
|
'echo': False,
|
||||||
|
'seed': -1,
|
||||||
|
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
||||||
|
'truncation_length': 2048, # first use shared.settings value
|
||||||
|
'add_bos_token': True,
|
||||||
|
'do_sample': True,
|
||||||
|
'typical_p': 1.0,
|
||||||
|
'epsilon_cutoff': 0.0, # In units of 1e-4
|
||||||
|
'eta_cutoff': 0.0, # In units of 1e-4
|
||||||
|
'tfs': 1.0,
|
||||||
|
'top_a': 0.0,
|
||||||
|
'min_length': 0,
|
||||||
|
'no_repeat_ngram_size': 0,
|
||||||
|
'num_beams': 1,
|
||||||
|
'penalty_alpha': 0.0,
|
||||||
|
'length_penalty': 1.0,
|
||||||
|
'early_stopping': False,
|
||||||
|
'mirostat_mode': 0,
|
||||||
|
'mirostat_tau': 5.0,
|
||||||
|
'mirostat_eta': 0.1,
|
||||||
|
'ban_eos_token': False,
|
||||||
|
'skip_special_tokens': True,
|
||||||
|
'custom_stopping_strings': '',
|
||||||
|
# 'logits_processor' - conditionally passed
|
||||||
|
# 'stopping_strings' - temporarily used
|
||||||
|
# 'logprobs' - temporarily used
|
||||||
|
# 'requested_model' - temporarily used
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_default_req_params():
|
||||||
|
return copy.deepcopy(default_req_params)
|
||||||
|
|
||||||
|
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||||
|
def default(dic, key, default):
|
||||||
|
val = dic.get(key, default)
|
||||||
|
if type(val) != type(default):
|
||||||
|
# maybe it's just something like 1 instead of 1.0
|
||||||
|
try:
|
||||||
|
v = type(default)(val)
|
||||||
|
if type(val)(v) == val: # if it's the same value passed in, it's ok.
|
||||||
|
return v
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
val = default
|
||||||
|
return val
|
||||||
|
|
||||||
|
def clamp(value, minvalue, maxvalue):
|
||||||
|
return max(minvalue, min(value, maxvalue))
|
||||||
|
|
102
extensions/openai/edits.py
Normal file
102
extensions/openai/edits.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from modules import shared
|
||||||
|
from extensions.openai.defaults import get_default_req_params
|
||||||
|
from extensions.openai.utils import debug_msg
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
from modules.text_generation import encode, generate_reply
|
||||||
|
|
||||||
|
|
||||||
|
def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
||||||
|
|
||||||
|
created_time = int(time.time()*1000)
|
||||||
|
|
||||||
|
# Request parameters
|
||||||
|
req_params = get_default_req_params()
|
||||||
|
stopping_strings = []
|
||||||
|
|
||||||
|
# Alpaca is verbose so a good default prompt
|
||||||
|
default_template = (
|
||||||
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
instruction_template = default_template
|
||||||
|
|
||||||
|
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||||
|
if shared.settings['instruction_template']:
|
||||||
|
if 'Alpaca' in shared.settings['instruction_template']:
|
||||||
|
stopping_strings.extend(['\n###'])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
|
template = instruct['turn_template']
|
||||||
|
template = template\
|
||||||
|
.replace('<|user|>', instruct.get('user', ''))\
|
||||||
|
.replace('<|bot|>', instruct.get('bot', ''))\
|
||||||
|
.replace('<|user-message|>', '{instruction}\n{input}')
|
||||||
|
|
||||||
|
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||||
|
if instruct['user']:
|
||||||
|
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
instruction_template = default_template
|
||||||
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||||
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||||
|
else:
|
||||||
|
stopping_strings.extend(['\n###'])
|
||||||
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||||
|
|
||||||
|
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||||
|
|
||||||
|
truncation_length = shared.settings['truncation_length']
|
||||||
|
|
||||||
|
token_count = len(encode(edit_task)[0])
|
||||||
|
max_tokens = truncation_length - token_count
|
||||||
|
|
||||||
|
if max_tokens < 1:
|
||||||
|
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
|
||||||
|
raise InvalidRequestError(err_msg, param='input')
|
||||||
|
|
||||||
|
req_params['max_new_tokens'] = max_tokens
|
||||||
|
req_params['truncation_length'] = truncation_length
|
||||||
|
req_params['temperature'] = temperature
|
||||||
|
req_params['top_p'] = top_p
|
||||||
|
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
||||||
|
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
||||||
|
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||||
|
|
||||||
|
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||||
|
|
||||||
|
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
||||||
|
answer = ''
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
||||||
|
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
||||||
|
answer = answer[1:]
|
||||||
|
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"object": "edit",
|
||||||
|
"created": created_time,
|
||||||
|
"choices": [{
|
||||||
|
"text": answer,
|
||||||
|
"index": 0,
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
50
extensions/openai/embeddings.py
Normal file
50
extensions/openai/embeddings.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import os
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from extensions.openai.utils import float_list_to_base64, debug_msg
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
|
||||||
|
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||||
|
embeddings_model = None
|
||||||
|
|
||||||
|
def load_embedding_model(model):
|
||||||
|
try:
|
||||||
|
emb_model = SentenceTransformer(model)
|
||||||
|
print(f"\nLoaded embedding model: {model}, max sequence length: {emb_model.max_seq_length}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError: Failed to load embedding model: {model}")
|
||||||
|
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message = repr(e))
|
||||||
|
|
||||||
|
return emb_model
|
||||||
|
|
||||||
|
def get_embeddings_model():
|
||||||
|
global embeddings_model, st_model
|
||||||
|
if st_model and not embeddings_model:
|
||||||
|
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||||
|
return embeddings_model
|
||||||
|
|
||||||
|
def get_embeddings_model_name():
|
||||||
|
global st_model
|
||||||
|
return st_model
|
||||||
|
|
||||||
|
def embeddings(input: list, encoding_format: str):
|
||||||
|
|
||||||
|
embeddings = get_embeddings_model().encode(input).tolist()
|
||||||
|
|
||||||
|
if encoding_format == "base64":
|
||||||
|
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
|
||||||
|
else:
|
||||||
|
data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
"model": st_model, # return the real model
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||||
|
|
||||||
|
return response
|
27
extensions/openai/errors.py
Normal file
27
extensions/openai/errors.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
class OpenAIError(Exception):
|
||||||
|
def __init__(self, message = None, code = 500, internal_message = ''):
|
||||||
|
self.message = message
|
||||||
|
self.code = code
|
||||||
|
self.internal_message = internal_message
|
||||||
|
def __repr__(self):
|
||||||
|
return "%s(message=%r, code=%d)" % (
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.message,
|
||||||
|
self.code,
|
||||||
|
)
|
||||||
|
|
||||||
|
class InvalidRequestError(OpenAIError):
|
||||||
|
def __init__(self, message, param, code = 400, error_type ='InvalidRequestError', internal_message = ''):
|
||||||
|
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||||
|
self.param = param
|
||||||
|
def __repr__(self):
|
||||||
|
return "%s(message=%r, code=%d, param=%s)" % (
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.message,
|
||||||
|
self.code,
|
||||||
|
self.param,
|
||||||
|
)
|
||||||
|
|
||||||
|
class ServiceUnavailableError(OpenAIError):
|
||||||
|
def __init__(self, message = None, code = 500, error_type ='ServiceUnavailableError', internal_message = ''):
|
||||||
|
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
48
extensions/openai/images.py
Normal file
48
extensions/openai/images.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
|
||||||
|
def generations(prompt: str, size: str, response_format: str, n: int):
|
||||||
|
# Stable Diffusion callout wrapper for txt2img
|
||||||
|
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||||
|
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
||||||
|
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
||||||
|
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
|
||||||
|
# Will probably work best with the stock SD models.
|
||||||
|
# SD configuration is beyond the scope of this API.
|
||||||
|
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
||||||
|
# require changing the form data handling to accept multipart form data, also to properly support
|
||||||
|
# url return types will require file management and a web serving files... Perhaps later!
|
||||||
|
|
||||||
|
width, height = [ int(x) for x in size.split('x') ] # ignore the restrictions on size
|
||||||
|
|
||||||
|
# to hack on better generation, edit default payload.
|
||||||
|
payload = {
|
||||||
|
'prompt': prompt, # ignore prompt limit of 1000 characters
|
||||||
|
'width': width,
|
||||||
|
'height': height,
|
||||||
|
'batch_size': n,
|
||||||
|
'restore_faces': True, # slightly less horrible
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
'created': int(time.time()),
|
||||||
|
'data': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: support SD_WEBUI_AUTH username:password pair.
|
||||||
|
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"
|
||||||
|
|
||||||
|
response = requests.post(url=sd_url, json=payload)
|
||||||
|
r = response.json()
|
||||||
|
if response.status_code != 200 or 'images' not in r:
|
||||||
|
raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code = response.status_code)
|
||||||
|
# r['parameters']...
|
||||||
|
for b64_json in r['images']:
|
||||||
|
if response_format == 'b64_json':
|
||||||
|
resp['data'].extend([{'b64_json': b64_json}])
|
||||||
|
else:
|
||||||
|
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
||||||
|
|
||||||
|
return resp
|
77
extensions/openai/models.py
Normal file
77
extensions/openai/models.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
from modules import shared
|
||||||
|
from modules.utils import get_available_models
|
||||||
|
from modules.models import load_model, unload_model
|
||||||
|
from modules.models_settings import (get_model_settings_from_yamls,
|
||||||
|
update_model_parameters)
|
||||||
|
|
||||||
|
from extensions.openai.embeddings import get_embeddings_model_name
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
|
||||||
|
def get_current_model_list() -> list:
|
||||||
|
return [ shared.model_name ] # The real chat/completions model, maybe "None"
|
||||||
|
|
||||||
|
def get_pseudo_model_list() -> list:
|
||||||
|
return [ # these are expected by so much, so include some here as a dummy
|
||||||
|
'gpt-3.5-turbo',
|
||||||
|
'text-embedding-ada-002',
|
||||||
|
]
|
||||||
|
|
||||||
|
def load_model(model_name: str) -> dict:
|
||||||
|
resp = {
|
||||||
|
"id": model_name,
|
||||||
|
"object": "engine",
|
||||||
|
"owner": "self",
|
||||||
|
"ready": True,
|
||||||
|
}
|
||||||
|
if model_name not in get_pseudo_model_list() + [ get_embeddings_model_name() ] + get_current_model_list(): # Real model only
|
||||||
|
# No args. Maybe it works anyways!
|
||||||
|
# TODO: hack some heuristics into args for better results
|
||||||
|
|
||||||
|
shared.model_name = model_name
|
||||||
|
unload_model()
|
||||||
|
|
||||||
|
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||||||
|
shared.settings.update(model_settings)
|
||||||
|
update_model_parameters(model_settings, initial=True)
|
||||||
|
|
||||||
|
if shared.settings['mode'] != 'instruct':
|
||||||
|
shared.settings['instruction_template'] = None
|
||||||
|
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
|
if not shared.model: # load failed.
|
||||||
|
shared.model_name = "None"
|
||||||
|
raise OpenAIError(f"Model load failed for: {shared.model_name}")
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def list_models(is_legacy: bool = False) -> dict:
|
||||||
|
# TODO: Lora's?
|
||||||
|
all_model_list = get_current_model_list() + [ get_embeddings_model_name() ] + get_pseudo_model_list() + get_available_models()
|
||||||
|
|
||||||
|
models = {}
|
||||||
|
|
||||||
|
if is_legacy:
|
||||||
|
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ]
|
||||||
|
if not shared.model:
|
||||||
|
models[0]['ready'] = False
|
||||||
|
else:
|
||||||
|
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"object": "list",
|
||||||
|
"data": models,
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def model_info(model_name: str) -> dict:
|
||||||
|
return {
|
||||||
|
"id": model_name,
|
||||||
|
"object": "model",
|
||||||
|
"owned_by": "user",
|
||||||
|
"permission": []
|
||||||
|
}
|
||||||
|
|
70
extensions/openai/moderations.py
Normal file
70
extensions/openai/moderations.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from numpy.linalg import norm
|
||||||
|
from extensions.openai.embeddings import get_embeddings_model
|
||||||
|
|
||||||
|
|
||||||
|
moderations_disabled = False # return 0/false
|
||||||
|
category_embeddings = None
|
||||||
|
antonym_embeddings = None
|
||||||
|
categories = [ "sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence" ]
|
||||||
|
flag_threshold = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def get_category_embeddings():
|
||||||
|
global category_embeddings, categories
|
||||||
|
if category_embeddings is None:
|
||||||
|
embeddings = get_embeddings_model().encode(categories).tolist()
|
||||||
|
category_embeddings = dict(zip(categories, embeddings))
|
||||||
|
|
||||||
|
return category_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a, b):
|
||||||
|
return np.dot(a, b) / (norm(a) * norm(b))
|
||||||
|
|
||||||
|
|
||||||
|
# seems most openai like with all-mpnet-base-v2
|
||||||
|
def mod_score(a, b):
|
||||||
|
return 2.0 * np.dot(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def moderations(input):
|
||||||
|
global category_embeddings, categories, flag_threshold, moderations_disabled
|
||||||
|
results = {
|
||||||
|
"id": f"modr-{int(time.time()*1e9)}",
|
||||||
|
"model": "text-moderation-001",
|
||||||
|
"results": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings_model = get_embeddings_model()
|
||||||
|
if not embeddings_model or moderations_disabled:
|
||||||
|
results['results'] = [{
|
||||||
|
'categories': dict([ (C, False) for C in categories]),
|
||||||
|
'category_scores': dict([ (C, 0.0) for C in categories]),
|
||||||
|
'flagged': False,
|
||||||
|
}]
|
||||||
|
return results
|
||||||
|
|
||||||
|
category_embeddings = get_category_embeddings()
|
||||||
|
|
||||||
|
|
||||||
|
# input, string or array
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
|
||||||
|
for in_str in input:
|
||||||
|
for ine in embeddings_model.encode([in_str]).tolist():
|
||||||
|
category_scores = dict([ (C, mod_score(category_embeddings[C], ine)) for C in categories ])
|
||||||
|
category_flags = dict([ (C, bool(category_scores[C] > flag_threshold)) for C in categories ])
|
||||||
|
flagged = any(category_flags.values())
|
||||||
|
|
||||||
|
results['results'].extend([{
|
||||||
|
'flagged': flagged,
|
||||||
|
'categories': category_flags,
|
||||||
|
'category_scores': category_scores,
|
||||||
|
}])
|
||||||
|
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
return results
|
@ -1,2 +1,3 @@
|
|||||||
flask_cloudflared==0.0.12
|
flask_cloudflared==0.0.12
|
||||||
sentence-transformers
|
sentence-transformers
|
||||||
|
tiktoken
|
@ -1,108 +1,27 @@
|
|||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import traceback
|
||||||
import requests
|
|
||||||
import yaml
|
|
||||||
import numpy as np
|
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from modules.utils import get_available_models
|
|
||||||
from modules.models import load_model, unload_model
|
|
||||||
from modules.models_settings import (get_model_settings_from_yamls,
|
|
||||||
update_model_parameters)
|
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, generate_reply
|
|
||||||
|
from extensions.openai.tokens import token_count, token_encode, token_decode
|
||||||
|
import extensions.openai.models as OAImodels
|
||||||
|
import extensions.openai.edits as OAIedits
|
||||||
|
import extensions.openai.embeddings as OAIembeddings
|
||||||
|
import extensions.openai.images as OAIimages
|
||||||
|
import extensions.openai.moderations as OAImoderations
|
||||||
|
import extensions.openai.completions as OAIcompletions
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
from extensions.openai.utils import debug_msg
|
||||||
|
from extensions.openai.defaults import (get_default_req_params, default, clamp)
|
||||||
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
||||||
}
|
}
|
||||||
|
|
||||||
debug = True if 'OPENEDAI_DEBUG' in os.environ else False
|
|
||||||
|
|
||||||
# Slightly different defaults for OpenAI's API
|
|
||||||
# Data type is important, Ex. use 0.0 for a float 0
|
|
||||||
default_req_params = {
|
|
||||||
'max_new_tokens': 200,
|
|
||||||
'temperature': 1.0,
|
|
||||||
'top_p': 1.0,
|
|
||||||
'top_k': 1,
|
|
||||||
'repetition_penalty': 1.18,
|
|
||||||
'repetition_penalty_range': 0,
|
|
||||||
'encoder_repetition_penalty': 1.0,
|
|
||||||
'suffix': None,
|
|
||||||
'stream': False,
|
|
||||||
'echo': False,
|
|
||||||
'seed': -1,
|
|
||||||
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
|
||||||
'truncation_length': 2048,
|
|
||||||
'add_bos_token': True,
|
|
||||||
'do_sample': True,
|
|
||||||
'typical_p': 1.0,
|
|
||||||
'epsilon_cutoff': 0.0, # In units of 1e-4
|
|
||||||
'eta_cutoff': 0.0, # In units of 1e-4
|
|
||||||
'tfs': 1.0,
|
|
||||||
'top_a': 0.0,
|
|
||||||
'min_length': 0,
|
|
||||||
'no_repeat_ngram_size': 0,
|
|
||||||
'num_beams': 1,
|
|
||||||
'penalty_alpha': 0.0,
|
|
||||||
'length_penalty': 1.0,
|
|
||||||
'early_stopping': False,
|
|
||||||
'mirostat_mode': 0,
|
|
||||||
'mirostat_tau': 5.0,
|
|
||||||
'mirostat_eta': 0.1,
|
|
||||||
'ban_eos_token': False,
|
|
||||||
'skip_special_tokens': True,
|
|
||||||
'custom_stopping_strings': '',
|
|
||||||
}
|
|
||||||
|
|
||||||
# Optional, install the module and download the model to enable
|
|
||||||
# v1/embeddings
|
|
||||||
try:
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
|
||||||
embedding_model = None
|
|
||||||
|
|
||||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
|
||||||
def default(dic, key, default):
|
|
||||||
val = dic.get(key, default)
|
|
||||||
if type(val) != type(default):
|
|
||||||
# maybe it's just something like 1 instead of 1.0
|
|
||||||
try:
|
|
||||||
v = type(default)(val)
|
|
||||||
if type(val)(v) == val: # if it's the same value passed in, it's ok.
|
|
||||||
return v
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
val = default
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def clamp(value, minvalue, maxvalue):
|
|
||||||
return max(minvalue, min(value, maxvalue))
|
|
||||||
|
|
||||||
|
|
||||||
def float_list_to_base64(float_list):
|
|
||||||
# Convert the list to a float32 array that the OpenAPI client expects
|
|
||||||
float_array = np.array(float_list, dtype="float32")
|
|
||||||
|
|
||||||
# Get raw bytes
|
|
||||||
bytes_array = float_array.tobytes()
|
|
||||||
|
|
||||||
# Encode bytes into base64
|
|
||||||
encoded_bytes = base64.b64encode(bytes_array)
|
|
||||||
|
|
||||||
# Turn raw base64 encoded bytes into ASCII
|
|
||||||
ascii_string = encoded_bytes.decode('ascii')
|
|
||||||
return ascii_string
|
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def send_access_control_headers(self):
|
def send_access_control_headers(self):
|
||||||
self.send_header("Access-Control-Allow-Origin", "*")
|
self.send_header("Access-Control-Allow-Origin", "*")
|
||||||
@ -118,11 +37,43 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
"Authorization"
|
"Authorization"
|
||||||
)
|
)
|
||||||
|
|
||||||
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
|
def do_OPTIONS(self):
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_access_control_headers()
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write("OK".encode('utf-8'))
|
||||||
|
|
||||||
|
def start_sse(self):
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_access_control_headers()
|
||||||
|
self.send_header('Content-Type', 'text/event-stream')
|
||||||
|
self.send_header('Cache-Control', 'no-cache')
|
||||||
|
# self.send_header('Connection', 'keep-alive')
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def send_sse(self, chunk: dict):
|
||||||
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
||||||
|
debug_msg(response)
|
||||||
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
def end_sse(self):
|
||||||
|
self.wfile.write('data: [DONE]\r\n\r\n'.encode('utf-8'))
|
||||||
|
|
||||||
|
def return_json(self, ret: dict, code: int = 200, no_debug=False):
|
||||||
self.send_response(code)
|
self.send_response(code)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
|
response = json.dumps(ret)
|
||||||
|
r_utf8 = response.encode('utf-8')
|
||||||
|
self.wfile.write(r_utf8)
|
||||||
|
if not no_debug:
|
||||||
|
debug_msg(r_utf8)
|
||||||
|
|
||||||
|
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
|
||||||
|
|
||||||
error_resp = {
|
error_resp = {
|
||||||
'error': {
|
'error': {
|
||||||
'message': message,
|
'message': message,
|
||||||
@ -132,121 +83,61 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if internal_message:
|
if internal_message:
|
||||||
error_resp['internal_message'] = internal_message
|
print(internal_message)
|
||||||
|
#error_resp['internal_message'] = internal_message
|
||||||
|
|
||||||
response = json.dumps(error_resp)
|
self.return_json(error_resp, code)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
def openai_error_handler(func):
|
||||||
self.send_response(200)
|
def wrapper(self):
|
||||||
self.send_access_control_headers()
|
try:
|
||||||
self.send_header('Content-Type', 'application/json')
|
func(self)
|
||||||
self.end_headers()
|
except ServiceUnavailableError as e:
|
||||||
self.wfile.write("OK".encode('utf-8'))
|
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
|
||||||
|
except InvalidRequestError as e:
|
||||||
|
self.openai_error(e.message, e.code, e.error_type, e.param, internal_message=e.internal_message)
|
||||||
|
except OpenAIError as e:
|
||||||
|
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
|
||||||
|
except Exception as e:
|
||||||
|
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@openai_error_handler
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
|
debug_msg(self.requestline)
|
||||||
current_model_list = [ shared.model_name ] # The real chat/completions model, maybe "None"
|
debug_msg(self.headers)
|
||||||
embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model
|
|
||||||
pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
|
|
||||||
'gpt-3.5-turbo', # /v1/chat/completions
|
|
||||||
'text-curie-001', # /v1/completions, 2k context
|
|
||||||
'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
|
||||||
]
|
|
||||||
|
|
||||||
|
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
|
||||||
is_legacy = 'engines' in self.path
|
is_legacy = 'engines' in self.path
|
||||||
is_list = self.path in ['/v1/engines', '/v1/models']
|
is_list = self.path in ['/v1/engines', '/v1/models']
|
||||||
|
if is_legacy and not is_list:
|
||||||
resp = ''
|
|
||||||
|
|
||||||
if is_legacy and not is_list: # load model
|
|
||||||
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
|
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
|
||||||
|
resp = OAImodels.load_model(model_name)
|
||||||
resp = {
|
|
||||||
"id": model_name,
|
|
||||||
"object": "engine",
|
|
||||||
"owner": "self",
|
|
||||||
"ready": True,
|
|
||||||
}
|
|
||||||
if model_name not in pseudo_model_list + embeddings_model_list + current_model_list: # Real model only
|
|
||||||
# No args. Maybe it works anyways!
|
|
||||||
# TODO: hack some heuristics into args for better results
|
|
||||||
|
|
||||||
shared.model_name = model_name
|
|
||||||
unload_model()
|
|
||||||
|
|
||||||
model_settings = get_model_settings_from_yamls(shared.model_name)
|
|
||||||
shared.settings.update(model_settings)
|
|
||||||
update_model_parameters(model_settings, initial=True)
|
|
||||||
|
|
||||||
if shared.settings['mode'] != 'instruct':
|
|
||||||
shared.settings['instruction_template'] = None
|
|
||||||
|
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
|
||||||
|
|
||||||
if not shared.model: # load failed.
|
|
||||||
shared.model_name = "None"
|
|
||||||
resp['id'] = "None"
|
|
||||||
resp['ready'] = False
|
|
||||||
|
|
||||||
elif is_list:
|
elif is_list:
|
||||||
# TODO: Lora's?
|
resp = OAImodels.list_models(is_legacy)
|
||||||
available_model_list = get_available_models()
|
|
||||||
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
|
|
||||||
|
|
||||||
models = {}
|
|
||||||
|
|
||||||
if is_legacy:
|
|
||||||
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ]
|
|
||||||
if not shared.model:
|
|
||||||
models[0]['ready'] = False
|
|
||||||
else:
|
else:
|
||||||
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
model_name = self.path[len('/v1/models/'):]
|
||||||
|
resp = OAImodels.model_info()
|
||||||
|
|
||||||
resp = {
|
self.return_json(resp)
|
||||||
"object": "list",
|
|
||||||
"data": models,
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
the_model_name = self.path[len('/v1/models/'):]
|
|
||||||
resp = {
|
|
||||||
"id": the_model_name,
|
|
||||||
"object": "model",
|
|
||||||
"owned_by": "user",
|
|
||||||
"permission": []
|
|
||||||
}
|
|
||||||
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
response = json.dumps(resp)
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif '/billing/usage' in self.path:
|
elif '/billing/usage' in self.path:
|
||||||
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||||
self.send_response(200)
|
self.return_json({"total_usage": 0}, no_debug=True)
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
response = json.dumps({
|
|
||||||
"total_usage": 0,
|
|
||||||
})
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
|
||||||
|
@openai_error_handler
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
if debug:
|
debug_msg(self.requestline)
|
||||||
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
|
debug_msg(self.headers)
|
||||||
|
|
||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['Content-Length'])
|
||||||
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
||||||
|
|
||||||
if debug:
|
debug_msg(body)
|
||||||
print(body)
|
|
||||||
|
|
||||||
if '/completions' in self.path or '/generate' in self.path:
|
if '/completions' in self.path or '/generate' in self.path:
|
||||||
|
|
||||||
@ -255,621 +146,109 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
is_legacy = '/generate' in self.path
|
is_legacy = '/generate' in self.path
|
||||||
is_chat_request = 'chat' in self.path
|
is_streaming = body.get('stream', False)
|
||||||
resp_list = 'data' if is_legacy else 'choices'
|
|
||||||
|
|
||||||
# XXX model is ignored for now
|
|
||||||
# model = body.get('model', shared.model_name) # ignored, use existing for now
|
|
||||||
model = shared.model_name
|
|
||||||
created_time = int(time.time())
|
|
||||||
|
|
||||||
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat_request else "conv-%d" % (created_time)
|
|
||||||
|
|
||||||
# Request Parameters
|
|
||||||
# Try to use openai defaults or map them to something with the same intent
|
|
||||||
req_params = default_req_params.copy()
|
|
||||||
stopping_strings = []
|
|
||||||
|
|
||||||
if 'stop' in body:
|
|
||||||
if isinstance(body['stop'], str):
|
|
||||||
stopping_strings.extend([body['stop']])
|
|
||||||
elif isinstance(body['stop'], list):
|
|
||||||
stopping_strings.extend(body['stop'])
|
|
||||||
|
|
||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
|
||||||
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
|
||||||
|
|
||||||
default_max_tokens = truncation_length if is_chat_request else 16 # completions default, chat default is 'inf' so we need to cap it.
|
|
||||||
|
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
|
||||||
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
|
||||||
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
|
|
||||||
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
|
||||||
req_params['truncation_length'] = truncation_length
|
|
||||||
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
|
||||||
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
|
|
||||||
req_params['top_k'] = default(body, 'best_of', default_req_params['top_k'])
|
|
||||||
req_params['suffix'] = default(body, 'suffix', default_req_params['suffix'])
|
|
||||||
req_params['stream'] = default(body, 'stream', default_req_params['stream'])
|
|
||||||
req_params['echo'] = default(body, 'echo', default_req_params['echo'])
|
|
||||||
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
|
|
||||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
|
|
||||||
|
|
||||||
is_streaming = req_params['stream']
|
|
||||||
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_access_control_headers()
|
|
||||||
if is_streaming:
|
|
||||||
self.send_header('Content-Type', 'text/event-stream')
|
|
||||||
self.send_header('Cache-Control', 'no-cache')
|
|
||||||
# self.send_header('Connection', 'keep-alive')
|
|
||||||
else:
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
token_count = 0
|
|
||||||
completion_token_count = 0
|
|
||||||
prompt = ''
|
|
||||||
stream_object_type = ''
|
|
||||||
object_type = ''
|
|
||||||
|
|
||||||
if is_chat_request:
|
|
||||||
# Chat Completions
|
|
||||||
stream_object_type = 'chat.completions.chunk'
|
|
||||||
object_type = 'chat.completions'
|
|
||||||
|
|
||||||
messages = body['messages']
|
|
||||||
|
|
||||||
role_formats = {
|
|
||||||
'user': 'user: {message}\n',
|
|
||||||
'assistant': 'assistant: {message}\n',
|
|
||||||
'system': '{message}',
|
|
||||||
'context': 'You are a helpful assistant. Answer as concisely as possible.',
|
|
||||||
'prompt': 'assistant:',
|
|
||||||
}
|
|
||||||
|
|
||||||
# Instruct models can be much better
|
|
||||||
if shared.settings['instruction_template']:
|
|
||||||
try:
|
|
||||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
||||||
|
|
||||||
template = instruct['turn_template']
|
|
||||||
system_message_template = "{message}"
|
|
||||||
system_message_default = instruct['context']
|
|
||||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
|
||||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
|
||||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
|
||||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
|
||||||
|
|
||||||
role_formats = {
|
|
||||||
'user': user_message_template,
|
|
||||||
'assistant': bot_message_template,
|
|
||||||
'system': system_message_template,
|
|
||||||
'context': system_message_default,
|
|
||||||
'prompt': bot_prompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'Alpaca' in shared.settings['instruction_template']:
|
|
||||||
stopping_strings.extend(['\n###'])
|
|
||||||
elif instruct['user']: # WizardLM and some others have no user prompt.
|
|
||||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
stopping_strings.extend(['\nuser:'])
|
|
||||||
|
|
||||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
|
||||||
print("Warning: Loaded default instruction-following template for model.")
|
|
||||||
|
|
||||||
else:
|
|
||||||
stopping_strings.extend(['\nuser:'])
|
|
||||||
print("Warning: Loaded default instruction-following template for model.")
|
|
||||||
|
|
||||||
system_msgs = []
|
|
||||||
chat_msgs = []
|
|
||||||
|
|
||||||
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
|
||||||
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
|
|
||||||
if context_msg:
|
|
||||||
system_msgs.extend([context_msg])
|
|
||||||
|
|
||||||
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
|
||||||
if 'prompt' in body:
|
|
||||||
prompt_msg = role_formats['system'].format(message=body['prompt'])
|
|
||||||
system_msgs.extend([prompt_msg])
|
|
||||||
|
|
||||||
for m in messages:
|
|
||||||
role = m['role']
|
|
||||||
content = m['content']
|
|
||||||
msg = role_formats[role].format(message=content)
|
|
||||||
if role == 'system':
|
|
||||||
system_msgs.extend([msg])
|
|
||||||
else:
|
|
||||||
chat_msgs.extend([msg])
|
|
||||||
|
|
||||||
# can't really truncate the system messages
|
|
||||||
system_msg = '\n'.join(system_msgs)
|
|
||||||
if system_msg and system_msg[-1] != '\n':
|
|
||||||
system_msg = system_msg + '\n'
|
|
||||||
|
|
||||||
system_token_count = len(encode(system_msg)[0])
|
|
||||||
remaining_tokens = truncation_length - system_token_count
|
|
||||||
chat_msg = ''
|
|
||||||
|
|
||||||
while chat_msgs:
|
|
||||||
new_msg = chat_msgs.pop()
|
|
||||||
new_size = len(encode(new_msg)[0])
|
|
||||||
if new_size <= remaining_tokens:
|
|
||||||
chat_msg = new_msg + chat_msg
|
|
||||||
remaining_tokens -= new_size
|
|
||||||
else:
|
|
||||||
print(f"Warning: too many messages for context size, dropping {len(chat_msgs) + 1} oldest message(s).")
|
|
||||||
break
|
|
||||||
|
|
||||||
prompt = system_msg + chat_msg + role_formats['prompt']
|
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Text Completions
|
|
||||||
stream_object_type = 'text_completion.chunk'
|
|
||||||
object_type = 'text_completion'
|
|
||||||
|
|
||||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
|
||||||
if is_legacy:
|
|
||||||
prompt = body['context'] # Older engines.generate API
|
|
||||||
else:
|
|
||||||
prompt = body['prompt'] # XXX this can be different types
|
|
||||||
|
|
||||||
if isinstance(prompt, list):
|
|
||||||
self.openai_error("API Batched generation not yet supported.")
|
|
||||||
return
|
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
|
||||||
if token_count >= truncation_length:
|
|
||||||
new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count)
|
|
||||||
prompt = prompt[-new_len:]
|
|
||||||
new_token_count = len(encode(prompt)[0])
|
|
||||||
print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.")
|
|
||||||
token_count = new_token_count
|
|
||||||
|
|
||||||
if truncation_length - token_count < req_params['max_new_tokens']:
|
|
||||||
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {truncation_length - token_count}")
|
|
||||||
req_params['max_new_tokens'] = truncation_length - token_count
|
|
||||||
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
|
||||||
|
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
# begin streaming
|
self.start_sse()
|
||||||
chunk = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": stream_object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": shared.model_name,
|
|
||||||
resp_list: [{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": None,
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
|
|
||||||
if stream_object_type == 'text_completion.chunk':
|
response = []
|
||||||
chunk[resp_list][0]["text"] = ""
|
if 'chat' in self.path:
|
||||||
|
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
|
||||||
else:
|
else:
|
||||||
# So yeah... do both methods? delta and messages.
|
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
|
||||||
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
|
||||||
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
|
||||||
|
|
||||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
for resp in response:
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.send_sse(resp)
|
||||||
|
|
||||||
# generate reply #######################################
|
self.end_sse()
|
||||||
if debug:
|
|
||||||
print({'prompt': prompt, 'req_params': req_params})
|
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
answer = ''
|
|
||||||
seen_content = ''
|
|
||||||
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
|
||||||
|
|
||||||
for a in generator:
|
|
||||||
answer = a
|
|
||||||
|
|
||||||
stop_string_found = False
|
|
||||||
len_seen = len(seen_content)
|
|
||||||
search_start = max(len_seen - longest_stop_len, 0)
|
|
||||||
|
|
||||||
for string in stopping_strings:
|
|
||||||
idx = answer.find(string, search_start)
|
|
||||||
if idx != -1:
|
|
||||||
answer = answer[:idx] # clip it.
|
|
||||||
stop_string_found = True
|
|
||||||
|
|
||||||
if stop_string_found:
|
|
||||||
break
|
|
||||||
|
|
||||||
# If something like "\nYo" is generated just before "\nYou:"
|
|
||||||
# is completed, buffer and generate more, don't send it
|
|
||||||
buffer_and_continue = False
|
|
||||||
|
|
||||||
for string in stopping_strings:
|
|
||||||
for j in range(len(string) - 1, 0, -1):
|
|
||||||
if answer[-j:] == string[:j]:
|
|
||||||
buffer_and_continue = True
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
continue
|
response = ''
|
||||||
break
|
if 'chat' in self.path:
|
||||||
|
response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
|
||||||
if buffer_and_continue:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_streaming:
|
|
||||||
# Streaming
|
|
||||||
new_content = answer[len_seen:]
|
|
||||||
|
|
||||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
|
||||||
continue
|
|
||||||
|
|
||||||
seen_content = answer
|
|
||||||
chunk = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": stream_object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": shared.model_name,
|
|
||||||
resp_list: [{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": None,
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
|
|
||||||
# strip extra leading space off new generated content
|
|
||||||
if len_seen == 0 and new_content[0] == ' ':
|
|
||||||
new_content = new_content[1:]
|
|
||||||
|
|
||||||
if stream_object_type == 'text_completion.chunk':
|
|
||||||
chunk[resp_list][0]['text'] = new_content
|
|
||||||
else:
|
else:
|
||||||
# So yeah... do both methods? delta and messages.
|
response = OAIcompletions.completions(body, is_legacy=is_legacy)
|
||||||
chunk[resp_list][0]['message'] = {'content': new_content}
|
|
||||||
chunk[resp_list][0]['delta'] = {'content': new_content}
|
|
||||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
completion_token_count += len(encode(new_content)[0])
|
|
||||||
|
|
||||||
if is_streaming:
|
self.return_json(response)
|
||||||
chunk = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": stream_object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": model, # TODO: add Lora info?
|
|
||||||
resp_list: [{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": token_count,
|
|
||||||
"completion_tokens": completion_token_count,
|
|
||||||
"total_tokens": token_count + completion_token_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if stream_object_type == 'text_completion.chunk':
|
|
||||||
chunk[resp_list][0]['text'] = ''
|
|
||||||
else:
|
|
||||||
# So yeah... do both methods? delta and messages.
|
|
||||||
chunk[resp_list][0]['message'] = {'content': ''}
|
|
||||||
chunk[resp_list][0]['delta'] = {'content': ''}
|
|
||||||
|
|
||||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\ndata: [DONE]\r\n\r\n'
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
# Finished if streaming.
|
|
||||||
if debug:
|
|
||||||
if answer and answer[0] == ' ':
|
|
||||||
answer = answer[1:]
|
|
||||||
print({'answer': answer}, chunk)
|
|
||||||
return
|
|
||||||
|
|
||||||
# strip extra leading space off new generated content
|
|
||||||
if answer and answer[0] == ' ':
|
|
||||||
answer = answer[1:]
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print({'response': answer})
|
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
|
||||||
stop_reason = "stop"
|
|
||||||
if token_count + completion_token_count >= truncation_length:
|
|
||||||
stop_reason = "length"
|
|
||||||
|
|
||||||
resp = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": model, # TODO: add Lora info?
|
|
||||||
resp_list: [{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": stop_reason,
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": token_count,
|
|
||||||
"completion_tokens": completion_token_count,
|
|
||||||
"total_tokens": token_count + completion_token_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_chat_request:
|
|
||||||
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
|
|
||||||
else:
|
|
||||||
resp[resp_list][0]["text"] = answer
|
|
||||||
|
|
||||||
response = json.dumps(resp)
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif '/edits' in self.path:
|
elif '/edits' in self.path:
|
||||||
|
# deprecated
|
||||||
|
|
||||||
if not shared.model:
|
if not shared.model:
|
||||||
self.openai_error("No model loaded.")
|
self.openai_error("No model loaded.")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.send_response(200)
|
req_params = get_default_req_params()
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
created_time = int(time.time())
|
|
||||||
|
|
||||||
# Using Alpaca format, this may work with other models too.
|
|
||||||
instruction = body['instruction']
|
instruction = body['instruction']
|
||||||
input = body.get('input', '')
|
input = body.get('input', '')
|
||||||
|
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||||
|
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||||
|
|
||||||
# Request parameters
|
response = OAIedits.edits(instruction, input, temperature, top_p)
|
||||||
req_params = default_req_params.copy()
|
|
||||||
stopping_strings = []
|
|
||||||
|
|
||||||
# Alpaca is verbose so a good default prompt
|
self.return_json(response)
|
||||||
default_template = (
|
|
||||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
|
||||||
"Write a response that appropriately completes the request.\n\n"
|
|
||||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
instruction_template = default_template
|
|
||||||
|
|
||||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
|
||||||
if shared.settings['instruction_template']:
|
|
||||||
if 'Alpaca' in shared.settings['instruction_template']:
|
|
||||||
stopping_strings.extend(['\n###'])
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
||||||
|
|
||||||
template = instruct['turn_template']
|
|
||||||
template = template\
|
|
||||||
.replace('<|user|>', instruct.get('user', ''))\
|
|
||||||
.replace('<|bot|>', instruct.get('bot', ''))\
|
|
||||||
.replace('<|user-message|>', '{instruction}\n{input}')
|
|
||||||
|
|
||||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
|
||||||
if instruct['user']:
|
|
||||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
instruction_template = default_template
|
|
||||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
|
||||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
|
||||||
else:
|
|
||||||
stopping_strings.extend(['\n###'])
|
|
||||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
|
||||||
|
|
||||||
|
|
||||||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
|
||||||
|
|
||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
|
||||||
token_count = len(encode(edit_task)[0])
|
|
||||||
max_tokens = truncation_length - token_count
|
|
||||||
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
|
||||||
req_params['truncation_length'] = truncation_length
|
|
||||||
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
|
||||||
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
|
|
||||||
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
|
|
||||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
|
||||||
|
|
||||||
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
|
||||||
answer = ''
|
|
||||||
seen_content = ''
|
|
||||||
for a in generator:
|
|
||||||
answer = a
|
|
||||||
|
|
||||||
stop_string_found = False
|
|
||||||
len_seen = len(seen_content)
|
|
||||||
search_start = max(len_seen - longest_stop_len, 0)
|
|
||||||
|
|
||||||
for string in stopping_strings:
|
|
||||||
idx = answer.find(string, search_start)
|
|
||||||
if idx != -1:
|
|
||||||
answer = answer[:idx] # clip it.
|
|
||||||
stop_string_found = True
|
|
||||||
|
|
||||||
if stop_string_found:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
|
||||||
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
|
||||||
answer = answer[1:]
|
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
|
||||||
|
|
||||||
resp = {
|
|
||||||
"object": "edit",
|
|
||||||
"created": created_time,
|
|
||||||
"choices": [{
|
|
||||||
"text": answer,
|
|
||||||
"index": 0,
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": token_count,
|
|
||||||
"completion_tokens": completion_token_count,
|
|
||||||
"total_tokens": token_count + completion_token_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print({'answer': answer, 'completion_token_count': completion_token_count})
|
|
||||||
|
|
||||||
response = json.dumps(resp)
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
||||||
# Stable Diffusion callout wrapper for txt2img
|
prompt = body['prompt']
|
||||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
size = default(body, 'size', '1024x1024')
|
||||||
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
|
||||||
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
|
||||||
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
|
|
||||||
# Will probably work best with the stock SD models.
|
|
||||||
# SD configuration is beyond the scope of this API.
|
|
||||||
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
|
||||||
# require changing the form data handling to accept multipart form data, also to properly support
|
|
||||||
# url return types will require file management and a web serving files... Perhaps later!
|
|
||||||
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] # ignore the restrictions on size
|
|
||||||
response_format = default(body, 'response_format', 'url') # or b64_json
|
response_format = default(body, 'response_format', 'url') # or b64_json
|
||||||
|
n = default(body, 'n', 1) # ignore the batch limits of max 10
|
||||||
|
|
||||||
payload = {
|
response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
||||||
'prompt': body['prompt'], # ignore prompt limit of 1000 characters
|
|
||||||
'width': width,
|
|
||||||
'height': height,
|
|
||||||
'batch_size': default(body, 'n', 1) # ignore the batch limits of max 10
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = {
|
self.return_json(response, no_debug=True)
|
||||||
'created': int(time.time()),
|
|
||||||
'data': []
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: support SD_WEBUI_AUTH username:password pair.
|
elif '/embeddings' in self.path:
|
||||||
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"
|
encoding_format = body.get('encoding_format', '')
|
||||||
|
|
||||||
response = requests.post(url=sd_url, json=payload)
|
input = body.get('input', body.get('text', ''))
|
||||||
r = response.json()
|
if not input:
|
||||||
# r['parameters']...
|
raise InvalidRequestError("Missing required argument input", params='input')
|
||||||
for b64_json in r['images']:
|
|
||||||
if response_format == 'b64_json':
|
|
||||||
resp['data'].extend([{'b64_json': b64_json}])
|
|
||||||
else:
|
|
||||||
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
|
||||||
|
|
||||||
response = json.dumps(resp)
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif '/embeddings' in self.path and embedding_model is not None:
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
input = body['input'] if 'input' in body else body['text']
|
|
||||||
if type(input) is str:
|
if type(input) is str:
|
||||||
input = [input]
|
input = [input]
|
||||||
|
|
||||||
embeddings = embedding_model.encode(input).tolist()
|
response = OAIembeddings.embeddings(input, encoding_format)
|
||||||
|
|
||||||
def enc_emb(emb):
|
self.return_json(response, no_debug=True)
|
||||||
# If base64 is specified, encode. Otherwise, do nothing.
|
|
||||||
if body.get("encoding_format", "") == "base64":
|
|
||||||
return float_list_to_base64(emb)
|
|
||||||
else:
|
|
||||||
return emb
|
|
||||||
data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)]
|
|
||||||
|
|
||||||
response = json.dumps({
|
|
||||||
"object": "list",
|
|
||||||
"data": data,
|
|
||||||
"model": st_model, # return the real model
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"total_tokens": 0,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif '/moderations' in self.path:
|
elif '/moderations' in self.path:
|
||||||
# for now do nothing, just don't error.
|
input = body['input']
|
||||||
self.send_response(200)
|
if not input:
|
||||||
self.send_access_control_headers()
|
raise InvalidRequestError("Missing required argument input", params='input')
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
response = json.dumps({
|
response = OAImoderations.moderations(input)
|
||||||
"id": "modr-5MWoLO",
|
|
||||||
"model": "text-moderation-001",
|
self.return_json(response, no_debug=True)
|
||||||
"results": [{
|
|
||||||
"categories": {
|
|
||||||
"hate": False,
|
|
||||||
"hate/threatening": False,
|
|
||||||
"self-harm": False,
|
|
||||||
"sexual": False,
|
|
||||||
"sexual/minors": False,
|
|
||||||
"violence": False,
|
|
||||||
"violence/graphic": False
|
|
||||||
},
|
|
||||||
"category_scores": {
|
|
||||||
"hate": 0.0,
|
|
||||||
"hate/threatening": 0.0,
|
|
||||||
"self-harm": 0.0,
|
|
||||||
"sexual": 0.0,
|
|
||||||
"sexual/minors": 0.0,
|
|
||||||
"violence": 0.0,
|
|
||||||
"violence/graphic": 0.0
|
|
||||||
},
|
|
||||||
"flagged": False
|
|
||||||
}]
|
|
||||||
})
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif self.path == '/api/v1/token-count':
|
elif self.path == '/api/v1/token-count':
|
||||||
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
||||||
self.send_response(200)
|
response = token_count(body['prompt'])
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
tokens = encode(body['prompt'])[0]
|
self.return_json(response, no_debug=True)
|
||||||
response = json.dumps({
|
|
||||||
'results': [{
|
elif self.path == '/api/v1/token/encode':
|
||||||
'tokens': len(tokens)
|
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
||||||
}]
|
encoding_format = body.get('encoding_format', '')
|
||||||
})
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
response = token_encode(body['input'], encoding_format)
|
||||||
|
|
||||||
|
self.return_json(response, no_debug=True)
|
||||||
|
|
||||||
|
elif self.path == '/api/v1/token/decode':
|
||||||
|
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
||||||
|
encoding_format = body.get('encoding_format', '')
|
||||||
|
|
||||||
|
response = token_decode(body['input'], encoding_format)
|
||||||
|
|
||||||
|
self.return_json(response, no_debug=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(self.path, self.headers)
|
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
|
||||||
|
|
||||||
def run_server():
|
def run_server():
|
||||||
global embedding_model
|
|
||||||
try:
|
|
||||||
embedding_model = SentenceTransformer(st_model)
|
|
||||||
print(f"\nLoaded embedding model: {st_model}, max sequence length: {embedding_model.max_seq_length}")
|
|
||||||
except:
|
|
||||||
print(f"\nFailed to load embedding model: {st_model}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
||||||
server = ThreadingHTTPServer(server_addr, Handler)
|
server = ThreadingHTTPServer(server_addr, Handler)
|
||||||
if shared.args.share:
|
if shared.args.share:
|
||||||
|
37
extensions/openai/tokens.py
Normal file
37
extensions/openai/tokens.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from extensions.openai.utils import float_list_to_base64
|
||||||
|
from modules.text_generation import encode, decode
|
||||||
|
|
||||||
|
def token_count(prompt):
|
||||||
|
tokens = encode(prompt)[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'results': [{
|
||||||
|
'tokens': len(tokens)
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def token_encode(input, encoding_format = ''):
|
||||||
|
#if isinstance(input, list):
|
||||||
|
tokens = encode(input)[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'results': [{
|
||||||
|
'encoding_format': encoding_format,
|
||||||
|
'tokens': float_list_to_base64(tokens) if encoding_format == "base64" else tokens,
|
||||||
|
'length': len(tokens),
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def token_decode(tokens, encoding_format):
|
||||||
|
#if isinstance(input, list):
|
||||||
|
# if encoding_format == "base64":
|
||||||
|
# tokens = base64_to_float_list(tokens)
|
||||||
|
output = decode(tokens)[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'results': [{
|
||||||
|
'text': output
|
||||||
|
}]
|
||||||
|
}
|
26
extensions/openai/utils.py
Normal file
26
extensions/openai/utils.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def float_list_to_base64(float_list):
|
||||||
|
# Convert the list to a float32 array that the OpenAPI client expects
|
||||||
|
float_array = np.array(float_list, dtype="float32")
|
||||||
|
|
||||||
|
# Get raw bytes
|
||||||
|
bytes_array = float_array.tobytes()
|
||||||
|
|
||||||
|
# Encode bytes into base64
|
||||||
|
encoded_bytes = base64.b64encode(bytes_array)
|
||||||
|
|
||||||
|
# Turn raw base64 encoded bytes into ASCII
|
||||||
|
ascii_string = encoded_bytes.decode('ascii')
|
||||||
|
return ascii_string
|
||||||
|
|
||||||
|
def end_line(s):
|
||||||
|
if s and s[-1] != '\n':
|
||||||
|
s = s + '\n'
|
||||||
|
return s
|
||||||
|
|
||||||
|
def debug_msg(*args, **kwargs):
|
||||||
|
if 'OPENEDAI_DEBUG' in os.environ:
|
||||||
|
print(*args, **kwargs)
|
Loading…
Reference in New Issue
Block a user