Downloader: Add --model-dir argument, respect --model-dir in the UI

This commit is contained in:
oobabooga 2024-05-23 20:42:46 -07:00
parent ad54d524f7
commit 4f1e96b9e3
2 changed files with 15 additions and 5 deletions

View File

@ -167,8 +167,11 @@ class ModelDownloader:
is_llamacpp = has_gguf and specific_file is not None
return links, sha256, is_lora, is_llamacpp
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False):
base_folder = 'models' if not is_lora else 'loras'
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None):
if model_dir:
base_folder = model_dir
else:
base_folder = 'models' if not is_lora else 'loras'
# If the model is of type GGUF, save directly in the base_folder
if is_llamacpp:
@ -304,7 +307,8 @@ if __name__ == '__main__':
parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.')
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).')
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
parser.add_argument('--max-retries', type=int, default=5, help='Max retries count when get error in download time.')
@ -333,7 +337,7 @@ if __name__ == '__main__':
if args.output:
output_folder = Path(args.output)
else:
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp)
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir)
if args.check:
# Check previously downloaded files

View File

@ -290,7 +290,13 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
return
yield ("Getting the output folder")
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp)
output_folder = downloader.get_output_folder(
model,
branch,
is_lora,
is_llamacpp=is_llamacpp,
model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None
)
if output_folder == Path("models"):
output_folder = Path(shared.args.model_dir)