diff --git a/server.py b/server.py index 740020ea..36cf57b4 100644 --- a/server.py +++ b/server.py @@ -10,7 +10,8 @@ import time import zipfile from datetime import datetime from pathlib import Path - +import os +import requests import gradio as gr from PIL import Image @@ -20,6 +21,7 @@ from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt, unload_model from modules.text_generation import generate_reply, stop_everything_event +from huggingface_hub import HfApi # Loading custom settings settings_file = None @@ -172,6 +174,62 @@ def create_prompt_menus(): shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) +def download_model_wrapper(repo_id): + print(repo_id) + if repo_id == '': + print("Please enter a valid repo ID. This field cant be empty") + else: + try: + print('Downloading repo') + hf_api = HfApi() + # Get repo info + repo_info = hf_api.repo_info( + repo_id=repo_id, + repo_type="model", + revision="main" + ) + # create model and repo folder and check for lora + is_lora = False + for file in repo_info.siblings: + if 'adapter_model.bin' in file.rfilename: + is_lora = True + repo_dir_name = repo_id.replace("/", "--") + if is_lora is True: + models_dir = ".loras" + else: + models_dir = ".models" + if not os.path.exists(models_dir): + os.makedirs(models_dir) + repo_dir = os.path.join(models_dir, repo_dir_name) + if not os.path.exists(repo_dir): + os.makedirs(repo_dir) + + for sibling in repo_info.siblings: + filename = sibling.rfilename + url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}" + download_path = os.path.join(repo_dir, filename) + response = requests.get(url, stream=True) + # Get the total file size from the content-length header + total_size = int(response.headers.get('content-length', 0)) + + # Download the file in chunks and print progress + with open(download_path, 'wb') as f: + downloaded_size = 0 + for data in response.iter_content(chunk_size=10000000): + downloaded_size += len(data) + f.write(data) + progress = downloaded_size * 100 // total_size + downloaded_size_mb = downloaded_size / (1024 * 1024) + total_size_mb = total_size / (1024 * 1024) + print(f"\rDownloading {filename}... {progress}% complete " + f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True) + print(f"\rDownloading {filename}... Complete!") + + print('Repo Downloaded') + except ValueError as e: + raise ValueError("Please enter a valid repo ID. Error: {}".format(e)) + + def create_model_menus(): with gr.Row(): with gr.Column(): @@ -182,6 +240,15 @@ def create_model_menus(): with gr.Row(): shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') + with gr.Row(): + with gr.Column(scale=0.5): + shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model", + info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'") + with gr.Row(): + with gr.Column(scale=0.5): + shared.gradio['download_button'] = gr.Button("Download", show_progress=True) + shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], + show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)