From f54a13929f6c70fe48a8b70634c7499d6270a496 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 6 Jan 2023 19:56:44 -0300 Subject: [PATCH] Load default model with --model flag --- server.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/server.py b/server.py index 4a9eb030..2783743a 100644 --- a/server.py +++ b/server.py @@ -2,23 +2,19 @@ import os import re import time import glob +from sys import exit import torch +import argparse import gradio as gr import transformers from transformers import AutoTokenizer from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel -#model_name = "bloomz-7b1-p3" -#model_name = 'gpt-j-6B-float16' -#model_name = "opt-6.7b" -#model_name = 'opt-13b' -model_name = "gpt4chan_model_float16" -#model_name = 'galactica-6.7b' -#model_name = 'gpt-neox-20b' -#model_name = 'flan-t5' -#model_name = 'OPT-13B-Erebus' - +parser = argparse.ArgumentParser() +parser.add_argument('--model', type=str, help='Name of the model to load by default') +args = parser.parse_args() loaded_preset = None +available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]")))) def load_model(model_name): print(f"Loading {model_name}...") @@ -85,7 +81,24 @@ def generate_reply(question, temperature, max_length, inference_settings, select return reply +# Choosing the default model +if args.model is not None: + model_name = args.model +else: + if len(available_models == 0): + print("No models are available! Please download at least one.") + exit(0) + elif len(available_models) == 1: + i = 0 + else: + print("The following models are available:\n") + for i,model in enumerate(available_models): + print(f"{i+1}. {model}") + print(f"\nWhich one do you want to load? 1-{len(available_models)}\n") + i = int(input())-1 + model_name = available_models[i] model, tokenizer = load_model(model_name) + if model_name.startswith('gpt4chan'): default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n" else: @@ -98,7 +111,7 @@ interface = gr.Interface( gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7), gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200), gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"), - gr.Dropdown(choices=sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*") + glob.glob("torch-dumps/*")))), value=model_name), + gr.Dropdown(choices=available_models, value=model_name), ], outputs=[ gr.Textbox(placeholder="", lines=15),