From 9ae0eab989ed05c1d373135148eca65f652c6fe3 Mon Sep 17 00:00:00 2001 From: matatonic <73265741+matatonic@users.noreply.github.com> Date: Tue, 1 Aug 2023 21:26:00 -0400 Subject: [PATCH] extensions/openai: +Array input (batched) , +Fixes (#3309) --- extensions/openai/README.md | 3 +- extensions/openai/completions.py | 129 ++++++++++++++++--------------- extensions/openai/script.py | 2 +- 3 files changed, 70 insertions(+), 64 deletions(-) diff --git a/extensions/openai/README.md b/extensions/openai/README.md index 2083734a..bce8efb6 100644 --- a/extensions/openai/README.md +++ b/extensions/openai/README.md @@ -174,7 +174,7 @@ print(text) | /v1/models | openai.Model.list() | Lists models, Currently loaded model first, plus some compatibility options | | /v1/models/{id} | openai.Model.get() | returns whatever you ask for | | /v1/edits | openai.Edit.create() | Deprecated by openai, good with instruction following models | -| /v1/text_completion | openai.Completion.create() | Legacy endpoint, doesn't support array input, variable quality based on the model | +| /v1/text_completion | openai.Completion.create() | Legacy endpoint, variable quality based on the model | | /v1/completions | openai api completions.create | Legacy endpoint (v0.25) | | /v1/engines/*/embeddings | python-openai v0.25 | Legacy endpoint | | /v1/engines/*/generate | openai engines.generate | Legacy endpoint | @@ -204,6 +204,7 @@ Some hacky mappings: | 1.0 | typical_p | hardcoded to 1.0 | | logprobs & logit_bias | - | experimental, llama only, transformers-kin only (ExLlama_HF ok), can also use llama tokens if 'model' is not an openai model or will convert from tiktoken for the openai model specified in 'model' | | messages.name | - | not supported yet | +| suffix | - | not supported yet | | user | - | not supported yet | | functions/function_call | - | function calls are not supported yet | diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index b6c573fd..646da958 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -48,7 +48,7 @@ class LogprobProcessor(LogitsProcessor): top_tokens = [ decode(tok) for tok in top_indices[0] ] top_probs = [ float(x) for x in top_values[0] ] self.token_alternatives = dict(zip(top_tokens, top_probs)) - debug_msg(f"{self.__class__.__name__}(logprobs+1={self.logprobs+1}, token_alternatives={self.token_alternatives})") + debug_msg(repr(self)) return logits def __repr__(self): @@ -63,7 +63,8 @@ def convert_logprobs_to_tiktoken(model, logprobs): # return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()]) # except KeyError: # # assume native tokens if we can't find the tokenizer - return logprobs +# return logprobs + return logprobs def marshal_common_params(body): @@ -271,16 +272,16 @@ def chat_completions(body: dict, is_legacy: bool = False) -> dict: req_params['max_new_tokens'] = req_params['truncation_length'] # format the prompt from messages - prompt, token_count = messages_to_prompt(body, req_params, max_tokens) + prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings'] # set real max, avoid deeper errors if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: req_params['max_new_tokens'] = req_params['truncation_length'] - token_count + stopping_strings = req_params.pop('stopping_strings', []) + # generate reply ####################################### debug_msg({'prompt': prompt, 'req_params': req_params}) - stopping_strings = req_params.pop('stopping_strings', []) - logprob_proc = req_params.pop('logprob_proc', None) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) answer = '' @@ -347,7 +348,7 @@ def stream_chat_completions(body: dict, is_legacy: bool = False): req_params['max_new_tokens'] = req_params['truncation_length'] # format the prompt from messages - prompt, token_count = messages_to_prompt(body, req_params, max_tokens) + prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings'] # set real max, avoid deeper errors if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']: @@ -441,16 +442,9 @@ def completions(body: dict, is_legacy: bool = False): if not prompt_str in body: raise InvalidRequestError("Missing required input", param=prompt_str) - prompt = body[prompt_str] - if isinstance(prompt, list): - if prompt and isinstance(prompt[0], int): - try: - encoder = tiktoken.encoding_for_model(requested_model) - prompt = encoder.decode(prompt) - except KeyError: - prompt = decode(prompt)[0] - else: - raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) + prompt_arg = body[prompt_str] + if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)): + prompt_arg = [prompt_arg] # common params req_params = marshal_common_params(body) @@ -460,59 +454,75 @@ def completions(body: dict, is_legacy: bool = False): req_params['max_new_tokens'] = max_tokens requested_model = req_params.pop('requested_model') logprob_proc = req_params.pop('logprob_proc', None) - - token_count = len(encode(prompt)[0]) - - if token_count + max_tokens > req_params['truncation_length']: - err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." - # print(f"Warning: ${err_msg}") - raise InvalidRequestError(message=err_msg, param=max_tokens_str) - + stopping_strings = req_params.pop('stopping_strings', []) + #req_params['suffix'] = default(body, 'suffix', req_params['suffix']) req_params['echo'] = default(body, 'echo', req_params['echo']) req_params['top_k'] = default(body, 'best_of', req_params['top_k']) - # generate reply ####################################### - debug_msg({'prompt': prompt, 'req_params': req_params}) - stopping_strings = req_params.pop('stopping_strings', []) - generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + resp_list_data = [] + total_completion_token_count = 0 + total_prompt_token_count = 0 - answer = '' + for idx, prompt in enumerate(prompt_arg, start=0): + if isinstance(prompt[0], int): + # token lists + if requested_model == shared.model_name: + prompt = decode(prompt)[0] + else: + try: + encoder = tiktoken.encoding_for_model(requested_model) + prompt = encoder.decode(prompt) + except KeyError: + prompt = decode(prompt)[0] - for a in generator: - answer = a + token_count = len(encode(prompt)[0]) + total_prompt_token_count += token_count - # strip extra leading space off new generated content - if answer and answer[0] == ' ': - answer = answer[1:] + if token_count + max_tokens > req_params['truncation_length']: + err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." + # print(f"Warning: ${err_msg}") + raise InvalidRequestError(message=err_msg, param=max_tokens_str) - completion_token_count = len(encode(answer)[0]) - stop_reason = "stop" - if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: - stop_reason = "length" + # generate reply ####################################### + debug_msg({'prompt': prompt, 'req_params': req_params}) + generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) + answer = '' + + for a in generator: + answer = a + + # strip extra leading space off new generated content + if answer and answer[0] == ' ': + answer = answer[1:] + + completion_token_count = len(encode(answer)[0]) + total_completion_token_count += completion_token_count + stop_reason = "stop" + if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: + stop_reason = "length" + + respi = { + "index": idx, + "finish_reason": stop_reason, + "text": answer, + "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, + } + + resp_list_data.extend([respi]) resp = { "id": cmpl_id, "object": object_type, "created": created_time, "model": shared.model_name, # TODO: add Lora info? - resp_list: [{ - "index": 0, - "finish_reason": stop_reason, - "text": answer, - }], + resp_list: resp_list_data, "usage": { - "prompt_tokens": token_count, - "completion_tokens": completion_token_count, - "total_tokens": token_count + completion_token_count + "prompt_tokens": total_prompt_token_count, + "completion_tokens": total_completion_token_count, + "total_tokens": total_prompt_token_count + total_completion_token_count } } - if logprob_proc and logprob_proc.token_alternatives: - top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) - resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} - else: - resp[resp_list][0]["logprobs"] = None - return resp @@ -550,6 +560,10 @@ def stream_completions(body: dict, is_legacy: bool = False): req_params['max_new_tokens'] = max_tokens requested_model = req_params.pop('requested_model') logprob_proc = req_params.pop('logprob_proc', None) + stopping_strings = req_params.pop('stopping_strings', []) + #req_params['suffix'] = default(body, 'suffix', req_params['suffix']) + req_params['echo'] = default(body, 'echo', req_params['echo']) + req_params['top_k'] = default(body, 'best_of', req_params['top_k']) token_count = len(encode(prompt)[0]) @@ -558,9 +572,6 @@ def stream_completions(body: dict, is_legacy: bool = False): # print(f"Warning: ${err_msg}") raise InvalidRequestError(message=err_msg, param=max_tokens_str) - req_params['echo'] = default(body, 'echo', req_params['echo']) - req_params['top_k'] = default(body, 'best_of', req_params['top_k']) - def text_streaming_chunk(content): # begin streaming chunk = { @@ -572,13 +583,9 @@ def stream_completions(body: dict, is_legacy: bool = False): "index": 0, "finish_reason": None, "text": content, + "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, }], } - if logprob_proc: - top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) - chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} - else: - chunk[resp_list][0]["logprobs"] = None return chunk @@ -586,8 +593,6 @@ def stream_completions(body: dict, is_legacy: bool = False): # generate reply ####################################### debug_msg({'prompt': prompt, 'req_params': req_params}) - stopping_strings = req_params.pop('stopping_strings', []) - logprob_proc = req_params.pop('logprob_proc', None) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False) answer = '' diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 86f2deb7..f95205a5 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -120,7 +120,7 @@ class Handler(BaseHTTPRequestHandler): resp = OAImodels.list_models(is_legacy) else: model_name = self.path[len('/v1/models/'):] - resp = OAImodels.model_info() + resp = OAImodels.model_info(model_name) self.return_json(resp)