From 6136da419cfaf008a9a6c2c10bb8f746f522603a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 15 Jan 2023 12:20:04 -0300 Subject: [PATCH] Add --cai-chat option that mimics Character.AI's interface --- README.md | 3 ++ html_generator.py | 104 ++++++++++++++++++++++++++++++++++++++++++++++ server.py | 38 +++++++++++++---- 3 files changed, 136 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 5428afd6..64e21455 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,9 @@ Optionally, you can use the following command-line flags: --notebook Launch the webui in notebook mode, where the output is written to the same text box as the input. --chat Launch the webui in chat mode. +--cai-chat Launch the webui in chat mode with a style similar to Character.AI's. If the + file profile.png exists in the same folder as server.py, this image will be used + as the bot's profile picture. --cpu Use the CPU to generate text. --auto-devices Automatically split the model across the available GPU(s) and CPU. --load-in-8bit Load the model with 8-bit precision. diff --git a/html_generator.py b/html_generator.py index e679eb3b..7e4f8b4b 100644 --- a/html_generator.py +++ b/html_generator.py @@ -5,6 +5,7 @@ This is a library for formatting gpt4chan outputs as nice HTML. ''' import re +from pathlib import Path def process_post(post, c): t = post.split('\n') @@ -153,3 +154,106 @@ def generate_4chan_html(f): output = '\n'.join(output) return output + +def generate_chat_html(history, name1, name2): + css = """ + .chat { + margin-left: auto; + margin-right: auto; + max-width: 800px; + height: 50vh; + overflow-y: auto; + padding-right: 20px; + display: flex; + flex-direction: column-reverse; + } + + .message { + display: grid; + grid-template-columns: 50px 1fr; + padding-bottom: 20px; + font-size: 15px; + font-family: helvetica; + } + + .circle-you { + width: 45px; + height: 45px; + background-color: rgb(244, 78, 59); + border-radius: 50%; + } + + .circle-bot { + width: 45px; + height: 45px; + background-color: rgb(59, 78, 244); + border-radius: 50%; + } + + .circle-bot img { + border-radius: 50%; + width: 100%; + height: 100%; + object-fit: cover; + } + + .text { + } + + .text p { + margin-top: 5px; + } + + .username { + font-weight: bold; + } + + .body { + } + """ + + output = '' + output += f'
' + if Path("profile.png").exists(): + img = '' + else: + img = '' + + for row in history[::-1]: + p = '\n'.join([f"

{x}

" for x in row[1].split('\n')]) + output += f""" +
+
+ {img} +
+
+
+ {name2} +
+
+ {p} +
+
+
+ """ + + p = '\n'.join([f"

{x}

" for x in row[0].split('\n')]) + output += f""" +
+
+
+
+
+ {name1} +
+
+ {p} +
+
+
+ """ + + output += '' + output += "
" + + return output diff --git a/server.py b/server.py index 05fa1985..48df0720 100644 --- a/server.py +++ b/server.py @@ -16,6 +16,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--chat', action='store_true', help='Launch the webui in chat mode.') +parser.add_argument('--cai-chat', action='store_true', help='Launch the webui in chat mode with a style similar to Character.AI\'s. If the file profile.png exists in the same folder as server.py, this image will be used as the bot\'s profile picture.') parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.') parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.') parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') @@ -189,7 +190,7 @@ if args.notebook: btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True, api_name="textgen") textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=True) -elif args.chat: +elif args.chat or args.cai_chat: history = [] # This gets the new line characters right. @@ -218,19 +219,29 @@ elif args.chat: idx = reply.find(f"\n{name1}:") if idx != -1: reply = reply[:idx] - reply = chat_response_cleaner(response) + reply = chat_response_cleaner(reply) history.append((text, reply)) return history - def remove_last_message(): + def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): + history = chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check) + return generate_chat_html(history, name1, name2) + + def remove_last_message(name1, name2): history.pop() - return history + if args.cai_chat: + return generate_chat_html(history, name1, name2) + else: + return history def clear(): global history history = [] + def clear_html(): + return generate_chat_html([], "", "") + if 'pygmalion' in model_name.lower(): context_str = "This is a conversation between two people.\n" name1_str = "You" @@ -258,7 +269,10 @@ elif args.chat: check = gr.Checkbox(value=True, label='Stop generating at new line character?') with gr.Column(): - display1 = gr.Chatbot() + if args.cai_chat: + display1 = gr.HTML(value=generate_chat_html([], "", "")) + else: + display1 = gr.Chatbot() textbox = gr.Textbox(lines=2, label='Input') btn = gr.Button("Generate") with gr.Row(): @@ -267,13 +281,19 @@ elif args.chat: with gr.Column(): btn2 = gr.Button("Clear history") - btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen") - textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True) - btn3.click(remove_last_message, [], display1, show_progress=False) + if args.cai_chat: + btn.click(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen") + textbox.submit(cai_chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True) + btn2.click(clear_html, [], display1, show_progress=False) + else: + btn.click(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True, api_name="textgen") + textbox.submit(chatbot_wrapper, [textbox, length_slider, preset_menu, model_menu, name1, name2, context, check], display1, show_progress=True) + btn2.click(lambda x: "", display1, display1) + btn2.click(clear) + btn3.click(remove_last_message, [name1, name2], display1, show_progress=False) btn.click(lambda x: "", textbox, textbox, show_progress=False) textbox.submit(lambda x: "", textbox, textbox, show_progress=False) - btn2.click(lambda x: "", display1, display1) else: def continue_wrapper(question, tokens, inference_settings, selected_model):