Add a retry mechanism to the model downloader (#5943)

This commit is contained in:
oobabooga 2024-04-27 12:25:28 -03:00 committed by GitHub
parent dfdb6fee22
commit 5770e06c48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,10 +15,12 @@ import os
import re
import sys
from pathlib import Path
from time import sleep
import requests
import tqdm
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, RequestException, Timeout
from tqdm.contrib.concurrent import thread_map
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
@ -177,25 +179,30 @@ class ModelDownloader:
return output_folder
def get_single_file(self, url, output_folder, start_from_scratch=False):
session = self.get_session()
filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename
max_retries = 7
attempt = 0
while attempt < max_retries:
attempt += 1
session = self.get_session()
headers = {}
mode = 'wb'
if output_path.exists() and not start_from_scratch:
# Check if the file has already been downloaded completely
r = session.get(url, stream=True, timeout=10)
if output_path.exists() and not start_from_scratch:
# Resume download
r = session.get(url, stream=True, timeout=20)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
return
# Otherwise, resume the download from where it left off
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
mode = 'ab'
with session.get(url, stream=True, headers=headers, timeout=10) as r:
r.raise_for_status() # Do not continue the download if the request was unsuccessful
try:
with session.get(url, stream=True, headers=headers, timeout=30) as r:
r.raise_for_status() # If status is not 2xx, raise an error
total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB
@ -203,7 +210,7 @@ class ModelDownloader:
'total': total_size,
'unit': 'iB',
'unit_scale': True,
'bar_format': '{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}'
'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} {rate_fmt}'
}
if 'COLAB_GPU' in os.environ:
@ -216,12 +223,22 @@ class ModelDownloader:
with tqdm.tqdm(**tqdm_kwargs) as t:
count = 0
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
t.update(len(data))
if total_size != 0 and self.progress_bar is not None:
count += len(data)
self.progress_bar(float(count) / float(total_size), f"{filename}")
break # Exit loop if successful
except (RequestException, ConnectionError, Timeout) as e:
print(f"Error downloading {filename}: {e}.")
print(f"That was attempt {attempt}/{max_retries}.", end=' ')
if attempt < max_retries:
print(f"Retry begins in {2 ** attempt} seconds.")
sleep(2 ** attempt)
else:
print("Failed to download after the maximum number of attempts.")
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4):
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)