From 2e86a1ec044c06a12c4e68c1b63a52040215199e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 23 Feb 2023 15:11:18 -0300 Subject: [PATCH] Move chat history into shared module --- modules/chat.py | 153 ++++++++++++++++++++++------------------------ modules/shared.py | 4 ++ server.py | 10 +-- 3 files changed, 82 insertions(+), 85 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index fa9ee4f3..7c55feda 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -17,9 +17,6 @@ from modules.text_generation import encode, generate_reply, get_max_prompt_lengt if shared.args.picture and (shared.args.cai_chat or shared.args.chat): import modules.bot_picture as bot_picture -history = {'internal': [], 'visible': []} -character = None - # This gets the new line characters right. def clean_chat_message(text): text = text.replace('\n', '\n\n') @@ -30,7 +27,7 @@ def clean_chat_message(text): def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False): text = clean_chat_message(text) rows = [f"{context.strip()}\n"] - i = len(history['internal'])-1 + i = len(shared.history['internal'])-1 count = 0 if shared.soft_prompt: @@ -38,10 +35,10 @@ def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, max_length = min(get_max_prompt_length(tokens), chat_prompt_size) while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length: - rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n") + rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n") count += 1 - if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'): - rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n") + if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'): + rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n") count += 1 i -= 1 @@ -130,20 +127,20 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, # We need this global variable to handle the Stop event, # otherwise gradio gets confused if stop_everything: - return history['visible'] + return shared.history['visible'] if first: first = False - history['internal'].append(['', '']) - history['visible'].append(['', '']) + shared.history['internal'].append(['', '']) + shared.history['visible'].append(['', '']) - history['internal'][-1] = [text, reply] - history['visible'][-1] = [visible_text, visible_reply] + shared.history['internal'][-1] = [text, reply] + shared.history['visible'][-1] = [visible_text, visible_reply] if not substring_found: - yield history['visible'] + yield shared.history['visible'] if next_character_found: break - yield history['visible'] + yield shared.history['visible'] def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): if 'pygmalion' in shared.model_name.lower(): @@ -161,78 +158,76 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): - yield generate_chat_html(_history, name1, name2, character) + yield generate_chat_html(_history, name1, name2, shared.character) def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): - if character is not None and len(history['visible']) == 1: + if shared.character is not None and len(shared.history['visible']) == 1: if shared.args.cai_chat: - yield generate_chat_html(history['visible'], name1, name2, character) + yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) else: - yield history['visible'] + yield shared.history['visible'] else: - last_visible = history['visible'].pop() - last_internal = history['internal'].pop() + last_visible = shared.history['visible'].pop() + last_internal = shared.history['internal'].pop() for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): if shared.args.cai_chat: - history['visible'][-1] = [last_visible[0], _history[-1][1]] - yield generate_chat_html(history['visible'], name1, name2, character) + shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] + yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) else: - history['visible'][-1] = (last_visible[0], _history[-1][1]) - yield history['visible'] + shared.history['visible'][-1] = (last_visible[0], _history[-1][1]) + yield shared.history['visible'] def remove_last_message(name1, name2): - if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': - last = history['visible'].pop() - history['internal'].pop() + if not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>': + last = shared.history['visible'].pop() + shared.history['internal'].pop() else: last = ['', ''] if shared.args.cai_chat: - return generate_chat_html(history['visible'], name1, name2, character), last[0] + return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0] else: - return history['visible'], last[0] + return shared.history['visible'], last[0] def send_last_reply_to_input(): - if len(history['internal']) > 0: - return history['internal'][-1][1] + if len(shared.history['internal']) > 0: + return shared.history['internal'][-1][1] else: return '' def replace_last_reply(text, name1, name2): - if len(history['visible']) > 0: + if len(shared.history['visible']) > 0: if shared.args.cai_chat: - history['visible'][-1][1] = text + shared.history['visible'][-1][1] = text else: - history['visible'][-1] = (history['visible'][-1][0], text) - history['internal'][-1][1] = apply_extensions(text, "input") + shared.history['visible'][-1] = (shared.history['visible'][-1][0], text) + shared.history['internal'][-1][1] = apply_extensions(text, "input") if shared.args.cai_chat: - return generate_chat_html(history['visible'], name1, name2, character) + return generate_chat_html(shared.history['visible'], name1, name2, shared.character) else: - return history['visible'] + return shared.history['visible'] def clear_html(): - return generate_chat_html([], "", "", character) + return generate_chat_html([], "", "", shared.character) -def clear_chat_log(_character, name1, name2): - global history - if _character != 'None': - for i in range(len(history['internal'])): - if '<|BEGIN-VISIBLE-CHAT|>' in history['internal'][i][0]: - history['visible'] = [['', history['internal'][i][1]]] - history['internal'] = history['internal'][:i+1] +def clear_chat_log(name1, name2): + if shared.character != 'None': + for i in range(len(shared.history['internal'])): + if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]: + shared.history['visible'] = [['', shared.history['internal'][i][1]]] + shared.history['internal'] = shared.history['internal'][:i+1] break else: - history['internal'] = [] - history['visible'] = [] + shared.history['internal'] = [] + shared.history['visible'] = [] if shared.args.cai_chat: - return generate_chat_html(history['visible'], name1, name2, character) + return generate_chat_html(shared.history['visible'], name1, name2, shared.character) else: - return history['visible'] + return shared.history['visible'] def redraw_html(name1, name2): - global history - return generate_chat_html(history['visible'], name1, name2, character) + return generate_chat_html(shared.history['visible'], name1, name2, shared.character) def tokenize_dialogue(dialogue, name1, name2): _history = [] @@ -273,47 +268,45 @@ def tokenize_dialogue(dialogue, name1, name2): def save_history(timestamp=True): if timestamp: - fname = f"{character or ''}{'_' if character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" + fname = f"{shared.character or ''}{'_' if shared.character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" else: - fname = f"{character or ''}{'_' if character else ''}persistent.json" + fname = f"{shared.character or ''}{'_' if shared.character else ''}persistent.json" if not Path('logs').exists(): Path('logs').mkdir() with open(Path(f'logs/{fname}'), 'w') as f: - f.write(json.dumps({'data': history['internal'], 'data_visible': history['visible']}, indent=2)) + f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) return Path(f'logs/{fname}') def load_history(file, name1, name2): - global history file = file.decode('utf-8') try: j = json.loads(file) if 'data' in j: - history['internal'] = j['data'] + shared.history['internal'] = j['data'] if 'data_visible' in j: - history['visible'] = j['data_visible'] + shared.history['visible'] = j['data_visible'] else: - history['visible'] = copy.deepcopy(history['internal']) + shared.history['visible'] = copy.deepcopy(shared.history['internal']) # Compatibility with Pygmalion AI's official web UI elif 'chat' in j: - history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']] + shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']] if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'): - history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', history['internal'][0]]] + [[history['internal'][i], history['internal'][i+1]] for i in range(1, len(history['internal'])-1, 2)] - history['visible'] = copy.deepcopy(history['internal']) - history['visible'][0][0] = '' + shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)] + shared.history['visible'] = copy.deepcopy(shared.history['internal']) + shared.history['visible'][0][0] = '' else: - history['internal'] = [[history['internal'][i], history['internal'][i+1]] for i in range(0, len(history['internal'])-1, 2)] - history['visible'] = copy.deepcopy(history['internal']) + shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)] + shared.history['visible'] = copy.deepcopy(shared.history['internal']) except: - history['internal'] = tokenize_dialogue(file, name1, name2) - history['visible'] = copy.deepcopy(history['internal']) + shared.history['internal'] = tokenize_dialogue(file, name1, name2) + shared.history['visible'] = copy.deepcopy(shared.history['internal']) def load_character(_character, name1, name2): - global history, character context = "" - history['internal'] = [] - history['visible'] = [] + shared.history['internal'] = [] + shared.history['visible'] = [] if _character != 'None': - character = _character + shared.character = _character data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read()) name2 = data['char_name'] if 'char_persona' in data and data['char_persona'] != '': @@ -322,25 +315,25 @@ def load_character(_character, name1, name2): context += f"Scenario: {data['world_scenario']}\n" context = f"{context.strip()}\n\n" if 'example_dialogue' in data and data['example_dialogue'] != '': - history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2) + shared.history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2) if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0: - history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]] - history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]] + shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]] + shared.history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]] else: - history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]] - history['visible'] += [['', "Hello there!"]] + shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]] + shared.history['visible'] += [['', "Hello there!"]] else: - character = None + shared.character = None context = shared.settings['context_pygmalion'] name2 = shared.settings['name2_pygmalion'] - if Path(f'logs/{character}_persistent.json').exists(): - load_history(open(Path(f'logs/{character}_persistent.json'), 'rb').read(), name1, name2) + if Path(f'logs/{shared.character}_persistent.json').exists(): + load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) if shared.args.cai_chat: - return name2, context, generate_chat_html(history['visible'], name1, name2, character) + return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character) else: - return name2, context, history['visible'] + return name2, context, shared.history['visible'] def upload_character(json_file, img, tavern=False): json_file = json_file if type(json_file) == str else json_file.decode('utf-8') diff --git a/modules/shared.py b/modules/shared.py index 7744771f..29600d8d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -7,6 +7,10 @@ soft_prompt_tensor = None soft_prompt = False stop_everything = False +# Chat variables +history = {'internal': [], 'visible': []} +character = 'None' + settings = { 'max_new_tokens': 200, 'max_new_tokens_min': 1, diff --git a/server.py b/server.py index 414b1dc6..ee5c9372 100644 --- a/server.py +++ b/server.py @@ -191,9 +191,9 @@ if shared.args.chat or shared.args.cai_chat: with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as interface: if shared.args.cai_chat: - display = gr.HTML(value=generate_chat_html(chat.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], chat.character)) + display = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) else: - display = gr.Chatbot(value=chat.history['visible']) + display = gr.Chatbot(value=shared.history['visible']) textbox = gr.Textbox(label='Input') with gr.Row(): buttons["Stop"] = gr.Button("Stop") @@ -272,7 +272,7 @@ if shared.args.chat or shared.args.cai_chat: buttons["Send last reply to input"].click(chat.send_last_reply_to_input, [], textbox, show_progress=shared.args.no_stream) buttons["Replace last reply"].click(chat.replace_last_reply, [textbox, name1, name2], display, show_progress=shared.args.no_stream) - buttons["Clear history"].click(chat.clear_chat_log, [character_menu, name1, name2], display) + buttons["Clear history"].click(chat.clear_chat_log, [name1, name2], display) buttons["Remove last"].click(chat.remove_last_message, [name1, name2], [display, textbox], show_progress=False) buttons["Download"].click(chat.save_history, inputs=[], outputs=[download]) buttons["Upload character"].click(chat.upload_character, [upload_char, upload_img], [character_menu]) @@ -295,8 +295,8 @@ if shared.args.chat or shared.args.cai_chat: upload_chat_history.upload(chat.redraw_html, [name1, name2], [display]) upload_img_me.upload(chat.redraw_html, [name1, name2], [display]) else: - upload_chat_history.upload(lambda : chat.history['visible'], [], [display]) - upload_img_me.upload(lambda : chat.history['visible'], [], [display]) + upload_chat_history.upload(lambda : shared.history['visible'], [], [display]) + upload_img_me.upload(lambda : shared.history['visible'], [], [display]) elif shared.args.notebook: with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: