Merge pull request #618 from nikita-skakun/optimize-download-model

Improve download-model.py progress bar with multiple threads
This commit is contained in:
oobabooga 2023-03-29 20:54:19 -03:00 committed by GitHub
commit 9104164297
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,13 +10,13 @@ import argparse
import base64 import base64
import datetime import datetime
import json import json
import multiprocessing
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
import requests import requests
import tqdm import tqdm
from tqdm.contrib.concurrent import thread_map
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('MODEL', type=str, default=None, nargs='?') parser.add_argument('MODEL', type=str, default=None, nargs='?')
@ -26,22 +26,15 @@ parser.add_argument('--text-only', action='store_true', help='Only download text
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
args = parser.parse_args() args = parser.parse_args()
def get_file(args): def get_file(url, output_folder):
url = args[0]
output_folder = args[1]
idx = args[2]
tot = args[3]
print(f"Downloading file {idx} of {tot}...")
r = requests.get(url, stream=True) r = requests.get(url, stream=True)
with open(output_folder / Path(url.split('/')[-1]), 'wb') as f: with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f:
total_size = int(r.headers.get('content-length', 0)) total_size = int(r.headers.get('content-length', 0))
block_size = 1024 block_size = 1024
t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
t.update(len(data)) t.update(len(data))
f.write(data) f.write(data)
t.close()
def sanitize_branch_name(branch_name): def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$") pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
@ -152,6 +145,9 @@ def get_download_links_from_huggingface(model, branch):
return links, is_lora return links, is_lora
def download_files(file_list, output_folder, num_threads=8):
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, verbose=False)
if __name__ == '__main__': if __name__ == '__main__':
model = args.MODEL model = args.MODEL
branch = args.branch branch = args.branch
@ -192,7 +188,4 @@ if __name__ == '__main__':
# Downloading the files # Downloading the files
print(f"Downloading the model to {output_folder}") print(f"Downloading the model to {output_folder}")
pool = multiprocessing.Pool(processes=args.threads) download_files(links, output_folder, args.threads)
results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))])
pool.close()
pool.join()