mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Implement echo/suffix parameters
This commit is contained in:
parent
cee099f131
commit
3d59346871
@ -349,8 +349,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||||||
generate_params['stream'] = stream
|
generate_params['stream'] = stream
|
||||||
requested_model = generate_params.pop('model')
|
requested_model = generate_params.pop('model')
|
||||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||||
# generate_params['suffix'] = body.get('suffix', generate_params['suffix'])
|
suffix = body['suffix'] if body['suffix'] else ''
|
||||||
generate_params['echo'] = body.get('echo', generate_params['echo'])
|
echo = body['echo']
|
||||||
|
|
||||||
if not stream:
|
if not stream:
|
||||||
prompt_arg = body[prompt_str]
|
prompt_arg = body[prompt_str]
|
||||||
@ -373,6 +373,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
prompt = decode(prompt)[0]
|
prompt = decode(prompt)[0]
|
||||||
|
|
||||||
|
prefix = prompt if echo else ''
|
||||||
token_count = len(encode(prompt)[0])
|
token_count = len(encode(prompt)[0])
|
||||||
total_prompt_token_count += token_count
|
total_prompt_token_count += token_count
|
||||||
|
|
||||||
@ -393,7 +394,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||||||
respi = {
|
respi = {
|
||||||
"index": idx,
|
"index": idx,
|
||||||
"finish_reason": stop_reason,
|
"finish_reason": stop_reason,
|
||||||
"text": answer,
|
"text": prefix + answer + suffix,
|
||||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -425,6 +426,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||||||
else:
|
else:
|
||||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||||
|
|
||||||
|
prefix = prompt if echo else ''
|
||||||
token_count = len(encode(prompt)[0])
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
def text_streaming_chunk(content):
|
def text_streaming_chunk(content):
|
||||||
@ -444,7 +446,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||||||
|
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
yield text_streaming_chunk('')
|
yield text_streaming_chunk(prefix)
|
||||||
|
|
||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||||
@ -472,7 +474,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
|||||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
stop_reason = "length"
|
stop_reason = "length"
|
||||||
|
|
||||||
chunk = text_streaming_chunk('')
|
chunk = text_streaming_chunk(suffix)
|
||||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||||
chunk["usage"] = {
|
chunk["usage"] = {
|
||||||
"prompt_tokens": token_count,
|
"prompt_tokens": token_count,
|
||||||
|
@ -57,7 +57,7 @@ class CompletionRequestParams(BaseModel):
|
|||||||
suffix: str | None = None
|
suffix: str | None = None
|
||||||
temperature: float | None = 1
|
temperature: float | None = 1
|
||||||
top_p: float | None = 1
|
top_p: float | None = 1
|
||||||
user: str | None = None
|
user: str | None = Field(default=None, description="Unused parameter.")
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(GenerationOptions, CompletionRequestParams):
|
class CompletionRequest(GenerationOptions, CompletionRequestParams):
|
||||||
|
Loading…
Reference in New Issue
Block a user