text-generation-webui/download-model.py
2023-07-11 18:46:59 -03:00

249 lines
10 KiB
Python

'''
Downloads models from Hugging Face to models/username_modelname.
Example:
python download-model.py facebook/opt-1.3b
'''
import argparse
import base64
import datetime
import hashlib
import json
import os
import re
import sys
from pathlib import Path
import requests
import tqdm
from requests.adapters import HTTPAdapter
from tqdm.contrib.concurrent import thread_map
class ModelDownloader:
def __init__(self, max_retries = 5):
self.s = requests.Session()
if max_retries:
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
def sanitize_model_and_branch_names(self, model, branch):
if model[-1] == '/':
model = model[:-1]
if branch is None:
branch = "main"
else:
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if not pattern.match(branch):
raise ValueError(
"Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
return model, branch
def get_download_links_from_huggingface(self, model, branch, text_only=False):
base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}"
cursor = b""
links = []
sha256 = []
classifications = []
has_pytorch = False
has_pt = False
# has_ggml = False
has_safetensors = False
is_lora = False
while True:
url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "")
r = self.s.get(url, timeout=20)
r.raise_for_status()
content = r.content
dict = json.loads(content)
if len(dict) == 0:
break
for i in range(len(dict)):
fname = dict[i]['path']
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
is_lora = True
is_pytorch = re.match("(pytorch|adapter|gptq)_model.*\.bin", fname)
is_safetensors = re.match(".*\.safetensors", fname)
is_pt = re.match(".*\.pt", fname)
is_ggml = re.match(".*ggml.*\.bin", fname)
is_tokenizer = re.match("(tokenizer|ice|spiece).*\.model", fname)
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
if 'lfs' in dict[i]:
sha256.append([fname, dict[i]['lfs']['oid']])
if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text')
continue
if not text_only:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
if is_safetensors:
has_safetensors = True
classifications.append('safetensors')
elif is_pytorch:
has_pytorch = True
classifications.append('pytorch')
elif is_pt:
has_pt = True
classifications.append('pt')
elif is_ggml:
# has_ggml = True
classifications.append('ggml')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D')
# If both pytorch and safetensors are available, download safetensors only
if (has_pytorch or has_pt) and has_safetensors:
for i in range(len(classifications) - 1, -1, -1):
if classifications[i] in ['pytorch', 'pt']:
links.pop(i)
return links, sha256, is_lora
def get_output_folder(self, model, branch, is_lora, base_folder=None):
if base_folder is None:
base_folder = 'models' if not is_lora else 'loras'
output_folder = f"{'_'.join(model.split('/')[-2:])}"
if branch != 'main':
output_folder += f'_{branch}'
output_folder = Path(base_folder) / output_folder
return output_folder
def get_single_file(self, url, output_folder, start_from_scratch=False):
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
headers = {}
mode = 'wb'
if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely
r = self.s.get(url, stream=True, timeout=20)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'
with self.s.get(url, stream=True, headers=headers, timeout=20) as r:
r.raise_for_status() # Do not continue the download if the request was unsuccessful
total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB
with open(output_path, mode) as f:
with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
count = 0
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
if total_size != 0 and self.progress_bar is not None:
count += len(data)
self.progress_bar(float(count) / float(total_size), f"Downloading {filename}")
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1):
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1):
self.progress_bar = progress_bar
# Creating the folder and writing the metadata
output_folder.mkdir(parents=True, exist_ok=True)
metadata = f'url: https://huggingface.co/{model}\n' \
f'branch: {branch}\n' \
f'download date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n'
sha256_str = '\n'.join([f' {item[1]} {item[0]}' for item in sha256])
if sha256_str:
metadata += f'sha256sum:\n{sha256_str}'
metadata += '\n'
(output_folder / 'huggingface-metadata.txt').write_text(metadata)
# Downloading the files
print(f"Downloading the model to {output_folder}")
self.start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
def check_model_files(self, model, branch, links, sha256, output_folder):
# Validate the checksums
validated = True
for i in range(len(sha256)):
fpath = (output_folder / sha256[i][0])
if not fpath.exists():
print(f"The following file is missing: {fpath}")
validated = False
continue
with open(output_folder / sha256[i][0], "rb") as f:
bytes = f.read()
file_hash = hashlib.sha256(bytes).hexdigest()
if file_hash != sha256[i][1]:
print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}')
validated = False
else:
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
if validated:
print('[+] Validated checksums of all model files!')
else:
print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?')
parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
parser.add_argument('--max-retries', type=int, default=5, help='Max retries count when get error in download time.')
args = parser.parse_args()
branch = args.branch
model = args.MODEL
if model is None:
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
sys.exit()
downloader = ModelDownloader(max_retries=args.max_retries)
# Cleaning up the model/branch names
try:
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
# Getting the download links from Hugging Face
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only)
# Getting the output folder
output_folder = downloader.get_output_folder(model, branch, is_lora, base_folder=args.output)
if args.check:
# Check previously downloaded files
downloader.check_model_files(model, branch, links, sha256, output_folder)
else:
# Download files
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)