diff --git a/modules/LoRA.py b/modules/LoRA.py index 84e128fb..c95da6ee 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -10,6 +10,8 @@ def add_lora_to_model(lora_name): # Is there a more efficient way of returning to the base model? if lora_name == "None": + print(f"Reloading the model to remove the LoRA...") shared.model, shared.tokenizer = load_model(shared.model_name) else: + print(f"Adding the LoRA {lora_name} to the model...") shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}")) diff --git a/modules/shared.py b/modules/shared.py index 9d4484c4..488a1e96 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -56,7 +56,7 @@ settings = { }, 'lora_prompts': { 'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', - 'alpaca-lora-7b': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a Python script that generates text using the transformers library.\n### Response:\n" + 'alpaca-lora-7b': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n" } } diff --git a/server.py b/server.py index 8dacc132..7d5ecc74 100644 --- a/server.py +++ b/server.py @@ -64,11 +64,15 @@ def load_model_wrapper(selected_model): return selected_model def load_lora_wrapper(selected_lora): + shared.lora_name = selected_lora + default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] + if not shared.args.cpu: gc.collect() torch.cuda.empty_cache() add_lora_to_model(selected_lora) - return selected_lora + + return selected_lora, default_text def load_preset_values(preset_menu, return_dict=False): generate_params = { @@ -156,6 +160,10 @@ def create_settings_menus(default_preset): shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') + with gr.Row(): + shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') + ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button') + with gr.Accordion('Soft prompt', open=False): with gr.Row(): shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt') @@ -167,6 +175,7 @@ def create_settings_menus(default_preset): shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']]) + shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True) shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']]) @@ -226,8 +235,8 @@ else: shared.model_name = available_models[i] shared.model, shared.tokenizer = load_model(shared.model_name) if shared.args.lora: + print(shared.args.lora) shared.lora_name = shared.args.lora - print(f"Adding the LoRA {shared.lora_name} to the model...") add_lora_to_model(shared.lora_name) # Default UI settings @@ -419,19 +428,6 @@ def create_interface(): shared.gradio['Stop'].click(None, None, None, cancels=gen_events) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") - with gr.Tab("LoRA", elem_id="lora"): - with gr.Row(): - with gr.Column(): - gr.Markdown("Load") - with gr.Row(): - shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') - ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button') - with gr.Column(): - gr.Markdown("Train (TODO)") - gr.Button("Practice your button clicking skills") - - shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True) - with gr.Tab("Interface mode", elem_id="interface-mode"): modes = ["default", "notebook", "chat", "cai_chat"] current_mode = "default"