mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Add a Continue button to chat mode
This commit is contained in:
parent
170e0c05c4
commit
d29f4624e9
@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
||||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
rows = [f"{context.strip()}\n"]
|
||||
|
||||
@ -39,6 +40,9 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
|
||||
i = len(shared.history['internal']) - 1
|
||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
||||
if _continue and i == len(shared.history['internal']) - 1:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||
else:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
||||
string = shared.history['internal'][i][0]
|
||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
if impersonate:
|
||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||
limit = 2
|
||||
elif _continue:
|
||||
limit = 3
|
||||
else:
|
||||
# Adding the user message
|
||||
user_input = fix_newlines(user_input)
|
||||
@ -56,12 +62,12 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||
|
||||
limit = 3
|
||||
|
||||
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
|
||||
rows.pop(1)
|
||||
prompt = ''.join(rows)
|
||||
|
||||
if also_return_rows:
|
||||
return prompt, rows
|
||||
else:
|
||||
@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
return reply, next_character_found
|
||||
|
||||
|
||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
|
||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
else:
|
||||
@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||
just_started = True
|
||||
name1_original = name1
|
||||
visible_text = custom_generate_chat_prompt = None
|
||||
@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
if not _continue:
|
||||
text = apply_extensions(text, "input")
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
|
||||
kwargs = {
|
||||
'end_of_turn': end_of_turn,
|
||||
'is_instruct': mode == 'instruct',
|
||||
'_continue': _continue
|
||||
}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not regenerate:
|
||||
if not any((regenerate, _continue)):
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
@ -154,9 +166,14 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
||||
return shared.history['visible']
|
||||
if just_started:
|
||||
just_started = False
|
||||
if not _continue:
|
||||
shared.history['internal'].append(['', ''])
|
||||
shared.history['visible'].append(['', ''])
|
||||
|
||||
if _continue:
|
||||
shared.history['internal'][-1] = [text, f'{last_reply[0]} {reply}']
|
||||
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]} {visible_reply}']
|
||||
else:
|
||||
shared.history['internal'][-1] = [text, reply]
|
||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||
if not shared.args.no_stream:
|
||||
@ -220,6 +237,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
else:
|
||||
# Yield ' ...'
|
||||
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
|
||||
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def remove_last_message(name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
last = shared.history['visible'].pop()
|
||||
|
@ -327,8 +327,9 @@ def create_interface():
|
||||
shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
|
||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
||||
with gr.Row():
|
||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||
shared.gradio['Continue'] = gr.Button('Continue')
|
||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||
with gr.Row():
|
||||
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
||||
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||
@ -411,7 +412,11 @@ def create_interface():
|
||||
|
||||
gen_events.append(shared.gradio['Regenerate'].click(
|
||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
||||
)
|
||||
|
||||
gen_events.append(shared.gradio['Continue'].click(
|
||||
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||
lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user