From ad54d524f7586d2fa8ab33bad46d0a4b108aa822 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 22 May 2024 17:18:08 -0700 Subject: [PATCH] Revert "Fix stopping strings for llama-3 and phi (#6043)" This reverts commit 5499bc9bc8d2b24f163c0026dce05df21a25a691. --- modules/chat.py | 73 ++++++++++++++++++++++++------------------------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 6a388a04..43f5466b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -45,35 +45,34 @@ yaml.add_representer(str, str_presenter) yaml.representer.SafeRepresenter.add_representer(str, str_presenter) -def extract_message_prefix_suffix(renderer, strip_trailing_spaces=True): +def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True): ''' - Given a Jinja template, extracts the prefix and suffix for - an assistant message and a user message. It assumes that they - share the same suffix. + Given a Jinja template, reverse-engineers the prefix and the suffix for + an assistant message (if impersonate=False) or an user message + (if impersonate=True) ''' - messages = [ - {"role": "user", "content": "<<|user-message-1|>>"}, - {"role": "assistant", "content": "<<|assistant-message-1|>>"}, - {"role": "user", "content": "<<|user-message-2|>>"}, - {"role": "assistant", "content": "<<|assistant-message-2|>>"}, - ] + if impersonate: + messages = [ + {"role": "user", "content": "<<|user-message-1|>>"}, + {"role": "user", "content": "<<|user-message-2|>>"}, + ] + else: + messages = [ + {"role": "assistant", "content": "<<|user-message-1|>>"}, + {"role": "assistant", "content": "<<|user-message-2|>>"}, + ] prompt = renderer(messages=messages) - unwanted_suffix = renderer(messages=[]) - suffix = prompt.split('<<|assistant-message-2|>>')[1] - if unwanted_suffix != '': - suffix = suffix[:-len(unwanted_suffix)] - - prefix_user = prompt.split('<<|assistant-message-1|>>')[1].split('<<|user-message-2|>>')[0][len(suffix):] - prefix_assistant = prompt.split('<<|user-message-1|>>')[1].split('<<|assistant-message-1|>>')[0][len(suffix):] + suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0] + suffix = prompt.split("<<|user-message-2|>>")[1] + prefix = suffix_plus_prefix[len(suffix):] if strip_trailing_spaces: - prefix_user = prefix_user.rstrip(' ') - prefix_assistant = prefix_assistant.rstrip(' ') + prefix = prefix.rstrip(' ') - return prefix_user, prefix_assistant, suffix + return prefix, suffix def generate_chat_prompt(user_input, state, **kwargs): @@ -126,12 +125,7 @@ def generate_chat_prompt(user_input, state, **kwargs): messages.append({"role": "user", "content": user_input}) def remove_extra_bos(prompt): - if hasattr(shared.tokenizer, 'bos_token_id'): - bos_tokens = [shared.tokenizer.decode(shared.tokenizer.bos_token_id)] - else: - bos_tokens = ['', '<|startoftext|>', ''] - - for bos_token in bos_tokens: + for bos_token in ['', '<|startoftext|>', '', '<|endoftext|>']: while prompt.startswith(bos_token): prompt = prompt[len(bos_token):] @@ -143,9 +137,6 @@ def generate_chat_prompt(user_input, state, **kwargs): else: prompt = renderer(messages=messages) - prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue) - prefix = prefix_user if impersonate else prefix_assistant - if state['mode'] == 'chat-instruct': outer_messages = [] if state['custom_system_message'].strip() != '': @@ -157,25 +148,29 @@ def generate_chat_prompt(user_input, state, **kwargs): command = command.replace('<|prompt|>', prompt) command = replace_character_names(command, state['name1'], state['name2']) - if _continue: + prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0] prefix += messages[-1]["content"] - elif not impersonate: - prefix = apply_extensions('bot_prefix', prefix, state) + else: + prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] + if not impersonate: + prefix = apply_extensions('bot_prefix', prefix, state) outer_messages.append({"role": "user", "content": command}) outer_messages.append({"role": "assistant", "content": prefix}) prompt = instruction_template.render(messages=outer_messages) + suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1] if len(suffix) > 0: prompt = prompt[:-len(suffix)] else: - if _continue: + suffix = get_generation_prompt(renderer, impersonate=impersonate)[1] if len(suffix) > 0: prompt = prompt[:-len(suffix)] else: + prefix = get_generation_prompt(renderer, impersonate=impersonate)[0] if state['mode'] == 'chat' and not impersonate: prefix = apply_extensions('bot_prefix', prefix, state) @@ -254,11 +249,15 @@ def get_stopping_strings(state): renderers.append(renderer) for renderer in renderers: - prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer) + prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False) + prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True) - for item in [suffix + prefix_assistant, suffix + prefix_user, suffix]: - stopping_strings.append(item) - stopping_strings.append(item.rstrip()) + stopping_strings += [ + suffix_user + prefix_bot, + suffix_user + prefix_user, + suffix_bot + prefix_bot, + suffix_bot + prefix_user, + ] if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): stopping_strings += state.pop('stopping_strings')