Fix generation_attempts continuing after an empty reply

This commit is contained in:
oobabooga 2023-05-21 22:14:50 -03:00
parent e18534fe12
commit fb91406e93

View File

@ -53,7 +53,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
history = kwargs.get('history', shared.history)['internal']
is_instruct = state['mode'] == 'instruct'
# Finding the maximum prompt size
# FInd the maximum prompt size
chat_prompt_size = state['chat_prompt_size']
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
@ -66,7 +66,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
substrings = all_substrings['instruct' if is_instruct else 'chat']
# Creating the template for "chat-instruct" mode
# Create the template for "chat-instruct" mode
if state['mode'] == 'chat-instruct':
wrapper = ''
command = state['chat-instruct_command'].replace('<|character|>', state['name2'] if not impersonate else state['name1'])
@ -83,7 +83,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
else:
wrapper = '<|prompt|>'
# Building the prompt
# Build the prompt
min_rows = 3
i = len(history) - 1
rows = [state['context_instruct'] if is_instruct else f"{state['context'].strip()}\n"]
@ -107,11 +107,11 @@ def generate_chat_prompt(user_input, state, **kwargs):
min_rows = 2
rows.append(substrings['user_turn_stripped'].rstrip(' '))
elif not _continue:
# Adding the user message
# Add the user message
if len(user_input) > 0:
rows.append(replace_all(substrings['user_turn'], {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
# Adding the Character prefix
# Add the character prefix
if state['mode'] != 'chat-instruct':
rows.append(apply_extensions("bot_prefix", substrings['bot_turn_stripped'].rstrip(' ')))
@ -192,7 +192,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
return
# Defining some variables
cumulative_reply = ''
just_started = True
visible_text = None
eos_token = '\n' if state['stop_at_newline'] else None
@ -232,12 +231,13 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
prompt = generate_chat_prompt(text, state, **kwargs)
# Generate
cumulative_reply = ''
for i in range(state['chat_generation_attempts']):
reply = None
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)):
reply = cumulative_reply + reply
# Extracting the reply
# Extract the reply
reply, next_character_found = extract_message_from_reply(reply, state)
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = apply_extensions("output", visible_reply)
@ -268,7 +268,7 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
if next_character_found:
break
if reply in [None, '']:
if reply in [None, cumulative_reply]:
break
else:
cumulative_reply = reply