From 4f1e96b9e3f9dc1593bed2096a40f68f8dc4d015 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 23 May 2024 20:42:46 -0700 Subject: [PATCH] Downloader: Add --model-dir argument, respect --model-dir in the UI --- download-model.py | 12 ++++++++---- modules/ui_model_menu.py | 8 +++++++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/download-model.py b/download-model.py index 465c76f9..08b23d08 100644 --- a/download-model.py +++ b/download-model.py @@ -167,8 +167,11 @@ class ModelDownloader: is_llamacpp = has_gguf and specific_file is not None return links, sha256, is_lora, is_llamacpp - def get_output_folder(self, model, branch, is_lora, is_llamacpp=False): - base_folder = 'models' if not is_lora else 'loras' + def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None): + if model_dir: + base_folder = model_dir + else: + base_folder = 'models' if not is_lora else 'loras' # If the model is of type GGUF, save directly in the base_folder if is_llamacpp: @@ -304,7 +307,8 @@ if __name__ == '__main__': parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.') parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).') - parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.') + parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.') + parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') parser.add_argument('--max-retries', type=int, default=5, help='Max retries count when get error in download time.') @@ -333,7 +337,7 @@ if __name__ == '__main__': if args.output: output_folder = Path(args.output) else: - output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp) + output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir) if args.check: # Check previously downloaded files diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 76ac7530..6cda7ecd 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -290,7 +290,13 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur return yield ("Getting the output folder") - output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp) + output_folder = downloader.get_output_folder( + model, + branch, + is_lora, + is_llamacpp=is_llamacpp, + model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None + ) if output_folder == Path("models"): output_folder = Path(shared.args.model_dir)