From 62b533f34420a698f5ee5c8b9314bc503a22e36a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 22 Jan 2023 02:19:58 -0300 Subject: [PATCH] Add "regenerate" button to the chat --- server.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/server.py b/server.py index 42fb802e..712388e7 100644 --- a/server.py +++ b/server.py @@ -322,6 +322,16 @@ if args.chat or args.cai_chat: for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): yield generate_chat_html(history, name1, name2, character) + def regenerate_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): + last = history.pop() + text = last[0] + if args.cai_chat: + for i in cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): + yield i + else: + for i in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size): + yield i + def remove_last_message(name1, name2): history.pop() if args.cai_chat: @@ -411,9 +421,10 @@ if args.chat or args.cai_chat: textbox = gr.Textbox(lines=2, label='Input') btn = gr.Button("Generate") with gr.Row(): - btn2 = gr.Button("Clear history") stop = gr.Button("Stop") - btn3 = gr.Button("Remove last message") + btn_regenerate = gr.Button("Regenerate") + btn_remove_last = gr.Button("Remove last") + btn_clear = gr.Button("Clear history") with gr.Row(): with gr.Column(): @@ -447,17 +458,19 @@ if args.chat or args.cai_chat: if args.cai_chat: gen_event = btn.click(cai_chatbot_wrapper, input_params, display1, show_progress=args.no_stream, api_name="textgen") gen_event2 = textbox.submit(cai_chatbot_wrapper, input_params, display1, show_progress=args.no_stream) - btn2.click(clear_html, [], display1, show_progress=False) + btn_clear.click(clear_html, [], display1, show_progress=False) else: gen_event = btn.click(chatbot_wrapper, input_params, display1, show_progress=args.no_stream, api_name="textgen") gen_event2 = textbox.submit(chatbot_wrapper, input_params, display1, show_progress=args.no_stream) - btn2.click(lambda x: "", display1, display1, show_progress=False) + btn_clear.click(lambda x: "", display1, display1, show_progress=False) + gen_event3 = btn_regenerate.click(regenerate_wrapper, input_params, display1, show_progress=args.no_stream) - btn2.click(clear) - btn3.click(remove_last_message, [name1, name2], display1, show_progress=False) + btn_clear.click(clear) + btn_remove_last.click(remove_last_message, [name1, name2], display1, show_progress=False) btn.click(lambda x: "", textbox, textbox, show_progress=False) + btn_regenerate.click(lambda x: "", textbox, textbox, show_progress=False) textbox.submit(lambda x: "", textbox, textbox, show_progress=False) - stop.click(None, None, None, cancels=[gen_event, gen_event2]) + stop.click(None, None, None, cancels=[gen_event, gen_event2, gen_event3]) save_btn.click(save_history, inputs=[], outputs=[download]) upload.upload(load_history, [upload], []) character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1])