Allow full model URL to be used for download (#3919)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
kalomaze 2023-09-16 08:06:13 -05:00 committed by GitHub
parent ed6b6411fb
commit 7c9664ed35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 7 deletions

View File

@ -22,6 +22,9 @@ from requests.adapters import HTTPAdapter
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
base = "https://huggingface.co"
class ModelDownloader: class ModelDownloader:
def __init__(self, max_retries=5): def __init__(self, max_retries=5):
self.session = requests.Session() self.session = requests.Session()
@ -37,6 +40,13 @@ class ModelDownloader:
if model[-1] == '/': if model[-1] == '/':
model = model[:-1] model = model[:-1]
if model.startswith(base + '/'):
model = model[len(base) + 1:]
model_parts = model.split(":")
model = model_parts[0] if len(model_parts) > 0 else model
branch = model_parts[1] if len(model_parts) > 1 else branch
if branch is None: if branch is None:
branch = "main" branch = "main"
else: else:
@ -48,7 +58,6 @@ class ModelDownloader:
return model, branch return model, branch
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None): def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}" page = f"/api/models/{model}/tree/{branch}"
cursor = b"" cursor = b""

View File

@ -216,18 +216,14 @@ def load_lora_wrapper(selected_loras):
yield ("Successfuly applied the LoRAs") yield ("Successfuly applied the LoRAs")
def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False): def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False):
try: try:
downloader_module = importlib.import_module("download-model") downloader_module = importlib.import_module("download-model")
downloader = downloader_module.ModelDownloader() downloader = downloader_module.ModelDownloader()
repo_id_parts = repo_id.split(":")
model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id
branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"
check = False
progress(0.0) progress(0.0)
yield ("Cleaning up the model/branch names") yield ("Cleaning up the model/branch names")
model, branch = downloader.sanitize_model_and_branch_names(model, branch) model, branch = downloader.sanitize_model_and_branch_names(repo_id, None)
yield ("Getting the download links from Hugging Face") yield ("Getting the download links from Hugging Face")
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file) links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file)