Add optional verbosity

This commit is contained in:
Konstantin Gukov 2023-05-26 09:53:53 +02:00 committed by Richard Guo
parent e05ee9466a
commit 5e61008424

View File

@ -58,7 +58,7 @@ class GPT4All():
return requests.get("https://gpt4all.io/models/models.json").json()
@staticmethod
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True) -> str:
def retrieve_model(model_name: str, model_path: str = None, allow_download: bool = True, verbose: bool = True) -> str:
"""
Find model file, and if it doesn't exist, download the model.
@ -67,6 +67,7 @@ class GPT4All():
model_path: Path to find model. Default is None in which case path is set to
~/.cache/gpt4all/.
allow_download: Allow API to download model from gpt4all.io. Default is True.
verbose: If True (default), print debug messages.
Returns:
Model file destination.
@ -92,6 +93,7 @@ class GPT4All():
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
if os.path.exists(model_dest):
if verbose:
print("Found model file at ", model_dest)
return model_dest
@ -106,13 +108,14 @@ class GPT4All():
raise ValueError("Failed to retrieve model")
@staticmethod
def download_model(model_filename: str, model_path: str) -> str:
def download_model(model_filename: str, model_path: str, verbose: bool) -> str:
"""
Download model from https://gpt4all.io.
Args:
model_filename: Filename of model (with .bin extension).
model_path: Path to download model to.
verbose: If True (default), print debug messages.
Returns:
Model file destination.
@ -137,6 +140,7 @@ class GPT4All():
file.write(data)
except Exception:
if os.path.exists(download_path):
if verbose:
print('Cleaning up the interrupted download...')
os.remove(download_path)
raise
@ -150,7 +154,8 @@ class GPT4All():
# Sleep for a little bit so OS can remove file lock
time.sleep(2)
print("Model downloaded at: " + download_path)
if verbose:
print("Model downloaded at: ", download_path)
return download_path
def generate(self, prompt: str, streaming: bool = True, **generate_kwargs) -> str: