mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Implement echo/suffix parameters
This commit is contained in:
parent
cee099f131
commit
3d59346871
@ -349,8 +349,8 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||
generate_params['stream'] = stream
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
# generate_params['suffix'] = body.get('suffix', generate_params['suffix'])
|
||||
generate_params['echo'] = body.get('echo', generate_params['echo'])
|
||||
suffix = body['suffix'] if body['suffix'] else ''
|
||||
echo = body['echo']
|
||||
|
||||
if not stream:
|
||||
prompt_arg = body[prompt_str]
|
||||
@ -373,6 +373,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
|
||||
prefix = prompt if echo else ''
|
||||
token_count = len(encode(prompt)[0])
|
||||
total_prompt_token_count += token_count
|
||||
|
||||
@ -393,7 +394,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||
respi = {
|
||||
"index": idx,
|
||||
"finish_reason": stop_reason,
|
||||
"text": answer,
|
||||
"text": prefix + answer + suffix,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}
|
||||
|
||||
@ -425,6 +426,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||
else:
|
||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||
|
||||
prefix = prompt if echo else ''
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
def text_streaming_chunk(content):
|
||||
@ -444,7 +446,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||
|
||||
return chunk
|
||||
|
||||
yield text_streaming_chunk('')
|
||||
yield text_streaming_chunk(prefix)
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
@ -472,7 +474,7 @@ def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
chunk = text_streaming_chunk('')
|
||||
chunk = text_streaming_chunk(suffix)
|
||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||
chunk["usage"] = {
|
||||
"prompt_tokens": token_count,
|
||||
|
@ -57,7 +57,7 @@ class CompletionRequestParams(BaseModel):
|
||||
suffix: str | None = None
|
||||
temperature: float | None = 1
|
||||
top_p: float | None = 1
|
||||
user: str | None = None
|
||||
user: str | None = Field(default=None, description="Unused parameter.")
|
||||
|
||||
|
||||
class CompletionRequest(GenerationOptions, CompletionRequestParams):
|
||||
|
Loading…
Reference in New Issue
Block a user