mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Add a retry mechanism to the model downloader (#5943)
This commit is contained in:
parent
dfdb6fee22
commit
5770e06c48
@ -15,10 +15,12 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import tqdm
|
import tqdm
|
||||||
from requests.adapters import HTTPAdapter
|
from requests.adapters import HTTPAdapter
|
||||||
|
from requests.exceptions import ConnectionError, RequestException, Timeout
|
||||||
from tqdm.contrib.concurrent import thread_map
|
from tqdm.contrib.concurrent import thread_map
|
||||||
|
|
||||||
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
|
base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co"
|
||||||
@ -177,25 +179,30 @@ class ModelDownloader:
|
|||||||
return output_folder
|
return output_folder
|
||||||
|
|
||||||
def get_single_file(self, url, output_folder, start_from_scratch=False):
|
def get_single_file(self, url, output_folder, start_from_scratch=False):
|
||||||
session = self.get_session()
|
|
||||||
filename = Path(url.rsplit('/', 1)[1])
|
filename = Path(url.rsplit('/', 1)[1])
|
||||||
output_path = output_folder / filename
|
output_path = output_folder / filename
|
||||||
|
|
||||||
|
max_retries = 7
|
||||||
|
attempt = 0
|
||||||
|
while attempt < max_retries:
|
||||||
|
attempt += 1
|
||||||
|
session = self.get_session()
|
||||||
headers = {}
|
headers = {}
|
||||||
mode = 'wb'
|
mode = 'wb'
|
||||||
if output_path.exists() and not start_from_scratch:
|
|
||||||
|
|
||||||
# Check if the file has already been downloaded completely
|
if output_path.exists() and not start_from_scratch:
|
||||||
r = session.get(url, stream=True, timeout=10)
|
# Resume download
|
||||||
|
r = session.get(url, stream=True, timeout=20)
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
if output_path.stat().st_size >= total_size:
|
if output_path.stat().st_size >= total_size:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Otherwise, resume the download from where it left off
|
|
||||||
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
|
||||||
mode = 'ab'
|
mode = 'ab'
|
||||||
|
|
||||||
with session.get(url, stream=True, headers=headers, timeout=10) as r:
|
try:
|
||||||
r.raise_for_status() # Do not continue the download if the request was unsuccessful
|
with session.get(url, stream=True, headers=headers, timeout=30) as r:
|
||||||
|
r.raise_for_status() # If status is not 2xx, raise an error
|
||||||
total_size = int(r.headers.get('content-length', 0))
|
total_size = int(r.headers.get('content-length', 0))
|
||||||
block_size = 1024 * 1024 # 1MB
|
block_size = 1024 * 1024 # 1MB
|
||||||
|
|
||||||
@ -203,7 +210,7 @@ class ModelDownloader:
|
|||||||
'total': total_size,
|
'total': total_size,
|
||||||
'unit': 'iB',
|
'unit': 'iB',
|
||||||
'unit_scale': True,
|
'unit_scale': True,
|
||||||
'bar_format': '{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}'
|
'bar_format': '{l_bar}{bar}| {n_fmt}/{total_fmt} {rate_fmt}'
|
||||||
}
|
}
|
||||||
|
|
||||||
if 'COLAB_GPU' in os.environ:
|
if 'COLAB_GPU' in os.environ:
|
||||||
@ -216,12 +223,22 @@ class ModelDownloader:
|
|||||||
with tqdm.tqdm(**tqdm_kwargs) as t:
|
with tqdm.tqdm(**tqdm_kwargs) as t:
|
||||||
count = 0
|
count = 0
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
t.update(len(data))
|
|
||||||
f.write(data)
|
f.write(data)
|
||||||
|
t.update(len(data))
|
||||||
if total_size != 0 and self.progress_bar is not None:
|
if total_size != 0 and self.progress_bar is not None:
|
||||||
count += len(data)
|
count += len(data)
|
||||||
self.progress_bar(float(count) / float(total_size), f"{filename}")
|
self.progress_bar(float(count) / float(total_size), f"{filename}")
|
||||||
|
|
||||||
|
break # Exit loop if successful
|
||||||
|
except (RequestException, ConnectionError, Timeout) as e:
|
||||||
|
print(f"Error downloading {filename}: {e}.")
|
||||||
|
print(f"That was attempt {attempt}/{max_retries}.", end=' ')
|
||||||
|
if attempt < max_retries:
|
||||||
|
print(f"Retry begins in {2 ** attempt} seconds.")
|
||||||
|
sleep(2 ** attempt)
|
||||||
|
else:
|
||||||
|
print("Failed to download after the maximum number of attempts.")
|
||||||
|
|
||||||
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4):
|
def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4):
|
||||||
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user