diff --git a/download-model.py b/download-model.py index b36865d7..c448292c 100644 --- a/download-model.py +++ b/download-model.py @@ -47,7 +47,7 @@ class ModelDownloader: return model, branch - def get_download_links_from_huggingface(self, model, branch, text_only=False): + def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None): base = "https://huggingface.co" page = f"/api/models/{model}/tree/{branch}" cursor = b"" @@ -73,6 +73,9 @@ class ModelDownloader: for i in range(len(dict)): fname = dict[i]['path'] + if specific_file is not None and fname != specific_file: + continue + if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')): is_lora = True @@ -126,12 +129,16 @@ class ModelDownloader: if classifications[i] == 'ggml': links.pop(i) - return links, sha256, is_lora + return links, sha256, is_lora, ((has_ggml or has_gguf) and specific_file is not None) - def get_output_folder(self, model, branch, is_lora, base_folder=None): + def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, base_folder=None): if base_folder is None: base_folder = 'models' if not is_lora else 'loras' + # If the model is of type GGUF or GGML, save directly in the base_folder + if is_llamacpp: + return Path(base_folder) + output_folder = f"{'_'.join(model.split('/')[-2:])}" if branch != 'main': output_folder += f'_{branch}' @@ -173,7 +180,7 @@ class ModelDownloader: def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=1): thread_map(lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) - def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1): + def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=1, specific_file=None): self.progress_bar = progress_bar # Creating the folder and writing the metadata @@ -189,8 +196,11 @@ class ModelDownloader: metadata += '\n' (output_folder / 'huggingface-metadata.txt').write_text(metadata) - # Downloading the files - print(f"Downloading the model to {output_folder}") + if specific_file: + print(f"Downloading {specific_file} to {output_folder}") + else: + print(f"Downloading the model to {output_folder}") + self.start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads) def check_model_files(self, model, branch, links, sha256, output_folder): @@ -226,6 +236,7 @@ if __name__ == '__main__': parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') + parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).') parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') @@ -234,28 +245,29 @@ if __name__ == '__main__': branch = args.branch model = args.MODEL + specific_file = args.specific_file if model is None: print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').") sys.exit() downloader = ModelDownloader(max_retries=args.max_retries) - # Cleaning up the model/branch names + # Clean up the model/branch names try: model, branch = downloader.sanitize_model_and_branch_names(model, branch) except ValueError as err_branch: print(f"Error: {err_branch}") sys.exit() - # Getting the download links from Hugging Face - links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only) + # Get the download links from Hugging Face + links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=args.text_only, specific_file=specific_file) - # Getting the output folder - output_folder = downloader.get_output_folder(model, branch, is_lora, base_folder=args.output) + # Get the output folder + output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, base_folder=args.output) if args.check: # Check previously downloaded files downloader.check_model_files(model, branch, links, sha256, output_folder) else: # Download files - downloader.download_model_files(model, branch, links, sha256, output_folder, threads=args.threads) + downloader.download_model_files(model, branch, links, sha256, output_folder, specific_file=specific_file, threads=args.threads)