Downloader: don't assume that huggingface_hub is installed

This commit is contained in:
oobabooga 2024-01-30 09:14:11 -08:00
parent 89f6036e98
commit ee65f4f014

View File

@ -20,7 +20,6 @@ import requests
import tqdm import tqdm
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
from huggingface_hub import get_token
base = "https://huggingface.co" base = "https://huggingface.co"
@ -31,10 +30,18 @@ class ModelDownloader:
if max_retries: if max_retries:
self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
if get_token() is not None:
self.session.headers = {'authorization': f'Bearer {get_token()}'} try:
from huggingface_hub import get_token
token = get_token()
except ImportError:
token = os.getenv("HF_TOKEN")
if token is not None:
self.session.headers = {'authorization': f'Bearer {token}'}
def sanitize_model_and_branch_names(self, model, branch): def sanitize_model_and_branch_names(self, model, branch):
if model[-1] == '/': if model[-1] == '/':