From 78bbc66fc4c0ec2613606a1e22e480f1c87a63db Mon Sep 17 00:00:00 2001 From: catalpaaa <89681913+catalpaaa@users.noreply.github.com> Date: Tue, 11 Apr 2023 08:30:06 -0700 Subject: [PATCH] allow custom stopping strings in all modes (#903) --- modules/chat.py | 26 +++++++++++++++----------- modules/shared.py | 1 + modules/text_generation.py | 10 +++++++--- server.py | 22 +++++++++++++--------- settings-template.json | 1 + 5 files changed, 37 insertions(+), 23 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 02b18477..0ef61f8c 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -74,8 +74,18 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat return prompt +def get_stopping_strings(state): + if state['mode'] == 'instruct': + stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"] + else: + stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"] + stopping_strings += state['custom_stopping_strings'] + return stopping_strings + + def extract_message_from_reply(reply, state): next_character_found = False + stopping_strings = get_stopping_strings(state) if state['stop_at_newline']: lines = reply.split('\n') @@ -83,7 +93,7 @@ def extract_message_from_reply(reply, state): if len(lines) > 1: next_character_found = True else: - for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]: + for string in stopping_strings: idx = reply.find(string) if idx != -1: reply = reply[:idx] @@ -92,7 +102,7 @@ def extract_message_from_reply(reply, state): # If something like "\nYo" is generated just before "\nYou:" # is completed, trim it if not next_character_found: - for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]: + for string in stopping_strings: for j in range(len(string) - 1, 0, -1): if reply[-j:] == string[:j]: reply = reply[:-j] @@ -106,10 +116,6 @@ def extract_message_from_reply(reply, state): def chatbot_wrapper(text, state, regenerate=False, _continue=False): - if state['mode'] == 'instruct': - stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"] - else: - stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"] # Defining some variables cumulative_reply = '' @@ -117,6 +123,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): just_started = True visible_text = custom_generate_chat_prompt = None eos_token = '\n' if state['stop_at_newline'] else None + stopping_strings = get_stopping_strings(state) # Check if any extension wants to hijack this function call for extension, _ in extensions_module.iterator(): @@ -186,15 +193,12 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): def impersonate_wrapper(text, state): - if state['mode'] == 'instruct': - stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"] - else: - stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"] # Defining some variables cumulative_reply = '' eos_token = '\n' if state['stop_at_newline'] else None prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], end_of_turn=state['end_of_turn'], impersonate=True) + stopping_strings = get_stopping_strings(state) # Yield *Is typing...* yield shared.processing_message @@ -498,4 +502,4 @@ def upload_your_profile_picture(img, name1, name2, mode): img.save(Path('cache/pfp_me.png')) print('Profile picture saved to "cache/pfp_me.png"') - return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) + return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index 9dc8d970..b278e2fd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -34,6 +34,7 @@ settings = { 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', 'greeting': 'Hello there!', 'end_of_turn': '', + 'custom_stopping_strings': '', 'stop_at_newline': False, 'add_bos_token': True, 'chat_prompt_size': 2048, diff --git a/modules/text_generation.py b/modules/text_generation.py index e8f283a9..7d62dddb 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -174,10 +174,14 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] if eos_token is not None: eos_token_ids.append(int(encode(eos_token)[0][-1])) + + # Handling the stopping strings stopping_criteria_list = transformers.StoppingCriteriaList() - if type(stopping_strings) is list and len(stopping_strings) > 0: - t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings] - stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) + for st in [stopping_strings, state['custom_stopping_strings']]: + if type(st) is list and len(st) > 0: + sentinel_token_ids = [encode(string, 0, add_special_tokens=False) for string in st] + stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0]))) + break if not shared.args.flexgen: for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: diff --git a/server.py b/server.py index d88a8655..aa8180ac 100644 --- a/server.py +++ b/server.py @@ -232,10 +232,8 @@ def create_model_menus(): def create_settings_menus(default_preset): + generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) - for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts', 'add_bos_token']: - generate_params[k] = shared.settings[k] - shared.gradio['generate_state'] = gr.State(generate_params) with gr.Row(): with gr.Column(): @@ -273,7 +271,12 @@ def create_settings_menus(default_preset): with gr.Column(): shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') + + with gr.Row(): + shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Disabling this can make the replies more creative.') + + with gr.Row(): + shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas. For instance: "\\nYour Assistant:", "\\nThe assistant:"') with gr.Accordion('Soft prompt', open=False): with gr.Row(): @@ -284,7 +287,7 @@ def create_settings_menus(default_preset): with gr.Row(): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) - shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) + shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu']) @@ -358,7 +361,7 @@ title = 'Text generation web UI' def list_interface_input_elements(chat=False): - elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token'] + elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'custom_stopping_strings'] if chat: elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode'] return elements @@ -368,6 +371,7 @@ def gather_interface_values(*args): output = {} for i, element in enumerate(shared.input_elements): output[element] = args[i] + output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]") return output @@ -453,7 +457,7 @@ def create_interface(): shared.gradio['chat_prompt_size'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) with gr.Column(): shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') - shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?') + shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character') create_settings_menus(default_preset) @@ -563,7 +567,7 @@ def create_interface(): with gr.Tab("Parameters", elem_id="parameters"): create_settings_menus(default_preset) - shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click( @@ -607,7 +611,7 @@ def create_interface(): with gr.Tab("Parameters", elem_id="parameters"): create_settings_menus(default_preset) - shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click( diff --git a/settings-template.json b/settings-template.json index 08db98fa..e38293df 100644 --- a/settings-template.json +++ b/settings-template.json @@ -8,6 +8,7 @@ "context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.", "greeting": "Hello there!", "end_of_turn": "", + "custom_stopping_strings": "", "stop_at_newline": false, "add_bos_token": true, "chat_prompt_size": 2048,