Use download-model.py to download the model

This commit is contained in:
oobabooga 2023-04-10 11:36:39 -03:00
parent c6e9ba20a4
commit 2c14df81a8
2 changed files with 50 additions and 72 deletions

View File

@ -20,17 +20,6 @@ import tqdm
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
def select_model_from_default_options(): def select_model_from_default_options():
models = { models = {
"OPT 6.7B": ("facebook", "opt-6.7b", "main"), "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@ -244,6 +233,17 @@ def check_model_files(model, branch, links, sha256, output_folder):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args()
branch = args.branch branch = args.branch
model = args.MODEL model = args.MODEL
if model is None: if model is None:

100
server.py
View File

@ -2,17 +2,21 @@ import os
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
import importlib
import io import io
import json import json
import os
import re import re
import sys import sys
import time import time
import traceback
import zipfile import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import os
import requests
import gradio as gr import gradio as gr
import requests
from huggingface_hub import HfApi
from PIL import Image from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
@ -21,7 +25,6 @@ from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model from modules.models import load_model, load_soft_prompt, unload_model
from modules.text_generation import generate_reply, stop_everything_event from modules.text_generation import generate_reply, stop_everything_event
from huggingface_hub import HfApi
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
@ -175,59 +178,31 @@ def create_prompt_menus():
def download_model_wrapper(repo_id): def download_model_wrapper(repo_id):
print(repo_id) try:
if repo_id == '': downloader = importlib.import_module("download-model")
print("Please enter a valid repo ID. This field cant be empty")
else:
try:
print('Downloading repo')
hf_api = HfApi()
# Get repo info
repo_info = hf_api.repo_info(
repo_id=repo_id,
repo_type="model",
revision="main"
)
# create model and repo folder and check for lora
is_lora = False
for file in repo_info.siblings:
if 'adapter_model.bin' in file.rfilename:
is_lora = True
repo_dir_name = repo_id.replace("/", "--")
if is_lora is True:
models_dir = "loras"
else:
models_dir = "models"
if not os.path.exists(models_dir):
os.makedirs(models_dir)
repo_dir = os.path.join(models_dir, repo_dir_name)
if not os.path.exists(repo_dir):
os.makedirs(repo_dir)
for sibling in repo_info.siblings: model = repo_id
filename = sibling.rfilename branch = "main"
url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}" check = False
download_path = os.path.join(repo_dir, filename)
response = requests.get(url, stream=True)
# Get the total file size from the content-length header
total_size = int(response.headers.get('content-length', 0))
# Download the file in chunks and print progress yield("Cleaning up the model/branch names")
with open(download_path, 'wb') as f: model, branch = downloader.sanitize_model_and_branch_names(model, branch)
downloaded_size = 0
for data in response.iter_content(chunk_size=10000000):
downloaded_size += len(data)
f.write(data)
progress = downloaded_size * 100 // total_size
downloaded_size_mb = downloaded_size / (1024 * 1024)
total_size_mb = total_size / (1024 * 1024)
print(f"\rDownloading {filename}... {progress}% complete "
f"({downloaded_size_mb:.2f}/{total_size_mb:.2f} MB)", end="", flush=True)
print(f"\rDownloading {filename}... Complete!")
print('Repo Downloaded') yield("Getting the download links from Hugging Face")
except ValueError as e: links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
raise ValueError("Please enter a valid repo ID. Error: {}".format(e))
yield("Getting the output folder")
output_folder = downloader.get_output_folder(model, branch, is_lora)
if check:
yield("Checking previously downloaded files")
downloader.check_model_files(model, branch, links, sha256, output_folder)
else:
yield("Downloading files")
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
yield("Done!")
except:
yield traceback.format_exc()
def create_model_menus(): def create_model_menus():
@ -241,17 +216,20 @@ def create_model_menus():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button') ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
with gr.Row(): with gr.Row():
with gr.Column(scale=0.5): with gr.Column():
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model", with gr.Row():
info="Enter hugging face username/model path e.g: 'decapoda-research/llama-7b-hf'") with gr.Column():
with gr.Row(): shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
with gr.Column(scale=0.5): info="Enter Hugging Face username/model path e.g: facebook/galactica-125m")
shared.gradio['download_button'] = gr.Button("Download", show_progress=True) with gr.Column():
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
show_progress=True) shared.gradio['download_status'] = gr.Markdown()
with gr.Column():
pass
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):