diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 5787d293..26564980 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -169,7 +169,8 @@ class GPT4All(): messages: List[Dict], default_prompt_header: bool = True, default_prompt_footer: bool = True, - verbose: bool = True) -> str: + verbose: bool = True, + **generate_kwargs) -> str: """ Format list of message dictionaries into a prompt and call model generate on prompt. Returns a response dictionary with metadata and @@ -201,7 +202,7 @@ class GPT4All(): if verbose: print(full_prompt) - response = self.model.generate(full_prompt) + response = self.model.generate(full_prompt, **generate_kwargs) if verbose: print(response) @@ -293,7 +294,9 @@ class GPT4All(): ] MPT_MODELS = [ - "ggml-mpt-7b-base.bin" + "ggml-mpt-7b-base.bin", + "ggml-mpt-7b-chat.bin", + "ggml-mpt-7b-instruct.bin" ] if model_name in GPTJ_MODELS: