Make it possible to download protected HF models from the command line. (#2408)

This commit is contained in:
Morgan Schweers 2023-05-31 20:11:21 -07:00 committed by GitHub
parent 419c34eca4
commit 1aed2b9e52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 172 additions and 157 deletions

View File

@ -156,7 +156,9 @@ For example:
python download-model.py facebook/opt-1.3b
If you want to download a model manually, note that all you need are the json, txt, and pytorch\*.bin (or model*.safetensors) files. The remaining files are not necessary.
* If you want to download a model manually, note that all you need are the json, txt, and pytorch\*.bin (or model*.safetensors) files. The remaining files are not necessary.
* If you want to download a protected model (one gated behind accepting a license or otherwise private, like `bigcode/starcoder`) you can set the environment variables `HF_USER` to your huggingface username and `HF_PASS` to your password or (_as a better option_) to a [User Access Token](https://huggingface.co/settings/tokens). Note that you will need to accept the model terms on the Hugging Face website before starting the download.
#### GGML models

View File

@ -12,6 +12,7 @@ import datetime
import hashlib
import json
import re
import os
import sys
from pathlib import Path
@ -70,20 +71,29 @@ EleutherAI/pythia-1.4b-deduped
return model, branch
def sanitize_model_and_branch_names(model, branch):
class ModelDownloader:
def __init__(self):
self.s = requests.Session()
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
def sanitize_model_and_branch_names(self, model, branch):
if model[-1] == '/':
model = model[:-1]
if branch is None:
branch = "main"
else:
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if not pattern.match(branch):
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
raise ValueError(
"Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
return model, branch
def get_download_links_from_huggingface(model, branch, text_only=False):
def get_download_links_from_huggingface(self, model, branch, text_only=False):
base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}"
cursor = b""
@ -98,7 +108,7 @@ def get_download_links_from_huggingface(model, branch, text_only=False):
is_lora = False
while True:
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
r = requests.get(url, timeout=10)
r = self.s.get(url, timeout=10)
r.raise_for_status()
content = r.content
@ -111,20 +121,21 @@ def get_download_links_from_huggingface(model, branch, text_only=False):
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
is_lora = True
is_pytorch = re.match("(pytorch|adapter|gptq)_model.*\.bin", fname)
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match(".*\.safetensors", fname)
is_pt = re.match(".*\.pt", fname)
is_ggml = re.match(".*ggml.*\.bin", fname)
is_tokenizer = re.match("(tokenizer|ice).*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
if 'lfs' in dict[i]:
sha256.append([fname, dict[i]['lfs']['oid']])
if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text')
continue
if not text_only:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
if is_safetensors:
@ -153,7 +164,7 @@ def get_download_links_from_huggingface(model, branch, text_only=False):
return links, sha256, is_lora
def get_output_folder(model, branch, is_lora, base_folder=None):
def get_output_folder(self, model, branch, is_lora, base_folder=None):
if base_folder is None:
base_folder = 'models' if not is_lora else 'loras'
@ -164,12 +175,12 @@ def get_output_folder(model, branch, is_lora, base_folder=None):
return output_folder
def get_single_file(url, output_folder, start_from_scratch=False):
def get_single_file(self, url, output_folder, start_from_scratch=False):
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely
r = requests.get(url, stream=True, timeout=10)
r = self.s.get(url, stream=True, timeout=10)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
@ -180,7 +191,7 @@ def get_single_file(url, output_folder, start_from_scratch=False):
headers = {}
mode = 'wb'
r = requests.get(url, stream=True, headers=headers, timeout=10)
r = self.s.get(url, stream=True, headers=headers, timeout=10)
with open(output_path, mode) as f:
total_size = int(r.headers.get('content-length', 0))
block_size = 1024
@ -190,11 +201,11 @@ def get_single_file(url, output_folder, start_from_scratch=False):
f.write(data)
def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1):
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
def download_model_files(self, model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
# Creating the folder and writing the metadata
if not output_folder.exists():
output_folder.mkdir(parents=True, exist_ok=True)
@ -210,10 +221,10 @@ def download_model_files(model, branch, links, sha256, output_folder, start_from
# Downloading the files
print(f"Downloading the model to {output_folder}")
start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
self.start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
def check_model_files(model, branch, links, sha256, output_folder):
def check_model_files(self, model, branch, links, sha256, output_folder):
# Validate the checksums
validated = True
for i in range(len(sha256)):
@ -256,22 +267,23 @@ if __name__ == '__main__':
if model is None:
model, branch = select_model_from_default_options()
downloader = ModelDownloader()
# Cleaning up the model/branch names
try:
model, branch = sanitize_model_and_branch_names(model, branch)
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
# Getting the download links from Hugging Face
links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only)
# Getting the output folder
output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
output_folder = downloader.get_output_folder(model, branch, is_lora, base_folder=args.output)
if args.check:
# Check previously downloaded files
check_model_files(model, branch, links, sha256, output_folder)
downloader.check_model_files(model, branch, links, sha256, output_folder)
else:
# Download files
download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)

View File

@ -184,7 +184,8 @@ def count_tokens(text):
def download_model_wrapper(repo_id):
try:
downloader = importlib.import_module("download-model")
downloader_module = importlib.import_module("download-model")
downloader = downloader_module.ModelDownloader()
repo_id_parts = repo_id.split(":")
model = repo_id_parts[0] if len(repo_id_parts) > 0 else repo_id
branch = repo_id_parts[1] if len(repo_id_parts) > 1 else "main"