From 91aa5b460ed1f330e35b02fd7f5368912ea6526c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 28 Mar 2023 13:08:38 -0300 Subject: [PATCH] If both .pt and .safetensors are present, download only safetensors --- download-model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/download-model.py b/download-model.py index 25386e5f..dce7e749 100644 --- a/download-model.py +++ b/download-model.py @@ -100,6 +100,7 @@ def get_download_links_from_huggingface(model, branch): links = [] classifications = [] has_pytorch = False + has_pt = False has_safetensors = False is_lora = False while True: @@ -115,7 +116,7 @@ def get_download_links_from_huggingface(model, branch): is_lora = True is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname) - is_safetensors = re.match("model.*\.safetensors", fname) + is_safetensors = re.match(".*\.safetensors", fname) is_pt = re.match(".*\.pt", fname) is_tokenizer = re.match("tokenizer.*\.model", fname) is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer @@ -134,6 +135,7 @@ def get_download_links_from_huggingface(model, branch): has_pytorch = True classifications.append('pytorch') elif is_pt: + has_pt = True classifications.append('pt') cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' @@ -141,9 +143,9 @@ def get_download_links_from_huggingface(model, branch): cursor = cursor.replace(b'=', b'%3D') # If both pytorch and safetensors are available, download safetensors only - if has_pytorch and has_safetensors: + if (has_pytorch or has_pt) and has_safetensors: for i in range(len(classifications)-1, -1, -1): - if classifications[i] == 'pytorch': + if classifications[i] in ['pytorch', 'pt']: links.pop(i) return links, is_lora