From 9215e281ba6dea9855e49419458d566043012ca2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 3 Feb 2023 18:57:12 -0300 Subject: [PATCH] Add --threads option to the download script --- download-model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/download-model.py b/download-model.py index 220302d9..49b20eb3 100644 --- a/download-model.py +++ b/download-model.py @@ -18,12 +18,16 @@ import re parser = argparse.ArgumentParser() parser.add_argument('MODEL', type=str) parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') +parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') args = parser.parse_args() def get_file(args): url = args[0] output_folder = args[1] + idx = args[2] + tot = args[3] + print(f"Downloading file {idx} of {tot}...") r = requests.get(url, stream=True) with open(output_folder / Path(url.split('/')[-1]), 'wb') as f: total_size = int(r.headers.get('content-length', 0)) @@ -77,8 +81,8 @@ if __name__ == '__main__': downloads.append(f'https://huggingface.co/{href}') # Downloading the files - print(f"Downloading the model to {output_folder}...") - pool = multiprocessing.Pool(processes=4) - results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))]) + print(f"Downloading the model to {output_folder}") + pool = multiprocessing.Pool(processes=args.threads) + results = pool.map(get_file, [[downloads[i], output_folder, i+1, len(downloads)] for i in range(len(downloads))]) pool.close() pool.join()