diff --git a/modules/chat.py b/modules/chat.py index 6c6077a2..6400adcb 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False + _continue = kwargs['_continue'] if '_continue' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False rows = [f"{context.strip()}\n"] @@ -39,7 +40,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat i = len(shared.history['internal']) - 1 while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: - rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") + if _continue and i == len(shared.history['internal']) - 1: + rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}") + else: + rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") string = shared.history['internal'][i][0] if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n") @@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat if impersonate: rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") limit = 2 + elif _continue: + limit = 3 else: # Adding the user message user_input = fix_newlines(user_input) @@ -56,12 +62,12 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat # Adding the Character prefix rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) + limit = 3 while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length: rows.pop(1) prompt = ''.join(rows) - if also_return_rows: return prompt, rows else: @@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): return reply, next_character_found -def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): +def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False): if mode == 'instruct': stopping_strings = [f"\n{name1}", f"\n{name2}"] else: @@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu # Defining some variables cumulative_reply = '' + last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None just_started = True name1_original = name1 visible_text = custom_generate_chat_prompt = None @@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu if visible_text is None: visible_text = text - text = apply_extensions(text, "input") + if not _continue: + text = apply_extensions(text, "input") # Generating the prompt - kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'} + kwargs = { + 'end_of_turn': end_of_turn, + 'is_instruct': mode == 'instruct', + '_continue': _continue + } if custom_generate_chat_prompt is None: prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) else: prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) # Yield *Is typing...* - if not regenerate: + if not any((regenerate, _continue)): yield shared.history['visible'] + [[visible_text, shared.processing_message]] # Generate @@ -154,11 +166,16 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu return shared.history['visible'] if just_started: just_started = False - shared.history['internal'].append(['', '']) - shared.history['visible'].append(['', '']) + if not _continue: + shared.history['internal'].append(['', '']) + shared.history['visible'].append(['', '']) - shared.history['internal'][-1] = [text, reply] - shared.history['visible'][-1] = [visible_text, visible_reply] + if _continue: + shared.history['internal'][-1] = [text, f'{last_reply[0]} {reply}'] + shared.history['visible'][-1] = [visible_text, f'{last_reply[1]} {visible_reply}'] + else: + shared.history['internal'][-1] = [text, reply] + shared.history['visible'][-1] = [visible_text, visible_reply] if not shared.args.no_stream: yield shared.history['visible'] if next_character_found: @@ -220,6 +237,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) +def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): + if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: + yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) + else: + # Yield ' ...' + yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode) + for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True): + yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) + + def remove_last_message(name1, name2, mode): if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': last = shared.history['visible'].pop() diff --git a/server.py b/server.py index 50305ec0..cbfbd241 100644 --- a/server.py +++ b/server.py @@ -327,8 +327,9 @@ def create_interface(): shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate') shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop") with gr.Row(): - shared.gradio['Impersonate'] = gr.Button('Impersonate') shared.gradio['Regenerate'] = gr.Button('Regenerate') + shared.gradio['Continue'] = gr.Button('Continue') + shared.gradio['Impersonate'] = gr.Button('Impersonate') with gr.Row(): shared.gradio['Copy last reply'] = gr.Button('Copy last reply') shared.gradio['Replace last reply'] = gr.Button('Replace last reply') @@ -411,7 +412,11 @@ def create_interface(): gen_events.append(shared.gradio['Regenerate'].click( chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( - lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then( + lambda: chat.save_history(timestamp=False), None, None, show_progress=False) + ) + + gen_events.append(shared.gradio['Continue'].click( + chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( lambda: chat.save_history(timestamp=False), None, None, show_progress=False) )