Refactor the UI

A single dictionary called 'interface_state' is now passed as input to all functions. The values are updated only when necessary.

The goal is to make it easier to add new elements to the UI.
This commit is contained in:
oobabooga 2023-04-11 11:46:30 -03:00 committed by GitHub
parent 64f5c90ee7
commit 0f212093a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 136 additions and 100 deletions

View File

@ -74,16 +74,16 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
return prompt return prompt
def extract_message_from_reply(reply, name1, name2, stop_at_newline): def extract_message_from_reply(reply, state):
next_character_found = False next_character_found = False
if stop_at_newline: if state['stop_at_newline']:
lines = reply.split('\n') lines = reply.split('\n')
reply = lines[0].strip() reply = lines[0].strip()
if len(lines) > 1: if len(lines) > 1:
next_character_found = True next_character_found = True
else: else:
for string in [f"\n{name1}:", f"\n{name2}:"]: for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
idx = reply.find(string) idx = reply.find(string)
if idx != -1: if idx != -1:
reply = reply[:idx] reply = reply[:idx]
@ -92,7 +92,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
# If something like "\nYo" is generated just before "\nYou:" # If something like "\nYo" is generated just before "\nYou:"
# is completed, trim it # is completed, trim it
if not next_character_found: if not next_character_found:
for string in [f"\n{name1}:", f"\n{name2}:"]: for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
for j in range(len(string) - 1, 0, -1): for j in range(len(string) - 1, 0, -1):
if reply[-j:] == string[:j]: if reply[-j:] == string[:j]:
reply = reply[:-j] reply = reply[:-j]
@ -105,21 +105,18 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
return reply, next_character_found return reply, next_character_found
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False): def chatbot_wrapper(text, state, regenerate=False, _continue=False):
if mode == 'instruct': if state['mode'] == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"] stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
else: else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"] stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
# Defining some variables # Defining some variables
cumulative_reply = '' cumulative_reply = ''
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
just_started = True just_started = True
name1_original = name1
visible_text = custom_generate_chat_prompt = None visible_text = custom_generate_chat_prompt = None
eos_token = '\n' if generate_state['stop_at_newline'] else None eos_token = '\n' if state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
# Check if any extension wants to hijack this function call # Check if any extension wants to hijack this function call
for extension, _ in extensions_module.iterator(): for extension, _ in extensions_module.iterator():
@ -136,28 +133,28 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
# Generating the prompt # Generating the prompt
kwargs = { kwargs = {
'end_of_turn': end_of_turn, 'end_of_turn': state['end_of_turn'],
'is_instruct': mode == 'instruct', 'is_instruct': state['mode'] == 'instruct',
'_continue': _continue '_continue': _continue
} }
if custom_generate_chat_prompt is None: if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
else: else:
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) prompt = custom_generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
# Yield *Is typing...* # Yield *Is typing...*
if not any((regenerate, _continue)): if not any((regenerate, _continue)):
yield shared.history['visible'] + [[visible_text, shared.processing_message]] yield shared.history['visible'] + [[visible_text, shared.processing_message]]
# Generate # Generate
for i in range(generate_state['chat_generation_attempts']): for i in range(state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply reply = cumulative_reply + reply
# Extracting the reply # Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) reply, next_character_found = extract_message_from_reply(reply, state)
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply) visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = apply_extensions(visible_reply, "output") visible_reply = apply_extensions(visible_reply, "output")
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
@ -171,7 +168,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
shared.history['visible'].append(['', '']) shared.history['visible'].append(['', ''])
if _continue: if _continue:
sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply)) sep = list(map(lambda x: ' ' if x[-1] != ' ' else '', last_reply))
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}'] shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}'] shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
else: else:
@ -188,28 +185,25 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
yield shared.history['visible'] yield shared.history['visible']
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def impersonate_wrapper(text, state):
if mode == 'instruct': if state['mode'] == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"] stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
else: else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"] stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
# Defining some variables # Defining some variables
cumulative_reply = '' cumulative_reply = ''
eos_token = '\n' if generate_state['stop_at_newline'] else None eos_token = '\n' if state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower(): prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], end_of_turn=state['end_of_turn'], impersonate=True)
name1 = "You"
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
# Yield *Is typing...* # Yield *Is typing...*
yield shared.processing_message yield shared.processing_message
for i in range(generate_state['chat_generation_attempts']): for i in range(state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) reply, next_character_found = extract_message_from_reply(reply, state)
yield reply yield reply
if next_character_found: if next_character_found:
break break
@ -220,32 +214,32 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
yield reply yield reply
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def cai_chatbot_wrapper(text, state):
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): for history in chatbot_wrapper(text, state):
yield chat_html_wrapper(history, name1, name2, mode) yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def regenerate_wrapper(text, state):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
else: else:
last_visible = shared.history['visible'].pop() last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop() last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*' # Yield '*Is typing...*'
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode'])
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True): for history in chatbot_wrapper(last_internal[0], state, regenerate=True):
shared.history['visible'][-1] = [last_visible[0], history[-1][1]] shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def continue_wrapper(text, state):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
else: else:
# Yield ' ...' # Yield ' ...'
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode'])
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True): for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
def remove_last_message(name1, name2, mode): def remove_last_message(name1, name2, mode):

View File

@ -69,6 +69,7 @@ def generate_softprompt_input_tensors(input_ids):
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs # Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s): def fix_gpt4chan(s):
for i in range(10): for i in range(10):
@ -77,6 +78,7 @@ def fix_gpt4chan(s):
s = re.sub("--- [0-9]*\n\n\n---", "---", s) s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s return s
# Fix the LaTeX equations in galactica # Fix the LaTeX equations in galactica
def fix_galactica(s): def fix_galactica(s):
s = s.replace(r'\[', r'$') s = s.replace(r'\[', r'$')
@ -117,9 +119,9 @@ def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]): def generate_reply(question, state, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
seed = set_manual_seed(generate_state['seed']) seed = set_manual_seed(state['seed'])
shared.stop_everything = False shared.stop_everything = False
generate_params = {} generate_params = {}
t0 = time.time() t0 = time.time()
@ -134,8 +136,8 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# separately and terminate the function call earlier # separately and terminate the function call earlier
if any((shared.is_RWKV, shared.is_llamacpp)): if any((shared.is_RWKV, shared.is_llamacpp)):
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = generate_state[k] generate_params[k] = state[k]
generate_params['token_count'] = generate_state['max_new_tokens'] generate_params['token_count'] = state['max_new_tokens']
try: try:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params) reply = shared.model.generate(context=question, **generate_params)
@ -164,7 +166,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return return
input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token']) input_ids = encode(question, state['max_new_tokens'], add_bos_token=state['add_bos_token'])
original_input_ids = input_ids original_input_ids = input_ids
output = input_ids[0] output = input_ids[0]
@ -179,13 +181,13 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
if not shared.args.flexgen: if not shared.args.flexgen:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
generate_params[k] = generate_state[k] generate_params[k] = state[k]
generate_params['eos_token_id'] = eos_token_ids generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list generate_params['stopping_criteria'] = stopping_criteria_list
else: else:
for k in ['max_new_tokens', 'do_sample', 'temperature']: for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = generate_state[k] generate_params[k] = state[k]
generate_params['stop'] = generate_state['eos_token_ids'][-1] generate_params['stop'] = state['eos_token_ids'][-1]
if not shared.args.no_stream: if not shared.args.no_stream:
generate_params['max_new_tokens'] = 8 generate_params['max_new_tokens'] = 8
@ -248,7 +250,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else: else:
for i in range(generate_state['max_new_tokens'] // 8 + 1): for i in range(state['max_new_tokens'] // 8 + 1):
clear_torch_cache() clear_torch_cache()
with torch.no_grad(): with torch.no_grad():
output = shared.model.generate(**generate_params)[0] output = shared.model.generate(**generate_params)[0]

126
server.py
View File

@ -184,22 +184,22 @@ def download_model_wrapper(repo_id):
branch = "main" branch = "main"
check = False check = False
yield("Cleaning up the model/branch names") yield ("Cleaning up the model/branch names")
model, branch = downloader.sanitize_model_and_branch_names(model, branch) model, branch = downloader.sanitize_model_and_branch_names(model, branch)
yield("Getting the download links from Hugging Face") yield ("Getting the download links from Hugging Face")
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False) links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
yield("Getting the output folder") yield ("Getting the output folder")
output_folder = downloader.get_output_folder(model, branch, is_lora) output_folder = downloader.get_output_folder(model, branch, is_lora)
if check: if check:
yield("Checking previously downloaded files") yield ("Checking previously downloaded files")
downloader.check_model_files(model, branch, links, sha256, output_folder) downloader.check_model_files(model, branch, links, sha256, output_folder)
else: else:
yield(f"Downloading files to {output_folder}") yield (f"Downloading files to {output_folder}")
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1) downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
yield("Done!") yield ("Done!")
except: except:
yield traceback.format_exc() yield traceback.format_exc()
@ -357,6 +357,20 @@ else:
title = 'Text generation web UI' title = 'Text generation web UI'
def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token']
if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
return elements
def gather_interface_values(*args):
output = {}
for i, element in enumerate(shared.input_elements):
output[element] = args[i]
return output
def create_interface(): def create_interface():
gen_events = [] gen_events = []
if shared.args.extensions is not None and len(shared.args.extensions) > 0: if shared.args.extensions is not None and len(shared.args.extensions) > 0:
@ -364,7 +378,11 @@ def create_interface():
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
if shared.is_chat(): if shared.is_chat():
shared.input_elements = list_interface_input_elements(chat=True)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['Chat input'] = gr.State() shared.gradio['Chat input'] = gr.State()
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat')) shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
shared.gradio['textbox'] = gr.Textbox(label='Input') shared.gradio['textbox'] = gr.Textbox(label='Input')
@ -384,7 +402,7 @@ def create_interface():
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode") shared.gradio["mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.") shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.")
with gr.Tab("Character", elem_id="chat-settings"): with gr.Tab("Character", elem_id="chat-settings"):
@ -439,75 +457,85 @@ def create_interface():
create_settings_menus(default_preset) create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']] shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']] reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
gen_events.append(shared.gradio['Generate'].click( gen_events.append(shared.gradio['Generate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False) chat.save_history, shared.gradio['mode'], None, show_progress=False)
) )
gen_events.append(shared.gradio['textbox'].submit( gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then( lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False) chat.save_history, shared.gradio['mode'], None, show_progress=False)
) )
gen_events.append(shared.gradio['Regenerate'].click( gen_events.append(shared.gradio['Regenerate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False) chat.save_history, shared.gradio['mode'], None, show_progress=False)
) )
gen_events.append(shared.gradio['Continue'].click( gen_events.append(shared.gradio['Continue'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then( chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False) chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Impersonate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
) )
shared.gradio['Replace last reply'].click( shared.gradio['Replace last reply'].click(
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then( chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then( lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False) chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Clear history-confirm'].click( shared.gradio['Clear history-confirm'].click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then( chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'mode']], shared.gradio['display']).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False) chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Stop'].click( shared.gradio['Stop'].click(
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then( stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
chat.redraw_html, reload_inputs, shared.gradio['display']) chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Chat mode'].change( shared.gradio['mode'].change(
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then( lambda x: gr.update(visible=x == 'instruct'), shared.gradio['mode'], shared.gradio['Instruction templates']).then(
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['Chat mode'], shared.gradio['character_menu']).then( lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['mode'], shared.gradio['character_menu']).then(
chat.redraw_html, reload_inputs, shared.gradio['display']) chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Instruction templates'].change( shared.gradio['Instruction templates'].change(
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then( lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
chat.redraw_html, reload_inputs, shared.gradio['display']) chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['upload_chat_history'].upload( shared.gradio['upload_chat_history'].upload(
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then( chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
chat.redraw_html, reload_inputs, shared.gradio['display']) chat.redraw_html, reload_inputs, shared.gradio['display'])
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['Chat mode'], shared.gradio['download']) shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display']) shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None) shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True) shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
elif shared.args.notebook: elif shared.args.notebook:
shared.input_elements = list_interface_input_elements(chat=False)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
@ -537,12 +565,23 @@ def create_interface():
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Generate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
else: else:
shared.input_elements = list_interface_input_elements(chat=False)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
with gr.Tab("Text generation", elem_id="main"): with gr.Tab("Text generation", elem_id="main"):
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -570,9 +609,22 @@ def create_interface():
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Generate'].click(
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
gen_events.append(shared.gradio['Continue'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)
)
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
@ -607,18 +659,6 @@ def create_interface():
if shared.args.extensions is not None: if shared.args.extensions is not None:
extensions_module.create_extensions_block() extensions_module.create_extensions_block()
def change_dict_value(d, key, value):
d[key] = value
return d
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'add_bos_token', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
if k not in shared.gradio:
continue
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
else:
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
if not shared.is_chat(): if not shared.is_chat():
api.create_apis() api.create_apis()