From 104293f411cd517babf19ecb7d80031b9e6df5f6 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 16 Mar 2023 21:31:39 -0300 Subject: [PATCH] Add LoRA support --- css/main.css | 11 ++++++++++- download-model.py | 17 +++++++++++------ modules/models.py | 2 ++ modules/shared.py | 3 ++- requirements.txt | 1 + server.py | 25 +++++++++++++++++++++++++ 6 files changed, 51 insertions(+), 8 deletions(-) diff --git a/css/main.css b/css/main.css index f5ccfe94..87c3bded 100644 --- a/css/main.css +++ b/css/main.css @@ -1,12 +1,15 @@ .tabs.svelte-710i53 { margin-top: 0 } + .py-6 { padding-top: 2.5rem } + .dark #refresh-button { background-color: #ffffff1f; } + #refresh-button { flex: none; margin: 0; @@ -17,22 +20,28 @@ border-radius: 10px; background-color: #0000000d; } + #download-label, #upload-label { min-height: 0 } + #accordion { } + .dark svg { fill: white; } + svg { display: unset !important; vertical-align: middle !important; margin: 5px; } + ol li p, ul li p { display: inline-block; } -#main, #parameters, #chat-settings, #interface-mode { + +#main, #parameters, #chat-settings, #interface-mode, #lora { border: 0; } diff --git a/download-model.py b/download-model.py index 8be398c4..808b9fc2 100644 --- a/download-model.py +++ b/download-model.py @@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch): classifications = [] has_pytorch = False has_safetensors = False + is_lora = False while True: content = requests.get(f"{base}{page}{cursor.decode()}").content @@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch): for i in range(len(dict)): fname = dict[i]['path'] + if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')): + is_lora = True - is_pytorch = re.match("pytorch_model.*\.bin", fname) + is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname) is_safetensors = re.match("model.*\.safetensors", fname) is_tokenizer = re.match("tokenizer.*\.model", fname) is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer @@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch): has_pytorch = True classifications.append('pytorch') + cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(cursor) cursor = cursor.replace(b'=', b'%3D') @@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch): if classifications[i] == 'pytorch': links.pop(i) - return links + return links, is_lora if __name__ == '__main__': model = args.MODEL @@ -159,15 +163,16 @@ if __name__ == '__main__': except ValueError as err_branch: print(f"Error: {err_branch}") sys.exit() + + links, is_lora = get_download_links_from_huggingface(model, branch) + base_folder = 'models' if not is_lora else 'loras' if branch != 'main': - output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}') + output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}') else: - output_folder = Path("models") / model.split('/')[-1] + output_folder = Path(base_folder) / model.split('/')[-1] if not output_folder.exists(): output_folder.mkdir() - links = get_download_links_from_huggingface(model, branch) - # Downloading the files print(f"Downloading the model to {output_folder}") pool = multiprocessing.Pool(processes=args.threads) diff --git a/modules/models.py b/modules/models.py index 63060d43..6df67d3c 100644 --- a/modules/models.py +++ b/modules/models.py @@ -11,6 +11,8 @@ from accelerate import infer_auto_device_map, init_empty_weights from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig) +from peft import PeftModel + import modules.shared as shared transformers.logging.set_verbosity_error() diff --git a/modules/shared.py b/modules/shared.py index da5efbd3..908455e1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -2,7 +2,8 @@ import argparse model = None tokenizer = None -model_name = "" +model_name = "None" +lora_name = "None" soft_prompt_tensor = None soft_prompt = False is_RWKV = False diff --git a/requirements.txt b/requirements.txt index b9a9b385..fcf000a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ flexgen==0.1.7 gradio==3.18.0 markdown numpy +peft==0.2.0 requests rwkv==0.4.2 safetensors==0.3.0 diff --git a/server.py b/server.py index 2024fd42..dd35d9aa 100644 --- a/server.py +++ b/server.py @@ -17,6 +17,7 @@ import modules.ui as ui from modules.html_generator import generate_chat_html from modules.models import load_model, load_soft_prompt from modules.text_generation import generate_reply +from modules.LoRA import add_lora_to_model # Loading custom settings settings_file = None @@ -48,6 +49,9 @@ def get_available_extensions(): def get_available_softprompts(): return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) +def get_available_loras(): + return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + def load_model_wrapper(selected_model): if selected_model != shared.model_name: shared.model_name = selected_model @@ -59,6 +63,13 @@ def load_model_wrapper(selected_model): return selected_model +def load_lora_wrapper(selected_lora): + if not shared.args.cpu: + gc.collect() + torch.cuda.empty_cache() + add_lora_to_model(selected_lora) + return selected_lora + def load_preset_values(preset_menu, return_dict=False): generate_params = { 'do_sample': True, @@ -181,6 +192,7 @@ available_models = get_available_models() available_presets = get_available_presets() available_characters = get_available_characters() available_softprompts = get_available_softprompts() +available_loras = get_available_loras() # Default extensions extensions_module.available_extensions = get_available_extensions() @@ -401,6 +413,19 @@ def create_interface(): shared.gradio['Stop'].click(None, None, None, cancels=gen_events) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") + with gr.Tab("LoRA", elem_id="lora"): + with gr.Row(): + with gr.Column(): + gr.Markdown("Load") + with gr.Row(): + shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') + ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button') + with gr.Column(): + gr.Markdown("Train (TODO)") + gr.Button("Practice your button clicking skills") + + shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True) + with gr.Tab("Interface mode", elem_id="interface-mode"): modes = ["default", "notebook", "chat", "cai_chat"] current_mode = "default"