download custom model menu (from hugging face) added in model tab

This commit is contained in:
Usama Kenway 2023-04-09 16:11:43 +05:00
parent bce1b7fbb2
commit 7436dd5b4a

View File

@ -10,7 +10,8 @@ import time
import zipfile import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import os
import requests
import gradio as gr import gradio as gr
from PIL import Image from PIL import Image
@ -20,6 +21,7 @@ from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model from modules.models import load_model, load_soft_prompt, unload_model
from modules.text_generation import generate_reply, stop_everything_event from modules.text_generation import generate_reply, stop_everything_event
from huggingface_hub import HfApi
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
@ -172,6 +174,62 @@ def create_prompt_menus():
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def download_model_wrapper(repo_id):
print(repo_id)
if repo_id == '':
print("Please enter a valid repo ID. This field cant be empty")
else:
try:
print('Downloading repo')
hf_api = HfApi()
# Get repo info
repo_info = hf_api.repo_info(
repo_id=repo_id,
repo_type="model",
revision="main"
)
# create model and repo folder and check for lora
is_lora = False
for file in repo_info.siblings:
if 'adapter_model.bin' in file.rfilename:
is_lora = True
repo_dir_name = repo_id.replace("/", "--")
if is_lora is True:
models_dir = ".loras"
else:
models_dir = ".models"
if not os.path.exists(models_dir):
os.makedirs(models_dir)
repo_dir = os.path.join(models_dir, repo_dir_name)
if not os.path.exists(repo_dir):
os.makedirs(repo_dir)
for sibling in repo_info.siblings:
filename = sibling.rfilename
url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
download_path = os.path.join(repo_dir, filename)
response = requests.get(url, stream=True)
# Get the total file size from the content-length header
total_size = int(response.headers.get('content-length', 0))
# Download the file in chunks and print progress
with open(download_path, 'wb') as f:
downloaded_size = 0
for data in response.iter_content(chunk_size=10000000):
downloaded_size += len(data)
f.write(data)
progress = downloaded_size * 100 // total_size
downloaded_size_mb = downloaded_size / (1024 * 1024)
total_size_mb = total_size / (1024 * 1024)
print(f"\rDownloading {filename}... {progress}% complete "
f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True)
print(f"\rDownloading {filename}... Complete!")
print('Repo Downloaded')
except ValueError as e:
raise ValueError("Please enter a valid repo ID. Error: {}".format(e))
def create_model_menus(): def create_model_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@ -182,6 +240,15 @@ def create_model_menus():
with gr.Row(): with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
with gr.Row():
with gr.Column(scale=0.5):
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'")
with gr.Row():
with gr.Column(scale=0.5):
shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'],
show_progress=True)
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)