From b4dbbd148574bcf7fc42c896387caa18363144a0 Mon Sep 17 00:00:00 2001 From: 385olt <385olt@gmail.com> Date: Thu, 20 Jul 2023 00:36:49 +0200 Subject: [PATCH] Python bindings: Custom callbacks, chat session improvement, refactoring (#1145) * Added the following features: \n 1) Now prompt_model uses the positional argument callback to return the response tokens. \n 2) Due to the callback argument of prompt_model, prompt_model_streaming only manages the queue and threading now, which reduces duplication of the code. \n 3) Added optional verbose argument to prompt_model which prints out the prompt that is passed to the model. \n 4) Chat sessions can now have a header, i.e. an instruction before the transcript of the conversation. The header is set at the creation of the chat session context. \n 5) generate function now accepts an optional callback. \n 6) When streaming and using chat session, the user doesn't need to save assistant's messages by himself. This is done automatically. * added _empty_response_callback so I don't have to check if callback is None * added docs * now if the callback stop generation, the last token is ignored * fixed type hints, reimplemented chat session header as a system prompt, minor refactoring, docs: removed section about manual update of chat session for streaming * forgot to add some type hints! * keep the config of the model in GPT4All class which is taken from models.json if the download is allowed * During chat sessions, the model-specific systemPrompt and promptTemplate are applied. * implemented the changes * Fixed typing. Now the user can set a prompt template that will be applied even outside of a chat session. The template can also have multiple placeholders that can be filled by passing a dictionary to the generate function * reversed some changes concerning the prompt templates and their functionality * fixed some type hints, changed list[float] to List[Float] * fixed type hints, changed List[Float] to List[float] * fix typo in the comment: Pepare => Prepare --------- Signed-off-by: 385olt <385olt@gmail.com> --- .../python/docs/gpt4all_python.md | 18 -- gpt4all-bindings/python/gpt4all/gpt4all.py | 198 ++++++++++++++---- gpt4all-bindings/python/gpt4all/pyllmodel.py | 143 ++++++------- 3 files changed, 213 insertions(+), 146 deletions(-) diff --git a/gpt4all-bindings/python/docs/gpt4all_python.md b/gpt4all-bindings/python/docs/gpt4all_python.md index 0d179b06..ad3e1c55 100644 --- a/gpt4all-bindings/python/docs/gpt4all_python.md +++ b/gpt4all-bindings/python/docs/gpt4all_python.md @@ -91,22 +91,4 @@ To interact with GPT4All responses as the model generates, use the `streaming = [' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0'] ``` -#### Streaming and Chat Sessions -When streaming tokens in a chat session, you must manually handle collection and updating of the chat history. - -```python -from gpt4all import GPT4All -model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin") - -with model.chat_session(): - tokens = list(model.generate(prompt='hello', top_k=1, streaming=True)) - model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)}) - - tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True)) - model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)}) - - print(model.current_chat_session) -``` - -### API documentation ::: gpt4all.gpt4all.GPT4All diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index bf409b51..ed642731 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -5,7 +5,7 @@ import os import time from contextlib import contextmanager from pathlib import Path -from typing import Dict, Iterable, List, Union, Optional +from typing import Any, Dict, Iterable, List, Union, Optional import requests from tqdm import tqdm @@ -13,7 +13,17 @@ from tqdm import tqdm from . import pyllmodel # TODO: move to config -DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") +DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace( + "\\", "\\\\" +) + +DEFAULT_MODEL_CONFIG = { + "systemPrompt": "", + "promptTemplate": "### Human: \n{0}\n### Assistant:\n", +} + +ConfigType = Dict[str,str] +MessageType = Dict[str, str] class Embed4All: """ @@ -34,7 +44,7 @@ class Embed4All: def embed( self, text: str - ) -> list[float]: + ) -> List[float]: """ Generate an embedding. @@ -74,17 +84,20 @@ class GPT4All: self.model_type = model_type self.model = pyllmodel.LLModel() # Retrieve model and download if allowed - model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download) - self.model.load_model(model_dest) + self.config: ConfigType = self.retrieve_model( + model_name, model_path=model_path, allow_download=allow_download + ) + self.model.load_model(self.config["path"]) # Set n_threads if n_threads is not None: self.model.set_thread_count(n_threads) - self._is_chat_session_activated = False - self.current_chat_session = [] + self._is_chat_session_activated: bool = False + self.current_chat_session: List[MessageType] = empty_chat_session() + self._current_prompt_template: str = "{0}" @staticmethod - def list_models() -> Dict: + def list_models() -> List[ConfigType]: """ Fetch model list from https://gpt4all.io/models/models.json. @@ -95,8 +108,11 @@ class GPT4All: @staticmethod def retrieve_model( - model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True - ) -> str: + model_name: str, + model_path: Optional[str] = None, + allow_download: bool = True, + verbose: bool = True, + ) -> ConfigType: """ Find model file, and if it doesn't exist, download the model. @@ -108,11 +124,25 @@ class GPT4All: verbose: If True (default), print debug messages. Returns: - Model file destination. + Model config. """ model_filename = append_bin_suffix_if_missing(model_name) + # get the config for the model + config: ConfigType = DEFAULT_MODEL_CONFIG + if allow_download: + available_models = GPT4All.list_models() + + for m in available_models: + if model_filename == m["filename"]: + config.update(m) + config["systemPrompt"] = config["systemPrompt"].strip() + config["promptTemplate"] = config["promptTemplate"].replace( + "%1", "{0}", 1 + ) # change to Python-style formatting + break + # Validate download directory if model_path is None: try: @@ -131,31 +161,34 @@ class GPT4All: model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\") if os.path.exists(model_dest): + config.pop("url", None) + config["path"] = model_dest if verbose: print("Found model file at ", model_dest) - return model_dest # If model file does not exist, download elif allow_download: # Make sure valid model filename before attempting download - available_models = GPT4All.list_models() - selected_model = None - for m in available_models: - if model_filename == m['filename']: - selected_model = m - break - - if selected_model is None: + if "url" not in config: raise ValueError(f"Model filename not in model list: {model_filename}") - url = selected_model.pop('url', None) + url = config.pop("url", None) - return GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url) + config["path"] = GPT4All.download_model( + model_filename, model_path, verbose=verbose, url=url + ) else: raise ValueError("Failed to retrieve model") + return config + @staticmethod - def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str: + def download_model( + model_filename: str, + model_path: str, + verbose: bool = True, + url: Optional[str] = None, + ) -> str: """ Download model from https://gpt4all.io. @@ -191,7 +224,7 @@ class GPT4All: except Exception: if os.path.exists(download_path): if verbose: - print('Cleaning up the interrupted download...') + print("Cleaning up the interrupted download...") os.remove(download_path) raise @@ -218,7 +251,8 @@ class GPT4All: n_batch: int = 8, n_predict: Optional[int] = None, streaming: bool = False, - ) -> Union[str, Iterable]: + callback: pyllmodel.ResponseCallbackType = pyllmodel.empty_response_callback, + ) -> Union[str, Iterable[str]]: """ Generate outputs from any GPT4All model. @@ -233,12 +267,14 @@ class GPT4All: n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements. n_predict: Equivalent to max_tokens, exists for backwards compatibility. streaming: If True, this method will instead return a generator that yields tokens as the model generates them. + callback: A function with arguments token_id:int and response:str, which receives the tokens from the model as they are generated and stops the generation by returning False. Returns: Either the entire completion or a generator that yields the completion token by token. """ - generate_kwargs = dict( - prompt=prompt, + + # Preparing the model request + generate_kwargs: Dict[str, Any] = dict( temp=temp, top_k=top_k, top_p=top_p, @@ -249,42 +285,87 @@ class GPT4All: ) if self._is_chat_session_activated: + generate_kwargs["reset_context"] = len(self.current_chat_session) == 1 # check if there is only one message, i.e. system prompt self.current_chat_session.append({"role": "user", "content": prompt}) - generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session[-1:]) - generate_kwargs['reset_context'] = len(self.current_chat_session) == 1 + + prompt = self._format_chat_prompt_template( + messages = self.current_chat_session[-1:], + default_prompt_header = self.current_chat_session[0]["content"] if generate_kwargs["reset_context"] else "", + ) else: - generate_kwargs['reset_context'] = True + generate_kwargs["reset_context"] = True - if streaming: - return self.model.prompt_model_streaming(**generate_kwargs) - - output = self.model.prompt_model(**generate_kwargs) + # Prepare the callback, process the model response + output_collector: List[MessageType] + output_collector = [{"content": ""}] # placeholder for the self.current_chat_session if chat session is not activated if self._is_chat_session_activated: - self.current_chat_session.append({"role": "assistant", "content": output}) + self.current_chat_session.append({"role": "assistant", "content": ""}) + output_collector = self.current_chat_session - return output + def _callback_wrapper( + callback: pyllmodel.ResponseCallbackType, + output_collector: List[MessageType], + ) -> pyllmodel.ResponseCallbackType: + + def _callback(token_id: int, response: str) -> bool: + nonlocal callback, output_collector + + output_collector[-1]["content"] += response + + return callback(token_id, response) + + return _callback + + # Send the request to the model + if streaming: + return self.model.prompt_model_streaming( + prompt=prompt, + callback=_callback_wrapper(callback, output_collector), + **generate_kwargs, + ) + + self.model.prompt_model( + prompt=prompt, + callback=_callback_wrapper(callback, output_collector), + **generate_kwargs, + ) + + return output_collector[-1]["content"] @contextmanager - def chat_session(self): - ''' + def chat_session( + self, + system_prompt: str = "", + prompt_template: str = "", + ): + """ Context manager to hold an inference optimized chat session with a GPT4All model. - ''' + + Args: + system_prompt: An initial instruction for the model. + prompt_template: Template for the prompts with {0} being replaced by the user message. + """ # Code to acquire resource, e.g.: self._is_chat_session_activated = True - self.current_chat_session = [] + self.current_chat_session = empty_chat_session(system_prompt or self.config["systemPrompt"]) + self._current_prompt_template = prompt_template or self.config["promptTemplate"] try: yield self finally: # Code to release resource, e.g.: self._is_chat_session_activated = False - self.current_chat_session = [] + self.current_chat_session = empty_chat_session() + self._current_prompt_template = "{0}" def _format_chat_prompt_template( - self, messages: List[Dict], default_prompt_header=True, default_prompt_footer=True + self, + messages: List[MessageType], + default_prompt_header: str = "", + default_prompt_footer: str = "", ) -> str: """ - Helper method for building a prompt using template from list of messages. + Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message. Args: messages: List of dictionaries. Each dictionary should have a "role" key @@ -296,19 +377,44 @@ class GPT4All: Returns: Formatted prompt. """ - full_prompt = "" + + if isinstance(default_prompt_header, bool): + import warnings + + warnings.warn( + "Using True/False for the 'default_prompt_header' is deprecated. Use a string instead.", + DeprecationWarning, + ) + default_prompt_header = "" + + if isinstance(default_prompt_footer, bool): + import warnings + + warnings.warn( + "Using True/False for the 'default_prompt_footer' is deprecated. Use a string instead.", + DeprecationWarning, + ) + default_prompt_footer = "" + + full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else "" for message in messages: if message["role"] == "user": - user_message = "### Human: \n" + message["content"] + "\n### Assistant:\n" + user_message = self._current_prompt_template.format(message["content"]) full_prompt += user_message if message["role"] == "assistant": - assistant_message = message["content"] + '\n' + assistant_message = message["content"] + "\n" full_prompt += assistant_message + full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else "" + return full_prompt +def empty_chat_session(system_prompt: str = "") -> List[MessageType]: + return [{"role": "system", "content": system_prompt}] + + def append_bin_suffix_if_missing(model_name): if not model_name.endswith(".bin"): model_name += ".bin" diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index 91395f53..14f35626 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -6,26 +6,19 @@ import re import subprocess import sys import threading -from typing import Iterable +import logging +from typing import Iterable, Callable, List import pkg_resources - -class DualStreamProcessor: - def __init__(self, stream=None): - self.stream = stream - self.output = "" - - def write(self, text): - if self.stream is not None: - self.stream.write(text) - self.stream.flush() - self.output += text +logger: logging.Logger = logging.getLogger(__name__) # TODO: provide a config file to make this more robust LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\") -MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\") +MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace( + "\\", "\\\\" +) def load_llmodel_library(): @@ -43,9 +36,9 @@ def load_llmodel_library(): c_lib_ext = get_c_shared_lib_extension() - llmodel_file = "libllmodel" + '.' + c_lib_ext + llmodel_file = "libllmodel" + "." + c_lib_ext - llmodel_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llmodel_file))).replace( + llmodel_dir = str(pkg_resources.resource_filename("gpt4all", os.path.join(LLMODEL_PATH, llmodel_file))).replace( "\\", "\\\\" ) @@ -134,7 +127,15 @@ llmodel.llmodel_set_implementation_search_path.restype = None llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p] llmodel.llmodel_threadCount.restype = ctypes.c_int32 -llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode('utf-8')) +llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode("utf-8")) + + +ResponseCallbackType = Callable[[int, str], bool] +RawResponseCallbackType = Callable[[int, bytes], bool] + + +def empty_response_callback(token_id: int, response: str) -> bool: + return True class LLModel: @@ -250,9 +251,10 @@ class LLModel: def generate_embedding( self, text: str - ) -> list[float]: + ) -> List[float]: if not text: raise ValueError("Text must not be None or empty") + embedding_size = ctypes.c_size_t() c_text = ctypes.c_char_p(text.encode('utf-8')) embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size)) @@ -263,6 +265,7 @@ class LLModel: def prompt_model( self, prompt: str, + callback: ResponseCallbackType, n_predict: int = 4096, top_k: int = 40, top_p: float = 0.9, @@ -272,8 +275,7 @@ class LLModel: repeat_last_n: int = 10, context_erase: float = 0.75, reset_context: bool = False, - streaming=False, - ) -> str: + ): """ Generate response from model from a prompt. @@ -281,26 +283,24 @@ class LLModel: ---------- prompt: str Question, task, or conversation for model to respond to - streaming: bool - Stream response to stdout + callback(token_id:int, response:str): bool + The model sends response tokens to callback Returns ------- - Model response str + None """ - prompt_bytes = prompt.encode('utf-8') + logger.info( + "LLModel.prompt_model -- prompt:\n" + + "%s\n" + + "===/LLModel.prompt_model -- prompt/===", + prompt, + ) + + prompt_bytes = prompt.encode("utf-8") prompt_ptr = ctypes.c_char_p(prompt_bytes) - old_stdout = sys.stdout - - stream_processor = DualStreamProcessor() - - if streaming: - stream_processor.stream = sys.stdout - - sys.stdout = stream_processor - self._set_context( n_predict=n_predict, top_k=top_k, @@ -317,56 +317,37 @@ class LLModel: self.model, prompt_ptr, PromptCallback(self._prompt_callback), - ResponseCallback(self._response_callback), + ResponseCallback(self._callback_decoder(callback)), RecalculateCallback(self._recalculate_callback), self.context, ) - # Revert to old stdout - sys.stdout = old_stdout - # Force new line - return stream_processor.output - def prompt_model_streaming( self, prompt: str, - n_predict: int = 4096, - top_k: int = 40, - top_p: float = 0.9, - temp: float = 0.1, - n_batch: int = 8, - repeat_penalty: float = 1.2, - repeat_last_n: int = 10, - context_erase: float = 0.75, - reset_context: bool = False, - ) -> Iterable: + callback: ResponseCallbackType = empty_response_callback, + **kwargs + ) -> Iterable[str]: # Symbol to terminate from generator TERMINATING_SYMBOL = object() output_queue = queue.Queue() - prompt_bytes = prompt.encode('utf-8') - prompt_ptr = ctypes.c_char_p(prompt_bytes) - - self._set_context( - n_predict=n_predict, - top_k=top_k, - top_p=top_p, - temp=temp, - n_batch=n_batch, - repeat_penalty=repeat_penalty, - repeat_last_n=repeat_last_n, - context_erase=context_erase, - reset_context=reset_context, - ) - # Put response tokens into an output queue - def _generator_response_callback(token_id, response): - output_queue.put(response.decode('utf-8', 'replace')) - return True + def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType: + def _generator_callback(token_id: int, response: str): + nonlocal callback - def run_llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context): - llmodel.llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context) + if callback(token_id, response): + output_queue.put(response) + return True + + return False + + return _generator_callback + + def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs): + self.prompt_model(prompt, callback, **kwargs) output_queue.put(TERMINATING_SYMBOL) # Kick off llmodel_prompt in separate thread so we can return generator @@ -374,13 +355,10 @@ class LLModel: thread = threading.Thread( target=run_llmodel_prompt, args=( - self.model, - prompt_ptr, - PromptCallback(self._prompt_callback), - ResponseCallback(_generator_response_callback), - RecalculateCallback(self._recalculate_callback), - self.context, + prompt, + _generator_callback_wrapper(callback) ), + kwargs=kwargs, ) thread.start() @@ -391,18 +369,19 @@ class LLModel: break yield response + def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType: + def _raw_callback(token_id: int, response: bytes) -> bool: + nonlocal callback + return callback(token_id, response.decode("utf-8", "replace")) + + return _raw_callback + # Empty prompt callback @staticmethod - def _prompt_callback(token_id): - return True - - # Empty response callback method that just prints response to be collected - @staticmethod - def _response_callback(token_id, response): - sys.stdout.write(response.decode('utf-8', 'replace')) + def _prompt_callback(token_id: int) -> bool: return True # Empty recalculate callback @staticmethod - def _recalculate_callback(is_recalculating): + def _recalculate_callback(is_recalculating: bool) -> bool: return is_recalculating