Show progress on impersonate

This commit is contained in:
oobabooga 2023-09-13 11:22:53 -07:00
parent 7cd437e05c
commit 8ce94b735c
2 changed files with 7 additions and 4 deletions

View File

@ -266,18 +266,21 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
def impersonate_wrapper(text, state): def impersonate_wrapper(text, state):
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'])
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.") logger.error("No model is loaded! Select one in the Model tab.")
yield '' yield '', static_output
return return
prompt = generate_chat_prompt('', state, impersonate=True) prompt = generate_chat_prompt('', state, impersonate=True)
stopping_strings = get_stopping_strings(state) stopping_strings = get_stopping_strings(state)
yield text + '...' yield text + '...', static_output
reply = None reply = None
for reply in generate_reply(prompt + text, state, stopping_strings=stopping_strings, is_chat=True): for reply in generate_reply(prompt + text, state, stopping_strings=stopping_strings, is_chat=True):
yield (text + reply).lstrip(' ') yield (text + reply).lstrip(' '), static_output
if shared.stop_everything: if shared.stop_everything:
return return

View File

@ -166,7 +166,7 @@ def create_event_handlers():
shared.gradio['Impersonate'].click( shared.gradio['Impersonate'].click(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda x: x, gradio('textbox'), gradio('Chat input'), show_progress=False).then( lambda x: x, gradio('textbox'), gradio('Chat input'), show_progress=False).then(
chat.impersonate_wrapper, gradio(inputs), gradio('textbox'), show_progress=False).then( chat.impersonate_wrapper, gradio(inputs), gradio('textbox', 'display'), show_progress=False).then(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}') lambda: None, None, None, _js=f'() => {{{ui.audio_notification_js}}}')