mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Add max_tokens_second param (#3533)
This commit is contained in:
parent
fe1f7c6513
commit
cec8db52e5
@ -22,6 +22,7 @@ async def run(user_input, history):
|
|||||||
'user_input': user_input,
|
'user_input': user_input,
|
||||||
'max_new_tokens': 250,
|
'max_new_tokens': 250,
|
||||||
'auto_max_new_tokens': False,
|
'auto_max_new_tokens': False,
|
||||||
|
'max_tokens_second': 0,
|
||||||
'history': history,
|
'history': history,
|
||||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||||
'character': 'Example',
|
'character': 'Example',
|
||||||
|
@ -16,6 +16,7 @@ def run(user_input, history):
|
|||||||
'user_input': user_input,
|
'user_input': user_input,
|
||||||
'max_new_tokens': 250,
|
'max_new_tokens': 250,
|
||||||
'auto_max_new_tokens': False,
|
'auto_max_new_tokens': False,
|
||||||
|
'max_tokens_second': 0,
|
||||||
'history': history,
|
'history': history,
|
||||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||||
'character': 'Example',
|
'character': 'Example',
|
||||||
|
@ -21,6 +21,7 @@ async def run(context):
|
|||||||
'prompt': context,
|
'prompt': context,
|
||||||
'max_new_tokens': 250,
|
'max_new_tokens': 250,
|
||||||
'auto_max_new_tokens': False,
|
'auto_max_new_tokens': False,
|
||||||
|
'max_tokens_second': 0,
|
||||||
|
|
||||||
# Generation params. If 'preset' is set to different than 'None', the values
|
# Generation params. If 'preset' is set to different than 'None', the values
|
||||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||||
|
@ -13,6 +13,7 @@ def run(prompt):
|
|||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
'max_new_tokens': 250,
|
'max_new_tokens': 250,
|
||||||
'auto_max_new_tokens': False,
|
'auto_max_new_tokens': False,
|
||||||
|
'max_tokens_second': 0,
|
||||||
|
|
||||||
# Generation params. If 'preset' is set to different than 'None', the values
|
# Generation params. If 'preset' is set to different than 'None', the values
|
||||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||||
|
@ -22,6 +22,7 @@ def build_parameters(body, chat=False):
|
|||||||
generate_params = {
|
generate_params = {
|
||||||
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
||||||
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
|
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
|
||||||
|
'max_tokens_second': int(body.get('max_tokens_second', 0)),
|
||||||
'do_sample': bool(body.get('do_sample', True)),
|
'do_sample': bool(body.get('do_sample', True)),
|
||||||
'temperature': float(body.get('temperature', 0.5)),
|
'temperature': float(body.get('temperature', 0.5)),
|
||||||
'top_p': float(body.get('top_p', 1)),
|
'top_p': float(body.get('top_p', 1)),
|
||||||
|
@ -5,6 +5,7 @@ import copy
|
|||||||
default_req_params = {
|
default_req_params = {
|
||||||
'max_new_tokens': 16, # 'Inf' for chat
|
'max_new_tokens': 16, # 'Inf' for chat
|
||||||
'auto_max_new_tokens': False,
|
'auto_max_new_tokens': False,
|
||||||
|
'max_tokens_second': 0,
|
||||||
'temperature': 1.0,
|
'temperature': 1.0,
|
||||||
'top_p': 1.0,
|
'top_p': 1.0,
|
||||||
'top_k': 1, # choose 20 for chat in absence of another default
|
'top_k': 1, # choose 20 for chat in absence of another default
|
||||||
|
@ -47,6 +47,7 @@ settings = {
|
|||||||
'truncation_length_max': 16384,
|
'truncation_length_max': 16384,
|
||||||
'custom_stopping_strings': '',
|
'custom_stopping_strings': '',
|
||||||
'auto_max_new_tokens': False,
|
'auto_max_new_tokens': False,
|
||||||
|
'max_tokens_second': 0,
|
||||||
'ban_eos_token': False,
|
'ban_eos_token': False,
|
||||||
'add_bos_token': True,
|
'add_bos_token': True,
|
||||||
'skip_special_tokens': True,
|
'skip_special_tokens': True,
|
||||||
|
@ -80,10 +80,22 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||||||
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||||
if is_stream:
|
if is_stream:
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
|
|
||||||
last_update = cur_time
|
# Maximum number of tokens/second
|
||||||
|
if state['max_tokens_second'] > 0:
|
||||||
|
diff = 1 / state['max_tokens_second'] - (cur_time - last_update)
|
||||||
|
if diff > 0:
|
||||||
|
time.sleep(diff)
|
||||||
|
|
||||||
|
last_update = time.time()
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
|
# Limit updates to 24 per second to not stress low latency networks
|
||||||
|
else:
|
||||||
|
if cur_time - last_update > 0.041666666666666664:
|
||||||
|
last_update = cur_time
|
||||||
|
yield reply
|
||||||
|
|
||||||
if stop_found:
|
if stop_found:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -93,6 +93,7 @@ def list_interface_input_elements():
|
|||||||
elements = [
|
elements = [
|
||||||
'max_new_tokens',
|
'max_new_tokens',
|
||||||
'auto_max_new_tokens',
|
'auto_max_new_tokens',
|
||||||
|
'max_tokens_second',
|
||||||
'seed',
|
'seed',
|
||||||
'temperature',
|
'temperature',
|
||||||
'top_p',
|
'top_p',
|
||||||
|
@ -105,7 +105,6 @@ def create_ui(default_preset):
|
|||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.')
|
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.')
|
||||||
|
|
||||||
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.')
|
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.')
|
||||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||||
@ -114,6 +113,7 @@ def create_ui(default_preset):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
|
||||||
|
shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum number of tokens/second', info='To make text readable in real time.')
|
||||||
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
|
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
|
||||||
|
@ -17,6 +17,7 @@ truncation_length_min: 0
|
|||||||
truncation_length_max: 16384
|
truncation_length_max: 16384
|
||||||
custom_stopping_strings: ''
|
custom_stopping_strings: ''
|
||||||
auto_max_new_tokens: false
|
auto_max_new_tokens: false
|
||||||
|
max_tokens_second: 0
|
||||||
ban_eos_token: false
|
ban_eos_token: false
|
||||||
add_bos_token: true
|
add_bos_token: true
|
||||||
skip_special_tokens: true
|
skip_special_tokens: true
|
||||||
|
Loading…
Reference in New Issue
Block a user