mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Add optional verbosity
This commit is contained in:
parent
e05ee9466a
commit
5e61008424
@ -58,7 +58,7 @@ class GPT4All():
|
||||
return requests.get("https://gpt4all.io/models/models.json").json()
|
||||
|
||||
@staticmethod
|
||||
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True) -> str:
|
||||
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True) -> str:
|
||||
"""
|
||||
Find model file, and if it doesn't exist, download the model.
|
||||
|
||||
@ -67,6 +67,7 @@ class GPT4All():
|
||||
model_path: Path to find model. Default is None in which case path is set to
|
||||
~/.cache/gpt4all/.
|
||||
allow_download: Allow API to download model from gpt4all.io. Default is True.
|
||||
verbose: If True (default), print debug messages.
|
||||
|
||||
Returns:
|
||||
Model file destination.
|
||||
@ -92,6 +93,7 @@ class GPT4All():
|
||||
|
||||
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
||||
if os.path.exists(model_dest):
|
||||
if verbose:
|
||||
print("Found model file at ", model_dest)
|
||||
return model_dest
|
||||
|
||||
@ -106,13 +108,14 @@ class GPT4All():
|
||||
raise ValueError("Failed to retrieve model")
|
||||
|
||||
@staticmethod
|
||||
def download_model(model_filename: str, model_path: str) -> str:
|
||||
def download_model(model_filename: str, model_path: str, verbose: bool) -> str:
|
||||
"""
|
||||
Download model from https://gpt4all.io.
|
||||
|
||||
Args:
|
||||
model_filename: Filename of model (with .bin extension).
|
||||
model_path: Path to download model to.
|
||||
verbose: If True (default), print debug messages.
|
||||
|
||||
Returns:
|
||||
Model file destination.
|
||||
@ -137,6 +140,7 @@ class GPT4All():
|
||||
file.write(data)
|
||||
except Exception:
|
||||
if os.path.exists(download_path):
|
||||
if verbose:
|
||||
print('Cleaning up the interrupted download...')
|
||||
os.remove(download_path)
|
||||
raise
|
||||
@ -150,7 +154,8 @@ class GPT4All():
|
||||
# Sleep for a little bit so OS can remove file lock
|
||||
time.sleep(2)
|
||||
|
||||
print("Model downloaded at: " + download_path)
|
||||
if verbose:
|
||||
print("Model downloaded at: ", download_path)
|
||||
return download_path
|
||||
|
||||
def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user