From 4d8e10100686a680026f76e9854be90ef279a797 Mon Sep 17 00:00:00 2001 From: Nikita Skakun Date: Tue, 28 Mar 2023 14:24:23 -0700 Subject: [PATCH 1/4] Refactor download process to use multiprocessing The previous implementation used threads to download files in parallel, which could lead to performance issues due to the Global Interpreter Lock (GIL). This commit refactors the download process to use multiprocessing instead, which allows for true parallelism across multiple CPUs. This results in significantly faster downloads, particularly for large models. --- download-model.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/download-model.py b/download-model.py index dce7e749..48ae449e 100644 --- a/download-model.py +++ b/download-model.py @@ -17,13 +17,6 @@ from pathlib import Path import requests import tqdm -parser = argparse.ArgumentParser() -parser.add_argument('MODEL', type=str, default=None, nargs='?') -parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') -parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') -parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') -args = parser.parse_args() - def get_file(args): url = args[0] output_folder = args[1] @@ -150,7 +143,22 @@ def get_download_links_from_huggingface(model, branch): return links, is_lora +def download_files(file_list, output_folder, num_processes=8): + with multiprocessing.Pool(processes=num_processes) as pool: + args = [(url, output_folder, idx+1, len(file_list)) for idx, url in enumerate(file_list)] + for _ in tqdm.tqdm(pool.imap_unordered(get_file, args), total=len(args)): + pass + pool.close() + pool.join() + if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('MODEL', type=str, default=None, nargs='?') + parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') + parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') + parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') + args = parser.parse_args() + model = args.MODEL branch = args.branch if model is None: @@ -179,7 +187,4 @@ if __name__ == '__main__': # Downloading the files print(f"Downloading the model to {output_folder}") - pool = multiprocessing.Pool(processes=args.threads) - results = pool.map(get_file, [[links[i], output_folder, i+1, len(links)] for i in range(len(links))]) - pool.close() - pool.join() + download_files(links, output_folder, num_processes=args.threads) From ff515ec2fe693cee7ea1d86d5e3f5bf0397aca2f Mon Sep 17 00:00:00 2001 From: Nikita Skakun Date: Tue, 28 Mar 2023 18:29:20 -0700 Subject: [PATCH 2/4] Improve progress bar visual style This commit reverts the performance improvements of the previous commit for for improved visual style of multithreaded progress bars. The style of the progress bar has been modified to take up the same amount of size to align them. --- download-model.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/download-model.py b/download-model.py index 48ae449e..2954f4b1 100644 --- a/download-model.py +++ b/download-model.py @@ -16,23 +16,17 @@ from pathlib import Path import requests import tqdm +from tqdm.contrib.concurrent import thread_map -def get_file(args): - url = args[0] - output_folder = args[1] - idx = args[2] - tot = args[3] - - print(f"Downloading file {idx} of {tot}...") +def get_file(url, output_folder): r = requests.get(url, stream=True) - with open(output_folder / Path(url.split('/')[-1]), 'wb') as f: + with open(output_folder / Path(url.rsplit('/', 1)[1]), 'wb') as f: total_size = int(r.headers.get('content-length', 0)) block_size = 1024 - t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) - for data in r.iter_content(block_size): - t.update(len(data)) - f.write(data) - t.close() + with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t: + for data in r.iter_content(block_size): + t.update(len(data)) + f.write(data) def sanitize_branch_name(branch_name): pattern = re.compile(r"^[a-zA-Z0-9._-]+$") @@ -143,13 +137,8 @@ def get_download_links_from_huggingface(model, branch): return links, is_lora -def download_files(file_list, output_folder, num_processes=8): - with multiprocessing.Pool(processes=num_processes) as pool: - args = [(url, output_folder, idx+1, len(file_list)) for idx, url in enumerate(file_list)] - for _ in tqdm.tqdm(pool.imap_unordered(get_file, args), total=len(args)): - pass - pool.close() - pool.join() +def download_files(file_list, output_folder, num_threads=8): + thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, verbose=False) if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -187,4 +176,4 @@ if __name__ == '__main__': # Downloading the files print(f"Downloading the model to {output_folder}") - download_files(links, output_folder, num_processes=args.threads) + download_files(links, output_folder, args.threads) From aaa218a10216483b48cec068d73a1f891efb55ec Mon Sep 17 00:00:00 2001 From: Nikita Skakun Date: Tue, 28 Mar 2023 18:32:49 -0700 Subject: [PATCH 3/4] Remove unused import. --- download-model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/download-model.py b/download-model.py index 2954f4b1..a2d3a6d6 100644 --- a/download-model.py +++ b/download-model.py @@ -9,7 +9,6 @@ python download-model.py facebook/opt-1.3b import argparse import base64 import json -import multiprocessing import re import sys from pathlib import Path From 37754164eb44338e9f9bf7642a49cc6f0a9802b9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 29 Mar 2023 20:47:36 -0300 Subject: [PATCH 4/4] Move argparse --- download-model.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/download-model.py b/download-model.py index f67055ba..dc6f3a9d 100644 --- a/download-model.py +++ b/download-model.py @@ -149,13 +149,6 @@ def download_files(file_list, output_folder, num_threads=8): thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, verbose=False) if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('MODEL', type=str, default=None, nargs='?') - parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') - parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') - parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') - args = parser.parse_args() - model = args.MODEL branch = args.branch if model is None: