From ee65f4f014b859edcf20a298d9039dc494b57986 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 30 Jan 2024 09:14:11 -0800 Subject: [PATCH] Downloader: don't assume that huggingface_hub is installed --- download-model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/download-model.py b/download-model.py index 5e62036f..82e956d6 100644 --- a/download-model.py +++ b/download-model.py @@ -20,7 +20,6 @@ import requests import tqdm from requests.adapters import HTTPAdapter from tqdm.contrib.concurrent import thread_map -from huggingface_hub import get_token base = "https://huggingface.co" @@ -31,10 +30,18 @@ class ModelDownloader: if max_retries: self.session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) self.session.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) + if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: self.session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) - if get_token() is not None: - self.session.headers = {'authorization': f'Bearer {get_token()}'} + + try: + from huggingface_hub import get_token + token = get_token() + except ImportError: + token = os.getenv("HF_TOKEN") + + if token is not None: + self.session.headers = {'authorization': f'Bearer {token}'} def sanitize_model_and_branch_names(self, model, branch): if model[-1] == '/':