mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Refactor the UI
A single dictionary called 'interface_state' is now passed as input to all functions. The values are updated only when necessary. The goal is to make it easier to add new elements to the UI.
This commit is contained in:
parent
64f5c90ee7
commit
0f212093a3
@ -74,16 +74,16 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
def extract_message_from_reply(reply, state):
|
||||||
next_character_found = False
|
next_character_found = False
|
||||||
|
|
||||||
if stop_at_newline:
|
if state['stop_at_newline']:
|
||||||
lines = reply.split('\n')
|
lines = reply.split('\n')
|
||||||
reply = lines[0].strip()
|
reply = lines[0].strip()
|
||||||
if len(lines) > 1:
|
if len(lines) > 1:
|
||||||
next_character_found = True
|
next_character_found = True
|
||||||
else:
|
else:
|
||||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
|
||||||
idx = reply.find(string)
|
idx = reply.find(string)
|
||||||
if idx != -1:
|
if idx != -1:
|
||||||
reply = reply[:idx]
|
reply = reply[:idx]
|
||||||
@ -92,7 +92,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||||||
# If something like "\nYo" is generated just before "\nYou:"
|
# If something like "\nYo" is generated just before "\nYou:"
|
||||||
# is completed, trim it
|
# is completed, trim it
|
||||||
if not next_character_found:
|
if not next_character_found:
|
||||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
|
||||||
for j in range(len(string) - 1, 0, -1):
|
for j in range(len(string) - 1, 0, -1):
|
||||||
if reply[-j:] == string[:j]:
|
if reply[-j:] == string[:j]:
|
||||||
reply = reply[:-j]
|
reply = reply[:-j]
|
||||||
@ -105,21 +105,18 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||||||
return reply, next_character_found
|
return reply, next_character_found
|
||||||
|
|
||||||
|
|
||||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
|
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||||
if mode == 'instruct':
|
if state['mode'] == 'instruct':
|
||||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
|
||||||
else:
|
else:
|
||||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
|
||||||
|
|
||||||
# Defining some variables
|
# Defining some variables
|
||||||
cumulative_reply = ''
|
cumulative_reply = ''
|
||||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||||
just_started = True
|
just_started = True
|
||||||
name1_original = name1
|
|
||||||
visible_text = custom_generate_chat_prompt = None
|
visible_text = custom_generate_chat_prompt = None
|
||||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
eos_token = '\n' if state['stop_at_newline'] else None
|
||||||
if 'pygmalion' in shared.model_name.lower():
|
|
||||||
name1 = "You"
|
|
||||||
|
|
||||||
# Check if any extension wants to hijack this function call
|
# Check if any extension wants to hijack this function call
|
||||||
for extension, _ in extensions_module.iterator():
|
for extension, _ in extensions_module.iterator():
|
||||||
@ -136,28 +133,28 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
|
|
||||||
# Generating the prompt
|
# Generating the prompt
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'end_of_turn': end_of_turn,
|
'end_of_turn': state['end_of_turn'],
|
||||||
'is_instruct': mode == 'instruct',
|
'is_instruct': state['mode'] == 'instruct',
|
||||||
'_continue': _continue
|
'_continue': _continue
|
||||||
}
|
}
|
||||||
if custom_generate_chat_prompt is None:
|
if custom_generate_chat_prompt is None:
|
||||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
|
||||||
else:
|
else:
|
||||||
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
prompt = custom_generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
|
||||||
|
|
||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
if not any((regenerate, _continue)):
|
if not any((regenerate, _continue)):
|
||||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
for i in range(generate_state['chat_generation_attempts']):
|
for i in range(state['chat_generation_attempts']):
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
|
|
||||||
# Extracting the reply
|
# Extracting the reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
|
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||||
visible_reply = apply_extensions(visible_reply, "output")
|
visible_reply = apply_extensions(visible_reply, "output")
|
||||||
|
|
||||||
# We need this global variable to handle the Stop event,
|
# We need this global variable to handle the Stop event,
|
||||||
@ -171,7 +168,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
shared.history['visible'].append(['', ''])
|
shared.history['visible'].append(['', ''])
|
||||||
|
|
||||||
if _continue:
|
if _continue:
|
||||||
sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply))
|
sep = list(map(lambda x: ' ' if x[-1] != ' ' else '', last_reply))
|
||||||
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
||||||
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
||||||
else:
|
else:
|
||||||
@ -188,28 +185,25 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||||||
yield shared.history['visible']
|
yield shared.history['visible']
|
||||||
|
|
||||||
|
|
||||||
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def impersonate_wrapper(text, state):
|
||||||
if mode == 'instruct':
|
if state['mode'] == 'instruct':
|
||||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
|
||||||
else:
|
else:
|
||||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
|
||||||
|
|
||||||
# Defining some variables
|
# Defining some variables
|
||||||
cumulative_reply = ''
|
cumulative_reply = ''
|
||||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
eos_token = '\n' if state['stop_at_newline'] else None
|
||||||
if 'pygmalion' in shared.model_name.lower():
|
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], end_of_turn=state['end_of_turn'], impersonate=True)
|
||||||
name1 = "You"
|
|
||||||
|
|
||||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
|
|
||||||
|
|
||||||
# Yield *Is typing...*
|
# Yield *Is typing...*
|
||||||
yield shared.processing_message
|
yield shared.processing_message
|
||||||
|
|
||||||
for i in range(generate_state['chat_generation_attempts']):
|
for i in range(state['chat_generation_attempts']):
|
||||||
reply = None
|
reply = None
|
||||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||||
reply = cumulative_reply + reply
|
reply = cumulative_reply + reply
|
||||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||||
yield reply
|
yield reply
|
||||||
if next_character_found:
|
if next_character_found:
|
||||||
break
|
break
|
||||||
@ -220,32 +214,32 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
|||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
|
|
||||||
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def cai_chatbot_wrapper(text, state):
|
||||||
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
for history in chatbot_wrapper(text, state):
|
||||||
yield chat_html_wrapper(history, name1, name2, mode)
|
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
|
||||||
|
|
||||||
|
|
||||||
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def regenerate_wrapper(text, state):
|
||||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
else:
|
else:
|
||||||
last_visible = shared.history['visible'].pop()
|
last_visible = shared.history['visible'].pop()
|
||||||
last_internal = shared.history['internal'].pop()
|
last_internal = shared.history['internal'].pop()
|
||||||
# Yield '*Is typing...*'
|
# Yield '*Is typing...*'
|
||||||
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode'])
|
||||||
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
|
for history in chatbot_wrapper(last_internal[0], state, regenerate=True):
|
||||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
|
|
||||||
|
|
||||||
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
def continue_wrapper(text, state):
|
||||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
else:
|
else:
|
||||||
# Yield ' ...'
|
# Yield ' ...'
|
||||||
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode'])
|
||||||
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
|
for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
|
||||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||||
|
|
||||||
|
|
||||||
def remove_last_message(name1, name2, mode):
|
def remove_last_message(name1, name2, mode):
|
||||||
|
@ -69,6 +69,7 @@ def generate_softprompt_input_tensors(input_ids):
|
|||||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||||
return inputs_embeds, filler_input_ids
|
return inputs_embeds, filler_input_ids
|
||||||
|
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
# Removes empty replies from gpt4chan outputs
|
||||||
def fix_gpt4chan(s):
|
def fix_gpt4chan(s):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@ -77,6 +78,7 @@ def fix_gpt4chan(s):
|
|||||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# Fix the LaTeX equations in galactica
|
# Fix the LaTeX equations in galactica
|
||||||
def fix_galactica(s):
|
def fix_galactica(s):
|
||||||
s = s.replace(r'\[', r'$')
|
s = s.replace(r'\[', r'$')
|
||||||
@ -117,9 +119,9 @@ def stop_everything_event():
|
|||||||
shared.stop_everything = True
|
shared.stop_everything = True
|
||||||
|
|
||||||
|
|
||||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
seed = set_manual_seed(generate_state['seed'])
|
seed = set_manual_seed(state['seed'])
|
||||||
shared.stop_everything = False
|
shared.stop_everything = False
|
||||||
generate_params = {}
|
generate_params = {}
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
@ -134,8 +136,8 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
# separately and terminate the function call earlier
|
# separately and terminate the function call earlier
|
||||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||||
generate_params[k] = generate_state[k]
|
generate_params[k] = state[k]
|
||||||
generate_params['token_count'] = generate_state['max_new_tokens']
|
generate_params['token_count'] = state['max_new_tokens']
|
||||||
try:
|
try:
|
||||||
if shared.args.no_stream:
|
if shared.args.no_stream:
|
||||||
reply = shared.model.generate(context=question, **generate_params)
|
reply = shared.model.generate(context=question, **generate_params)
|
||||||
@ -164,7 +166,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
|
||||||
input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token'])
|
input_ids = encode(question, state['max_new_tokens'], add_bos_token=state['add_bos_token'])
|
||||||
original_input_ids = input_ids
|
original_input_ids = input_ids
|
||||||
output = input_ids[0]
|
output = input_ids[0]
|
||||||
|
|
||||||
@ -179,13 +181,13 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
|
|
||||||
if not shared.args.flexgen:
|
if not shared.args.flexgen:
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||||
generate_params[k] = generate_state[k]
|
generate_params[k] = state[k]
|
||||||
generate_params['eos_token_id'] = eos_token_ids
|
generate_params['eos_token_id'] = eos_token_ids
|
||||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||||
else:
|
else:
|
||||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||||
generate_params[k] = generate_state[k]
|
generate_params[k] = state[k]
|
||||||
generate_params['stop'] = generate_state['eos_token_ids'][-1]
|
generate_params['stop'] = state['eos_token_ids'][-1]
|
||||||
if not shared.args.no_stream:
|
if not shared.args.no_stream:
|
||||||
generate_params['max_new_tokens'] = 8
|
generate_params['max_new_tokens'] = 8
|
||||||
|
|
||||||
@ -248,7 +250,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||||||
|
|
||||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||||
else:
|
else:
|
||||||
for i in range(generate_state['max_new_tokens'] // 8 + 1):
|
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = shared.model.generate(**generate_params)[0]
|
output = shared.model.generate(**generate_params)[0]
|
||||||
|
126
server.py
126
server.py
@ -184,22 +184,22 @@ def download_model_wrapper(repo_id):
|
|||||||
branch = "main"
|
branch = "main"
|
||||||
check = False
|
check = False
|
||||||
|
|
||||||
yield("Cleaning up the model/branch names")
|
yield ("Cleaning up the model/branch names")
|
||||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||||
|
|
||||||
yield("Getting the download links from Hugging Face")
|
yield ("Getting the download links from Hugging Face")
|
||||||
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||||
|
|
||||||
yield("Getting the output folder")
|
yield ("Getting the output folder")
|
||||||
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
yield("Checking previously downloaded files")
|
yield ("Checking previously downloaded files")
|
||||||
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||||
else:
|
else:
|
||||||
yield(f"Downloading files to {output_folder}")
|
yield (f"Downloading files to {output_folder}")
|
||||||
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
||||||
yield("Done!")
|
yield ("Done!")
|
||||||
except:
|
except:
|
||||||
yield traceback.format_exc()
|
yield traceback.format_exc()
|
||||||
|
|
||||||
@ -357,6 +357,20 @@ else:
|
|||||||
title = 'Text generation web UI'
|
title = 'Text generation web UI'
|
||||||
|
|
||||||
|
|
||||||
|
def list_interface_input_elements(chat=False):
|
||||||
|
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token']
|
||||||
|
if chat:
|
||||||
|
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
|
||||||
|
return elements
|
||||||
|
|
||||||
|
|
||||||
|
def gather_interface_values(*args):
|
||||||
|
output = {}
|
||||||
|
for i, element in enumerate(shared.input_elements):
|
||||||
|
output[element] = args[i]
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def create_interface():
|
def create_interface():
|
||||||
gen_events = []
|
gen_events = []
|
||||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||||
@ -364,7 +378,11 @@ def create_interface():
|
|||||||
|
|
||||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||||
if shared.is_chat():
|
if shared.is_chat():
|
||||||
|
|
||||||
|
shared.input_elements = list_interface_input_elements(chat=True)
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
shared.gradio['Chat input'] = gr.State()
|
shared.gradio['Chat input'] = gr.State()
|
||||||
|
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
||||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||||
@ -384,7 +402,7 @@ def create_interface():
|
|||||||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
||||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||||
|
|
||||||
shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
shared.gradio["mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
|
||||||
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.")
|
shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False, info="Change this according to the model/LoRA that you are using.")
|
||||||
|
|
||||||
with gr.Tab("Character", elem_id="chat-settings"):
|
with gr.Tab("Character", elem_id="chat-settings"):
|
||||||
@ -439,75 +457,85 @@ def create_interface():
|
|||||||
|
|
||||||
create_settings_menus(default_preset)
|
create_settings_menus(default_preset)
|
||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
|
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]
|
||||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||||
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
|
reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'mode']]
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
|
||||||
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Regenerate'].click(
|
gen_events.append(shared.gradio['Regenerate'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Continue'].click(
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Impersonate'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
)
|
)
|
||||||
|
|
||||||
shared.gradio['Replace last reply'].click(
|
shared.gradio['Replace last reply'].click(
|
||||||
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
|
||||||
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
|
||||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
shared.gradio['Clear history-confirm'].click(
|
shared.gradio['Clear history-confirm'].click(
|
||||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
|
||||||
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
|
chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'mode']], shared.gradio['display']).then(
|
||||||
chat.save_history, shared.gradio['Chat mode'], None, show_progress=False)
|
chat.save_history, shared.gradio['mode'], None, show_progress=False)
|
||||||
|
|
||||||
shared.gradio['Stop'].click(
|
shared.gradio['Stop'].click(
|
||||||
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
|
||||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
shared.gradio['Chat mode'].change(
|
shared.gradio['mode'].change(
|
||||||
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
|
lambda x: gr.update(visible=x == 'instruct'), shared.gradio['mode'], shared.gradio['Instruction templates']).then(
|
||||||
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['Chat mode'], shared.gradio['character_menu']).then(
|
lambda x: gr.update(interactive=x != 'instruct'), shared.gradio['mode'], shared.gradio['character_menu']).then(
|
||||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
shared.gradio['Instruction templates'].change(
|
shared.gradio['Instruction templates'].change(
|
||||||
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
|
||||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
shared.gradio['upload_chat_history'].upload(
|
shared.gradio['upload_chat_history'].upload(
|
||||||
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
|
chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
|
||||||
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
chat.redraw_html, reload_inputs, shared.gradio['display'])
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
|
||||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||||
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||||
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||||
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['Chat mode'], shared.gradio['download'])
|
shared.gradio['download_button'].click(lambda x: chat.save_history(x, timestamp=True), shared.gradio['mode'], shared.gradio['download'])
|
||||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||||
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
|
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
|
||||||
|
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||||
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
||||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
||||||
|
|
||||||
elif shared.args.notebook:
|
elif shared.args.notebook:
|
||||||
|
shared.input_elements = list_interface_input_elements(chat=False)
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
@ -537,12 +565,23 @@ def create_interface():
|
|||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
||||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
|
||||||
|
)
|
||||||
|
|
||||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
shared.input_elements = list_interface_input_elements(chat=False)
|
||||||
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -570,9 +609,22 @@ def create_interface():
|
|||||||
|
|
||||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
|
||||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
|
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
|
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)
|
||||||
|
)
|
||||||
|
|
||||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||||
|
|
||||||
@ -607,18 +659,6 @@ def create_interface():
|
|||||||
if shared.args.extensions is not None:
|
if shared.args.extensions is not None:
|
||||||
extensions_module.create_extensions_block()
|
extensions_module.create_extensions_block()
|
||||||
|
|
||||||
def change_dict_value(d, key, value):
|
|
||||||
d[key] = value
|
|
||||||
return d
|
|
||||||
|
|
||||||
for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'add_bos_token', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
|
|
||||||
if k not in shared.gradio:
|
|
||||||
continue
|
|
||||||
if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
|
|
||||||
shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
|
||||||
else:
|
|
||||||
shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
|
|
||||||
|
|
||||||
if not shared.is_chat():
|
if not shared.is_chat():
|
||||||
api.create_apis()
|
api.create_apis()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user