Add max_tokens_second param (#3533)

This commit is contained in:
oobabooga 2023-08-29 17:44:31 -03:00 committed by GitHub
parent fe1f7c6513
commit cec8db52e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 24 additions and 3 deletions

View File

@ -22,6 +22,7 @@ async def run(user_input, history):
'user_input': user_input,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',

View File

@ -16,6 +16,7 @@ def run(user_input, history):
'user_input': user_input,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'history': history,
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
'character': 'Example',

View File

@ -21,6 +21,7 @@ async def run(context):
'prompt': context,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.

View File

@ -13,6 +13,7 @@ def run(prompt):
'prompt': prompt,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.

View File

@ -22,6 +22,7 @@ def build_parameters(body, chat=False):
generate_params = {
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
'max_tokens_second': int(body.get('max_tokens_second', 0)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),

View File

@ -5,6 +5,7 @@ import copy
default_req_params = {
'max_new_tokens': 16, # 'Inf' for chat
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'temperature': 1.0,
'top_p': 1.0,
'top_k': 1, # choose 20 for chat in absence of another default

View File

@ -47,6 +47,7 @@ settings = {
'truncation_length_max': 16384,
'custom_stopping_strings': '',
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'ban_eos_token': False,
'add_bos_token': True,
'skip_special_tokens': True,

View File

@ -80,7 +80,19 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
if is_stream:
cur_time = time.time()
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
# Maximum number of tokens/second
if state['max_tokens_second'] > 0:
diff = 1 / state['max_tokens_second'] - (cur_time - last_update)
if diff > 0:
time.sleep(diff)
last_update = time.time()
yield reply
# Limit updates to 24 per second to not stress low latency networks
else:
if cur_time - last_update > 0.041666666666666664:
last_update = cur_time
yield reply

View File

@ -93,6 +93,7 @@ def list_interface_input_elements():
elements = [
'max_new_tokens',
'auto_max_new_tokens',
'max_tokens_second',
'seed',
'temperature',
'top_p',

View File

@ -105,7 +105,6 @@ def create_ui(default_preset):
with gr.Column():
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.')
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.')
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
@ -114,6 +113,7 @@ def create_ui(default_preset):
with gr.Row():
with gr.Column():
shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum number of tokens/second', info='To make text readable in real time.')
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
with gr.Column():
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')

View File

@ -17,6 +17,7 @@ truncation_length_min: 0
truncation_length_max: 16384
custom_stopping_strings: ''
auto_max_new_tokens: false
max_tokens_second: 0
ban_eos_token: false
add_bos_token: true
skip_special_tokens: true