Move chat history into shared module

This commit is contained in:
oobabooga 2023-02-23 15:11:18 -03:00
parent c87800341c
commit 2e86a1ec04
3 changed files with 82 additions and 85 deletions

View File

@ -17,9 +17,6 @@ from modules.text_generation import encode, generate_reply, get_max_prompt_lengt
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
import modules.bot_picture as bot_picture
history = {'internal': [], 'visible': []}
character = None
# This gets the new line characters right.
def clean_chat_message(text):
text = text.replace('\n', '\n\n')
@ -30,7 +27,7 @@ def clean_chat_message(text):
def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
text = clean_chat_message(text)
rows = [f"{context.strip()}\n"]
i = len(history['internal'])-1
i = len(shared.history['internal'])-1
count = 0
if shared.soft_prompt:
@ -38,10 +35,10 @@ def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size,
max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
count += 1
if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n")
if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n")
count += 1
i -= 1
@ -130,20 +127,20 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p,
# We need this global variable to handle the Stop event,
# otherwise gradio gets confused
if stop_everything:
return history['visible']
return shared.history['visible']
if first:
first = False
history['internal'].append(['', ''])
history['visible'].append(['', ''])
shared.history['internal'].append(['', ''])
shared.history['visible'].append(['', ''])
history['internal'][-1] = [text, reply]
history['visible'][-1] = [visible_text, visible_reply]
shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply]
if not substring_found:
yield history['visible']
yield shared.history['visible']
if next_character_found:
break
yield history['visible']
yield shared.history['visible']
def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
if 'pygmalion' in shared.model_name.lower():
@ -161,78 +158,76 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to
def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
yield generate_chat_html(_history, name1, name2, character)
yield generate_chat_html(_history, name1, name2, shared.character)
def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
if character is not None and len(history['visible']) == 1:
if shared.character is not None and len(shared.history['visible']) == 1:
if shared.args.cai_chat:
yield generate_chat_html(history['visible'], name1, name2, character)
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
yield history['visible']
yield shared.history['visible']
else:
last_visible = history['visible'].pop()
last_internal = history['internal'].pop()
last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop()
for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
if shared.args.cai_chat:
history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(history['visible'], name1, name2, character)
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
history['visible'][-1] = (last_visible[0], _history[-1][1])
yield history['visible']
shared.history['visible'][-1] = (last_visible[0], _history[-1][1])
yield shared.history['visible']
def remove_last_message(name1, name2):
if not history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
last = history['visible'].pop()
history['internal'].pop()
if not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop()
shared.history['internal'].pop()
else:
last = ['', '']
if shared.args.cai_chat:
return generate_chat_html(history['visible'], name1, name2, character), last[0]
return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
else:
return history['visible'], last[0]
return shared.history['visible'], last[0]
def send_last_reply_to_input():
if len(history['internal']) > 0:
return history['internal'][-1][1]
if len(shared.history['internal']) > 0:
return shared.history['internal'][-1][1]
else:
return ''
def replace_last_reply(text, name1, name2):
if len(history['visible']) > 0:
if len(shared.history['visible']) > 0:
if shared.args.cai_chat:
history['visible'][-1][1] = text
shared.history['visible'][-1][1] = text
else:
history['visible'][-1] = (history['visible'][-1][0], text)
history['internal'][-1][1] = apply_extensions(text, "input")
shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
shared.history['internal'][-1][1] = apply_extensions(text, "input")
if shared.args.cai_chat:
return generate_chat_html(history['visible'], name1, name2, character)
return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
return history['visible']
return shared.history['visible']
def clear_html():
return generate_chat_html([], "", "", character)
return generate_chat_html([], "", "", shared.character)
def clear_chat_log(_character, name1, name2):
global history
if _character != 'None':
for i in range(len(history['internal'])):
if '<|BEGIN-VISIBLE-CHAT|>' in history['internal'][i][0]:
history['visible'] = [['', history['internal'][i][1]]]
history['internal'] = history['internal'][:i+1]
def clear_chat_log(name1, name2):
if shared.character != 'None':
for i in range(len(shared.history['internal'])):
if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]:
shared.history['visible'] = [['', shared.history['internal'][i][1]]]
shared.history['internal'] = shared.history['internal'][:i+1]
break
else:
history['internal'] = []
history['visible'] = []
shared.history['internal'] = []
shared.history['visible'] = []
if shared.args.cai_chat:
return generate_chat_html(history['visible'], name1, name2, character)
return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
return history['visible']
return shared.history['visible']
def redraw_html(name1, name2):
global history
return generate_chat_html(history['visible'], name1, name2, character)
return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
def tokenize_dialogue(dialogue, name1, name2):
_history = []
@ -273,47 +268,45 @@ def tokenize_dialogue(dialogue, name1, name2):
def save_history(timestamp=True):
if timestamp:
fname = f"{character or ''}{'_' if character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
fname = f"{shared.character or ''}{'_' if shared.character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else:
fname = f"{character or ''}{'_' if character else ''}persistent.json"
fname = f"{shared.character or ''}{'_' if shared.character else ''}persistent.json"
if not Path('logs').exists():
Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w') as f:
f.write(json.dumps({'data': history['internal'], 'data_visible': history['visible']}, indent=2))
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}')
def load_history(file, name1, name2):
global history
file = file.decode('utf-8')
try:
j = json.loads(file)
if 'data' in j:
history['internal'] = j['data']
shared.history['internal'] = j['data']
if 'data_visible' in j:
history['visible'] = j['data_visible']
shared.history['visible'] = j['data_visible']
else:
history['visible'] = copy.deepcopy(history['internal'])
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
# Compatibility with Pygmalion AI's official web UI
elif 'chat' in j:
history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', history['internal'][0]]] + [[history['internal'][i], history['internal'][i+1]] for i in range(1, len(history['internal'])-1, 2)]
history['visible'] = copy.deepcopy(history['internal'])
history['visible'][0][0] = ''
shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
shared.history['visible'][0][0] = ''
else:
history['internal'] = [[history['internal'][i], history['internal'][i+1]] for i in range(0, len(history['internal'])-1, 2)]
history['visible'] = copy.deepcopy(history['internal'])
shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
except:
history['internal'] = tokenize_dialogue(file, name1, name2)
history['visible'] = copy.deepcopy(history['internal'])
shared.history['internal'] = tokenize_dialogue(file, name1, name2)
shared.history['visible'] = copy.deepcopy(shared.history['internal'])
def load_character(_character, name1, name2):
global history, character
context = ""
history['internal'] = []
history['visible'] = []
shared.history['internal'] = []
shared.history['visible'] = []
if _character != 'None':
character = _character
shared.character = _character
data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read())
name2 = data['char_name']
if 'char_persona' in data and data['char_persona'] != '':
@ -322,25 +315,25 @@ def load_character(_character, name1, name2):
context += f"Scenario: {data['world_scenario']}\n"
context = f"{context.strip()}\n<START>\n"
if 'example_dialogue' in data and data['example_dialogue'] != '':
history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
shared.history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]]
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
shared.history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]]
else:
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
history['visible'] += [['', "Hello there!"]]
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
shared.history['visible'] += [['', "Hello there!"]]
else:
character = None
shared.character = None
context = shared.settings['context_pygmalion']
name2 = shared.settings['name2_pygmalion']
if Path(f'logs/{character}_persistent.json').exists():
load_history(open(Path(f'logs/{character}_persistent.json'), 'rb').read(), name1, name2)
if Path(f'logs/{shared.character}_persistent.json').exists():
load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
if shared.args.cai_chat:
return name2, context, generate_chat_html(history['visible'], name1, name2, character)
return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
else:
return name2, context, history['visible']
return name2, context, shared.history['visible']
def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')

View File

@ -7,6 +7,10 @@ soft_prompt_tensor = None
soft_prompt = False
stop_everything = False
# Chat variables
history = {'internal': [], 'visible': []}
character = 'None'
settings = {
'max_new_tokens': 200,
'max_new_tokens_min': 1,

View File

@ -191,9 +191,9 @@ if shared.args.chat or shared.args.cai_chat:
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as interface:
if shared.args.cai_chat:
display = gr.HTML(value=generate_chat_html(chat.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], chat.character))
display = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
else:
display = gr.Chatbot(value=chat.history['visible'])
display = gr.Chatbot(value=shared.history['visible'])
textbox = gr.Textbox(label='Input')
with gr.Row():
buttons["Stop"] = gr.Button("Stop")
@ -272,7 +272,7 @@ if shared.args.chat or shared.args.cai_chat:
buttons["Send last reply to input"].click(chat.send_last_reply_to_input, [], textbox, show_progress=shared.args.no_stream)
buttons["Replace last reply"].click(chat.replace_last_reply, [textbox, name1, name2], display, show_progress=shared.args.no_stream)
buttons["Clear history"].click(chat.clear_chat_log, [character_menu, name1, name2], display)
buttons["Clear history"].click(chat.clear_chat_log, [name1, name2], display)
buttons["Remove last"].click(chat.remove_last_message, [name1, name2], [display, textbox], show_progress=False)
buttons["Download"].click(chat.save_history, inputs=[], outputs=[download])
buttons["Upload character"].click(chat.upload_character, [upload_char, upload_img], [character_menu])
@ -295,8 +295,8 @@ if shared.args.chat or shared.args.cai_chat:
upload_chat_history.upload(chat.redraw_html, [name1, name2], [display])
upload_img_me.upload(chat.redraw_html, [name1, name2], [display])
else:
upload_chat_history.upload(lambda : chat.history['visible'], [], [display])
upload_img_me.upload(lambda : chat.history['visible'], [], [display])
upload_chat_history.upload(lambda : shared.history['visible'], [], [display])
upload_img_me.upload(lambda : shared.history['visible'], [], [display])
elif shared.args.notebook:
with gr.Blocks(css=ui.css, analytics_enabled=False) as interface: