Implement notebook mode

This commit is contained in:
oobabooga 2023-01-06 20:22:26 -03:00
parent 6c02fe7a94
commit e5f547fc87
2 changed files with 42 additions and 17 deletions

View File

@ -72,6 +72,7 @@ Then browse to
Optionally, you can use the following command-line flags: Optionally, you can use the following command-line flags:
--model model-name: load this model by default. --model model-name: load this model by default.
--notebook: Launch the webui in notebook mode, where the output is written to the same text box as the input.
## Presets ## Presets

View File

@ -11,7 +11,8 @@ from transformers import AutoTokenizer
from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='Name of the model to load by default') parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.')
args = parser.parse_args() args = parser.parse_args()
loaded_preset = None loaded_preset = None
available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]")))) available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]"))))
@ -79,7 +80,10 @@ def generate_reply(question, temperature, max_length, inference_settings, select
if model_name.startswith('gpt4chan'): if model_name.startswith('gpt4chan'):
reply = fix_gpt4chan(reply) reply = fix_gpt4chan(reply)
return reply if model_name.lower().startswith('galactica'):
return reply, reply
else:
return reply, ''
# Choosing the default model # Choosing the default model
if args.model is not None: if args.model is not None:
@ -104,20 +108,40 @@ if model_name.startswith('gpt4chan'):
else: else:
default_text = "Common sense questions and answers\n\nQuestion: \nFactual answer:" default_text = "Common sense questions and answers\n\nQuestion: \nFactual answer:"
interface = gr.Interface( if args.notebook:
generate_reply, with gr.Blocks() as interface:
inputs=[ gr.Markdown(
gr.Textbox(value=default_text, lines=15), f"""
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7), # Text generation lab
gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200), Generate text using Large Language Models.
gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"), """
gr.Dropdown(choices=available_models, value=model_name), )
],
outputs=[ textbox = gr.Textbox(value=default_text, lines=23)
gr.Textbox(placeholder="", lines=15), temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7)
], length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200)
title="Text generation lab", preset_menu = gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default")
description=f"Generate text using Large Language Models.", model_menu = gr.Dropdown(choices=available_models, value=model_name)
) btn = gr.Button("Generate")
markdown = gr.Markdown()
btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [textbox, markdown], show_progress=False)
else:
interface = gr.Interface(
generate_reply,
inputs=[
gr.Textbox(value=default_text, lines=15),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
gr.Dropdown(choices=available_models, value=model_name),
],
outputs=[
gr.Textbox(placeholder="", lines=15),
gr.Markdown()
],
title="Text generation lab",
description=f"Generate text using Large Language Models.",
)
interface.launch(share=False, server_name="0.0.0.0") interface.launch(share=False, server_name="0.0.0.0")