From cec8db52e572496ef53f4275bd39ec497df54e68 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 29 Aug 2023 17:44:31 -0300 Subject: [PATCH] Add max_tokens_second param (#3533) --- api-examples/api-example-chat-stream.py | 1 + api-examples/api-example-chat.py | 1 + api-examples/api-example-stream.py | 1 + api-examples/api-example.py | 1 + extensions/api/util.py | 1 + extensions/openai/defaults.py | 1 + modules/shared.py | 1 + modules/text_generation.py | 16 ++++++++++++++-- modules/ui.py | 1 + modules/ui_parameters.py | 2 +- settings-template.yaml | 1 + 11 files changed, 24 insertions(+), 3 deletions(-) diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index c8dbbc5a..5670d4cf 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -22,6 +22,7 @@ async def run(user_input, history): 'user_input': user_input, 'max_new_tokens': 250, 'auto_max_new_tokens': False, + 'max_tokens_second': 0, 'history': history, 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'character': 'Example', diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index e1796963..26c69b73 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -16,6 +16,7 @@ def run(user_input, history): 'user_input': user_input, 'max_new_tokens': 250, 'auto_max_new_tokens': False, + 'max_tokens_second': 0, 'history': history, 'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct' 'character': 'Example', diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index bf5eabac..c042a50b 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -21,6 +21,7 @@ async def run(context): 'prompt': context, 'max_new_tokens': 250, 'auto_max_new_tokens': False, + 'max_tokens_second': 0, # Generation params. If 'preset' is set to different than 'None', the values # in presets/preset-name.yaml are used instead of the individual numbers. diff --git a/api-examples/api-example.py b/api-examples/api-example.py index 16029807..47362754 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -13,6 +13,7 @@ def run(prompt): 'prompt': prompt, 'max_new_tokens': 250, 'auto_max_new_tokens': False, + 'max_tokens_second': 0, # Generation params. If 'preset' is set to different than 'None', the values # in presets/preset-name.yaml are used instead of the individual numbers. diff --git a/extensions/api/util.py b/extensions/api/util.py index 032a9e5c..6d0cb170 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -22,6 +22,7 @@ def build_parameters(body, chat=False): generate_params = { 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))), 'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)), + 'max_tokens_second': int(body.get('max_tokens_second', 0)), 'do_sample': bool(body.get('do_sample', True)), 'temperature': float(body.get('temperature', 0.5)), 'top_p': float(body.get('top_p', 1)), diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py index ffef12d0..c6a6adfd 100644 --- a/extensions/openai/defaults.py +++ b/extensions/openai/defaults.py @@ -5,6 +5,7 @@ import copy default_req_params = { 'max_new_tokens': 16, # 'Inf' for chat 'auto_max_new_tokens': False, + 'max_tokens_second': 0, 'temperature': 1.0, 'top_p': 1.0, 'top_k': 1, # choose 20 for chat in absence of another default diff --git a/modules/shared.py b/modules/shared.py index 28481a85..b587a99a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -47,6 +47,7 @@ settings = { 'truncation_length_max': 16384, 'custom_stopping_strings': '', 'auto_max_new_tokens': False, + 'max_tokens_second': 0, 'ban_eos_token': False, 'add_bos_token': True, 'skip_special_tokens': True, diff --git a/modules/text_generation.py b/modules/text_generation.py index 5128e503..c3cf74da 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -80,10 +80,22 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap reply, stop_found = apply_stopping_strings(reply, all_stop_strings) if is_stream: cur_time = time.time() - if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps - last_update = cur_time + + # Maximum number of tokens/second + if state['max_tokens_second'] > 0: + diff = 1 / state['max_tokens_second'] - (cur_time - last_update) + if diff > 0: + time.sleep(diff) + + last_update = time.time() yield reply + # Limit updates to 24 per second to not stress low latency networks + else: + if cur_time - last_update > 0.041666666666666664: + last_update = cur_time + yield reply + if stop_found: break diff --git a/modules/ui.py b/modules/ui.py index aa72f287..021bbb6d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -93,6 +93,7 @@ def list_interface_input_elements(): elements = [ 'max_new_tokens', 'auto_max_new_tokens', + 'max_tokens_second', 'seed', 'temperature', 'top_p', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index b5ce5ac9..169ab500 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -105,7 +105,6 @@ def create_ui(default_preset): with gr.Column(): shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.') - shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.') shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') @@ -114,6 +113,7 @@ def create_ui(default_preset): with gr.Row(): with gr.Column(): shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.') + shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum number of tokens/second', info='To make text readable in real time.') shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"') with gr.Column(): shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.') diff --git a/settings-template.yaml b/settings-template.yaml index b819eb39..ae2dd9ed 100644 --- a/settings-template.yaml +++ b/settings-template.yaml @@ -17,6 +17,7 @@ truncation_length_min: 0 truncation_length_max: 16384 custom_stopping_strings: '' auto_max_new_tokens: false +max_tokens_second: 0 ban_eos_token: false add_bos_token: true skip_special_tokens: true