From 1cb9246160bafca2599b20b69e7c4e9afff410e6 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 29 Mar 2023 21:47:36 -0300 Subject: [PATCH] Adapt to the new model names --- modules/GPTQ_loader.py | 7 ++++--- modules/models.py | 4 ++-- modules/shared.py | 4 ---- modules/text_generation.py | 6 +++--- server.py | 13 ++++++------- settings-template.json | 9 +++------ 6 files changed, 18 insertions(+), 25 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 7926d0ab..e7877de7 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -51,11 +51,12 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc def load_quantized(model_name): if not shared.args.model_type: # Try to determine model type from model name - if model_name.lower().startswith(('llama', 'alpaca')): + name = model_name.lower() + if any((k in name for k in ['llama', 'alpaca'])): model_type = 'llama' - elif model_name.lower().startswith(('opt', 'galactica')): + elif any((k in name for k in ['opt-', 'galactica'])): model_type = 'opt' - elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')): + elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])): model_type = 'gptj' else: print("Can't determine model type from model name. Please specify it manually using --model_type " diff --git a/modules/models.py b/modules/models.py index a6839318..b19507db 100644 --- a/modules/models.py +++ b/modules/models.py @@ -41,7 +41,7 @@ def load_model(model_name): print(f"Loading {model_name}...") t0 = time.time() - shared.is_RWKV = model_name.lower().startswith('rwkv-') + shared.is_RWKV = 'rwkv-' in model_name.lower() # Default settings if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): @@ -159,7 +159,7 @@ def load_model(model_name): model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) # Loading the tokenizer - if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): + if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) else: tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) diff --git a/modules/shared.py b/modules/shared.py index 5d1b42d4..8bbf3b69 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -37,10 +37,6 @@ settings = { 'chat_generation_attempts': 1, 'chat_generation_attempts_min': 1, 'chat_generation_attempts_max': 5, - 'name1_pygmalion': 'You', - 'name2_pygmalion': 'Kawaii', - 'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n", - 'stop_at_newline_pygmalion': False, 'default_extensions': [], 'chat_default_extensions': ["gallery"], 'presets': { diff --git a/modules/text_generation.py b/modules/text_generation.py index 20a07ca3..7b5fcd6a 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -42,7 +42,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): def decode(output_ids): # Open Assistant relies on special tokens like <|endoftext|> - if re.match('(oasst|galactica)-*', shared.model_name.lower()): + if re.match('.*(oasst|galactica)-*', shared.model_name.lower()): return shared.tokenizer.decode(output_ids, skip_special_tokens=False) else: reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) @@ -77,10 +77,10 @@ def fix_galactica(s): def formatted_outputs(reply, model_name): if not (shared.args.chat or shared.args.cai_chat): - if model_name.lower().startswith('galactica'): + if 'galactica' in model_name.lower(): reply = fix_galactica(reply) return reply, reply, generate_basic_html(reply) - elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')): + elif any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])): reply = fix_gpt4chan(reply) return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) else: diff --git a/server.py b/server.py index 6023451b..62a7ebfb 100644 --- a/server.py +++ b/server.py @@ -282,7 +282,6 @@ else: default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')] title ='Text generation web UI' description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n' -suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else '' def create_interface(): @@ -294,7 +293,7 @@ def create_interface(): if shared.args.chat or shared.args.cai_chat: with gr.Tab("Text generation", elem_id="main"): if shared.args.cai_chat: - shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character)) + shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character)) else: shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528")) shared.gradio['textbox'] = gr.Textbox(label='Input') @@ -314,9 +313,9 @@ def create_interface(): shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) with gr.Tab("Character", elem_id="chat-settings"): - shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name') - shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name') - shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context') + shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') + shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name') + shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context') with gr.Row(): shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') @@ -354,7 +353,7 @@ def create_interface(): shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) with gr.Column(): shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') - shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?') + shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?') create_settings_menus(default_preset) @@ -401,7 +400,7 @@ def create_interface(): shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") - shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) + shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) elif shared.args.notebook: diff --git a/settings-template.json b/settings-template.json index 79fd5023..2a2aaed9 100644 --- a/settings-template.json +++ b/settings-template.json @@ -12,10 +12,6 @@ "chat_generation_attempts": 1, "chat_generation_attempts_min": 1, "chat_generation_attempts_max": 5, - "name1_pygmalion": "You", - "name2_pygmalion": "Kawaii", - "context_pygmalion": "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n", - "stop_at_newline_pygmalion": false, "default_extensions": [], "chat_default_extensions": [ "gallery" @@ -29,10 +25,11 @@ "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n", "(rosey|chip|joi)_.*_instruct.*": "User: \n", - "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>" + "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>", + "alpaca-*": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" }, "lora_prompts": { "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:", - "alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" + "(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" } }