diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index a4823670..38f1dfba 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -87,24 +87,24 @@ class GPT4All(): else: model_path = model_path.replace("\\", "\\\\") - if os.path.exists(model_path): - model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\") - if os.path.exists(model_dest): - print("Found model file at ", model_dest) - return model_dest + if not os.path.exists(model_path): + raise ValueError("Invalid model directory: {}".format(model_path)) - # If model file does not exist, download - elif allow_download: - # Make sure valid model filename before attempting download - available_models = GPT4All.list_models() - if model_filename not in (m["filename"] for m in available_models): - raise ValueError(f"Model filename not in model list: {model_filename}") - return GPT4All.download_model(model_filename, model_path) - else: - raise ValueError("Failed to retrieve model") + model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\") + if os.path.exists(model_dest): + print("Found model file at ", model_dest) + return model_dest + + # If model file does not exist, download + elif allow_download: + # Make sure valid model filename before attempting download + available_models = GPT4All.list_models() + if model_filename not in (m["filename"] for m in available_models): + raise ValueError(f"Model filename not in model list: {model_filename}") + return GPT4All.download_model(model_filename, model_path) else: - raise ValueError("Invalid model directory") - + raise ValueError("Failed to retrieve model") + @staticmethod def download_model(model_filename: str, model_path: str) -> str: """