Downloader: use HF get_token function (#5381)

This commit is contained in:
Anthony Guijarro 2024-01-27 14:13:09 -06:00 committed by GitHub
parent de387069da
commit 828be63f2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,6 +20,7 @@ import requests
import tqdm import tqdm
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
from huggingface_hub import get_token
base = "https://huggingface.co" base = "https://huggingface.co"
@ -32,8 +33,8 @@ class ModelDownloader:
self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
if os.getenv('HF_TOKEN') is not None: if get_token() is not None:
self.session.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'} self.session.headers = {'authorization': f'Bearer {get_token()}'}
def sanitize_model_and_branch_names(self, model, branch): def sanitize_model_and_branch_names(self, model, branch):
if model[-1] == '/': if model[-1] == '/':