Refactor the UI

A single dictionary called 'interface_state' is now passed as input to all functions. The values are updated only when necessary.

The goal is to make it easier to add new elements to the UI.
This commit is contained in:
oobabooga 2023-04-11 11:46:30 -03:00 committed by GitHub
parent 64f5c90ee7
commit 0f212093a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 136 additions and 100 deletions

View File

@ -74,16 +74,16 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
return prompt
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
def extract_message_from_reply(reply, state):
next_character_found = False
if stop_at_newline:
if state['stop_at_newline']:
lines = reply.split('\n')
reply = lines[0].strip()
if len(lines) > 1:
next_character_found = True
else:
for string in [f"\n{name1}:", f"\n{name2}:"]:
for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
idx = reply.find(string)
if idx != -1:
reply = reply[:idx]
@ -92,7 +92,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
# If something like "\nYo" is generated just before "\nYou:"
# is completed, trim it
if not next_character_found:
for string in [f"\n{name1}:", f"\n{name2}:"]:
for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
for j in range(len(string) - 1, 0, -1):
if reply[-j:] == string[:j]:
reply = reply[:-j]
@ -105,21 +105,18 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
return reply, next_character_found
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"]
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
if state['mode'] == 'instruct':
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
# Defining some variables
cumulative_reply = ''
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
just_started = True
name1_original = name1
visible_text = custom_generate_chat_prompt = None
eos_token = '\n' if generate_state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
eos_token = '\n' if state['stop_at_newline'] else None
# Check if any extension wants to hijack this function call
for extension, _ in extensions_module.iterator():
@ -136,28 +133,28 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
# Generating the prompt
kwargs = {
'end_of_turn': end_of_turn,
'is_instruct': mode == 'instruct',
'end_of_turn': state['end_of_turn'],
'is_instruct': state['mode'] == 'instruct',
'_continue': _continue
}
if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
else:
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
prompt = custom_generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
# Yield *Is typing...*
if not any((regenerate, _continue)):
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
# Generate
for i in range(generate_state['chat_generation_attempts']):
for i in range(state['chat_generation_attempts']):
reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply
# Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
reply, next_character_found = extract_message_from_reply(reply, state)
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = apply_extensions(visible_reply, "output")
# We need this global variable to handle the Stop event,
@ -188,28 +185,25 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
yield shared.history['visible']
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"]
def impersonate_wrapper(text, state):
if state['mode'] == 'instruct':
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
else:
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
# Defining some variables
cumulative_reply = ''
eos_token = '\n' if generate_state['stop_at_newline'] else None
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
eos_token = '\n' if state['stop_at_newline'] else None
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], end_of_turn=state['end_of_turn'], impersonate=True)
# Yield *Is typing...*
yield shared.processing_message
for i in range(generate_state['chat_generation_attempts']):
for i in range(state['chat_generation_attempts']):
reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
reply = cumulative_reply + reply
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
reply, next_character_found = extract_message_from_reply(reply, state)
yield reply
if next_character_found:
break
@ -220,32 +214,32 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
yield reply
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
yield chat_html_wrapper(history, name1, name2, mode)
def cai_chatbot_wrapper(text, state):
for history in chatbot_wrapper(text, state):
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
def regenerate_wrapper(text, state):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
else:
last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*'
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode'])
for history in chatbot_wrapper(last_internal[0], state, regenerate=True):
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
def continue_wrapper(text, state):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
else:
# Yield ' ...'
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode'])
for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
def remove_last_message(name1, name2, mode):

View File

@ -69,6 +69,7 @@ def generate_softprompt_input_tensors(input_ids):
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s):
for i in range(10):
@ -77,6 +78,7 @@ def fix_gpt4chan(s):
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s
# Fix the LaTeX equations in galactica
def fix_galactica(s):
s = s.replace(r'\[', r'$')
@ -117,9 +119,9 @@ def stop_everything_event():
shared.stop_everything = True
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
clear_torch_cache()
seed = set_manual_seed(generate_state['seed'])
seed = set_manual_seed(state['seed'])
shared.stop_everything = False
generate_params = {}
t0 = time.time()
@ -134,8 +136,8 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# separately and terminate the function call earlier
if any((shared.is_RWKV, shared.is_llamacpp)):
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = generate_state[k]
generate_params['token_count'] = generate_state['max_new_tokens']
generate_params[k] = state[k]
generate_params['token_count'] = state['max_new_tokens']
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params)
@ -164,7 +166,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token'])
input_ids = encode(question, state['max_new_tokens'], add_bos_token=state['add_bos_token'])
original_input_ids = input_ids
output = input_ids[0]
@ -179,13 +181,13 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
if not shared.args.flexgen:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
generate_params[k] = generate_state[k]
generate_params[k] = state[k]
generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list
else:
for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = generate_state[k]
generate_params['stop'] = generate_state['eos_token_ids'][-1]
generate_params[k] = state[k]
generate_params['stop'] = state['eos_token_ids'][-1]
if not shared.args.no_stream:
generate_params['max_new_tokens'] = 8
@ -248,7 +250,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(generate_state['max_new_tokens'] // 8 + 1):
for i in range(state['max_new_tokens'] // 8 + 1):
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]

114
server.py
View File

@ -357,6 +357,20 @@ else:
title = 'Text generation web UI'
def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token']
if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
return elements
def gather_interface_values(*args):
output = {}
for i, element in enumerate(shared.input_elements):
output[element] = args[i]
return output
def create_interface():
gen_events = []
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
@ -364,7 +378,11 @@ def create_interface():
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
if shared.is_chat():
shared.input_elements = list_interface_input_elements(chat=True)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
shared.gradio['Chat input'] = gr.State()
with gr.Tab("Text generation", elem_id="main"):
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
shared.gradio['textbox'] = gr.Textbox(label='Input')
@ -384,7 +402,7 @@ def create_interface():
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
shared.gradio["mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.")
with gr.Tab("Character", elem_id="chat-settings"):
@ -439,75 +457,85 @@ def create_interface():
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
gen_events.append(shared.gradio['Generate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Regenerate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Continue'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
chat.save_history, shared.gradio['mode'], None, show_progress=False)
)
gen_events.append(shared.gradio['Impersonate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
)
shared.gradio['Replace last reply'].click(
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Clear history-confirm'].click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'mode']], shared.gradio['display']).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False)
shared.gradio['Stop'].click(
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Chat mode'].change(
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['Chat mode'], shared.gradio['character_menu']).then(
shared.gradio['mode'].change(
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['mode'], shared.gradio['Instruction templates']).then(
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['mode'], shared.gradio['character_menu']).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['Instruction templates'].change(
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
shared.gradio['upload_chat_history'].upload(
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
chat.redraw_html, reload_inputs, shared.gradio['display'])
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['Chat mode'], shared.gradio['download'])
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
elif shared.args.notebook:
shared.input_elements = list_interface_input_elements(chat=False)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
with gr.Tab("Text generation", elem_id="main"):
with gr.Row():
with gr.Column(scale=4):
@ -537,12 +565,23 @@ def create_interface():
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Generate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
else:
shared.input_elements = list_interface_input_elements(chat=False)
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
with gr.Tab("Text generation", elem_id="main"):
with gr.Row():
with gr.Column():
@ -570,9 +609,22 @@ def create_interface():
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Generate'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
gen_events.append(shared.gradio['textbox'].submit(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
)
gen_events.append(shared.gradio['Continue'].click(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)
)
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
@ -607,18 +659,6 @@ def create_interface():
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
def change_dict_value(d, key, value):
d[key] = value
return d
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'add_bos_token', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
if k not in shared.gradio:
continue
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
else:
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
if not shared.is_chat():
api.create_apis()