diff --git a/download-model.py b/download-model.py index 87fd6261..9ee77906 100644 --- a/download-model.py +++ b/download-model.py @@ -18,12 +18,16 @@ from pathlib import Path import requests import tqdm +from requests.adapters import HTTPAdapter from tqdm.contrib.concurrent import thread_map class ModelDownloader: - def __init__(self): + def __init__(self, max_retries): self.s = requests.Session() + if max_retries: + self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) + self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) @@ -212,6 +216,7 @@ if __name__ == '__main__': parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') + parser.add_argument('--max-retries', type=int, default=5, help='Max retries count when get error in download time.') args = parser.parse_args() branch = args.branch @@ -221,7 +226,7 @@ if __name__ == '__main__': print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').") sys.exit() - downloader = ModelDownloader() + downloader = ModelDownloader(max_retries=args.max_retries) # Cleaning up the model/branch names try: model, branch = downloader.sanitize_model_and_branch_names(model, branch)