mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Support specify retry times in download-model.py (#2908)
This commit is contained in:
parent
70a4d5dbcf
commit
be4582be40
@ -18,12 +18,16 @@ from pathlib import Path
|
||||
|
||||
import requests
|
||||
import tqdm
|
||||
from requests.adapters import HTTPAdapter
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
|
||||
|
||||
class ModelDownloader:
|
||||
def __init__(self):
|
||||
def __init__(self, max_retries):
|
||||
self.s = requests.Session()
|
||||
if max_retries:
|
||||
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
|
||||
self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
|
||||
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
|
||||
self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
|
||||
|
||||
@ -212,6 +216,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||
parser.add_argument('--max-retries', type=int, default=5, help='Max retries count when get error in download time.')
|
||||
args = parser.parse_args()
|
||||
|
||||
branch = args.branch
|
||||
@ -221,7 +226,7 @@ if __name__ == '__main__':
|
||||
print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').")
|
||||
sys.exit()
|
||||
|
||||
downloader = ModelDownloader()
|
||||
downloader = ModelDownloader(max_retries=args.max_retries)
|
||||
# Cleaning up the model/branch names
|
||||
try:
|
||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||
|
Loading…
Reference in New Issue
Block a user