mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Fix stopping strings for llama-3 and phi (#6043)
This commit is contained in:
parent
8aaa0a6f4e
commit
5499bc9bc8
@ -45,34 +45,35 @@ yaml.add_representer(str, str_presenter)
|
|||||||
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
|
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
|
||||||
|
|
||||||
|
|
||||||
def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
|
def extract_message_prefix_suffix(renderer, strip_trailing_spaces=True):
|
||||||
'''
|
'''
|
||||||
Given a Jinja template, reverse-engineers the prefix and the suffix for
|
Given a Jinja template, extracts the prefix and suffix for
|
||||||
an assistant message (if impersonate=False) or an user message
|
an assistant message and a user message. It assumes that they
|
||||||
(if impersonate=True)
|
share the same suffix.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if impersonate:
|
messages = [
|
||||||
messages = [
|
{"role": "user", "content": "<<|user-message-1|>>"},
|
||||||
{"role": "user", "content": "<<|user-message-1|>>"},
|
{"role": "assistant", "content": "<<|assistant-message-1|>>"},
|
||||||
{"role": "user", "content": "<<|user-message-2|>>"},
|
{"role": "user", "content": "<<|user-message-2|>>"},
|
||||||
]
|
{"role": "assistant", "content": "<<|assistant-message-2|>>"},
|
||||||
else:
|
]
|
||||||
messages = [
|
|
||||||
{"role": "assistant", "content": "<<|user-message-1|>>"},
|
|
||||||
{"role": "assistant", "content": "<<|user-message-2|>>"},
|
|
||||||
]
|
|
||||||
|
|
||||||
prompt = renderer(messages=messages)
|
prompt = renderer(messages=messages)
|
||||||
|
unwanted_suffix = renderer(messages=[])
|
||||||
|
|
||||||
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
|
suffix = prompt.split('<<|assistant-message-2|>>')[1]
|
||||||
suffix = prompt.split("<<|user-message-2|>>")[1]
|
if unwanted_suffix != '':
|
||||||
prefix = suffix_plus_prefix[len(suffix):]
|
suffix = suffix[:-len(unwanted_suffix)]
|
||||||
|
|
||||||
|
prefix_user = prompt.split('<<|assistant-message-1|>>')[1].split('<<|user-message-2|>>')[0][len(suffix):]
|
||||||
|
prefix_assistant = prompt.split('<<|user-message-1|>>')[1].split('<<|assistant-message-1|>>')[0][len(suffix):]
|
||||||
|
|
||||||
if strip_trailing_spaces:
|
if strip_trailing_spaces:
|
||||||
prefix = prefix.rstrip(' ')
|
prefix_user = prefix_user.rstrip(' ')
|
||||||
|
prefix_assistant = prefix_assistant.rstrip(' ')
|
||||||
|
|
||||||
return prefix, suffix
|
return prefix_user, prefix_assistant, suffix
|
||||||
|
|
||||||
|
|
||||||
def generate_chat_prompt(user_input, state, **kwargs):
|
def generate_chat_prompt(user_input, state, **kwargs):
|
||||||
@ -125,7 +126,12 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
messages.append({"role": "user", "content": user_input})
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
def remove_extra_bos(prompt):
|
def remove_extra_bos(prompt):
|
||||||
for bos_token in ['<s>', '<|startoftext|>', '<BOS_TOKEN>', '<|endoftext|>']:
|
if hasattr(shared.tokenizer, 'bos_token_id'):
|
||||||
|
bos_tokens = [shared.tokenizer.decode(shared.tokenizer.bos_token_id)]
|
||||||
|
else:
|
||||||
|
bos_tokens = ['<s>', '<|startoftext|>', '<BOS_TOKEN>']
|
||||||
|
|
||||||
|
for bos_token in bos_tokens:
|
||||||
while prompt.startswith(bos_token):
|
while prompt.startswith(bos_token):
|
||||||
prompt = prompt[len(bos_token):]
|
prompt = prompt[len(bos_token):]
|
||||||
|
|
||||||
@ -137,6 +143,9 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
else:
|
else:
|
||||||
prompt = renderer(messages=messages)
|
prompt = renderer(messages=messages)
|
||||||
|
|
||||||
|
prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer, strip_trailing_spaces=not _continue)
|
||||||
|
prefix = prefix_user if impersonate else prefix_assistant
|
||||||
|
|
||||||
if state['mode'] == 'chat-instruct':
|
if state['mode'] == 'chat-instruct':
|
||||||
outer_messages = []
|
outer_messages = []
|
||||||
if state['custom_system_message'].strip() != '':
|
if state['custom_system_message'].strip() != '':
|
||||||
@ -148,29 +157,25 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
command = command.replace('<|prompt|>', prompt)
|
command = command.replace('<|prompt|>', prompt)
|
||||||
command = replace_character_names(command, state['name1'], state['name2'])
|
command = replace_character_names(command, state['name1'], state['name2'])
|
||||||
|
|
||||||
|
|
||||||
if _continue:
|
if _continue:
|
||||||
prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
|
|
||||||
prefix += messages[-1]["content"]
|
prefix += messages[-1]["content"]
|
||||||
else:
|
elif not impersonate:
|
||||||
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
|
prefix = apply_extensions('bot_prefix', prefix, state)
|
||||||
if not impersonate:
|
|
||||||
prefix = apply_extensions('bot_prefix', prefix, state)
|
|
||||||
|
|
||||||
outer_messages.append({"role": "user", "content": command})
|
outer_messages.append({"role": "user", "content": command})
|
||||||
outer_messages.append({"role": "assistant", "content": prefix})
|
outer_messages.append({"role": "assistant", "content": prefix})
|
||||||
|
|
||||||
prompt = instruction_template.render(messages=outer_messages)
|
prompt = instruction_template.render(messages=outer_messages)
|
||||||
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
|
|
||||||
if len(suffix) > 0:
|
if len(suffix) > 0:
|
||||||
prompt = prompt[:-len(suffix)]
|
prompt = prompt[:-len(suffix)]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
if _continue:
|
if _continue:
|
||||||
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
|
|
||||||
if len(suffix) > 0:
|
if len(suffix) > 0:
|
||||||
prompt = prompt[:-len(suffix)]
|
prompt = prompt[:-len(suffix)]
|
||||||
else:
|
else:
|
||||||
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
|
|
||||||
if state['mode'] == 'chat' and not impersonate:
|
if state['mode'] == 'chat' and not impersonate:
|
||||||
prefix = apply_extensions('bot_prefix', prefix, state)
|
prefix = apply_extensions('bot_prefix', prefix, state)
|
||||||
|
|
||||||
@ -249,15 +254,11 @@ def get_stopping_strings(state):
|
|||||||
renderers.append(renderer)
|
renderers.append(renderer)
|
||||||
|
|
||||||
for renderer in renderers:
|
for renderer in renderers:
|
||||||
prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
|
prefix_user, prefix_assistant, suffix = extract_message_prefix_suffix(renderer)
|
||||||
prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)
|
|
||||||
|
|
||||||
stopping_strings += [
|
for item in [suffix + prefix_assistant, suffix + prefix_user, suffix]:
|
||||||
suffix_user + prefix_bot,
|
stopping_strings.append(item)
|
||||||
suffix_user + prefix_user,
|
stopping_strings.append(item.rstrip())
|
||||||
suffix_bot + prefix_bot,
|
|
||||||
suffix_bot + prefix_user,
|
|
||||||
]
|
|
||||||
|
|
||||||
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
|
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
|
||||||
stopping_strings += state.pop('stopping_strings')
|
stopping_strings += state.pop('stopping_strings')
|
||||||
|
Loading…
Reference in New Issue
Block a user