Support specify retry times in download-model.py (#2908)

This commit is contained in:
AN Long 2023-07-05 09:26:30 +08:00 committed by GitHub
parent 70a4d5dbcf
commit be4582be40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,12 +18,16 @@ from pathlib import Path
import requests
import tqdm
from requests.adapters import HTTPAdapter
from tqdm.contrib.concurrent import thread_map
class ModelDownloader:
def __init__(self):
def __init__(self, max_retries):
self.s = requests.Session()
if max_retries:
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
@ -212,6 +216,7 @@ if __name__ == '__main__':
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
parser.add_argument('--max-retries', type=int, default=5, help='Max retries count when get error in download time.')
args = parser.parse_args()
branch = args.branch
@ -221,7 +226,7 @@ if __name__ == '__main__':
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
sys.exit()
downloader = ModelDownloader()
downloader = ModelDownloader(max_retries=args.max_retries)
# Cleaning up the model/branch names
try:
model, branch = downloader.sanitize_model_and_branch_names(model, branch)