mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Python bindings: Custom callbacks, chat session improvement, refactoring (#1145)
* Added the following features: \n 1) Now prompt_model uses the positional argument callback to return the response tokens. \n 2) Due to the callback argument of prompt_model, prompt_model_streaming only manages the queue and threading now, which reduces duplication of the code. \n 3) Added optional verbose argument to prompt_model which prints out the prompt that is passed to the model. \n 4) Chat sessions can now have a header, i.e. an instruction before the transcript of the conversation. The header is set at the creation of the chat session context. \n 5) generate function now accepts an optional callback. \n 6) When streaming and using chat session, the user doesn't need to save assistant's messages by himself. This is done automatically. * added _empty_response_callback so I don't have to check if callback is None * added docs * now if the callback stop generation, the last token is ignored * fixed type hints, reimplemented chat session header as a system prompt, minor refactoring, docs: removed section about manual update of chat session for streaming * forgot to add some type hints! * keep the config of the model in GPT4All class which is taken from models.json if the download is allowed * During chat sessions, the model-specific systemPrompt and promptTemplate are applied. * implemented the changes * Fixed typing. Now the user can set a prompt template that will be applied even outside of a chat session. The template can also have multiple placeholders that can be filled by passing a dictionary to the generate function * reversed some changes concerning the prompt templates and their functionality * fixed some type hints, changed list[float] to List[Float] * fixed type hints, changed List[Float] to List[float] * fix typo in the comment: Pepare => Prepare --------- Signed-off-by: 385olt <385olt@gmail.com>
This commit is contained in:
parent
5f0aaf8bdb
commit
b4dbbd1485
@ -91,22 +91,4 @@ To interact with GPT4All responses as the model generates, use the `streaming =
|
||||
[' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0']
|
||||
```
|
||||
|
||||
#### Streaming and Chat Sessions
|
||||
When streaming tokens in a chat session, you must manually handle collection and updating of the chat history.
|
||||
|
||||
```python
|
||||
from gpt4all import GPT4All
|
||||
model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
||||
|
||||
with model.chat_session():
|
||||
tokens = list(model.generate(prompt='hello', top_k=1, streaming=True))
|
||||
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
|
||||
|
||||
tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True))
|
||||
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
|
||||
|
||||
print(model.current_chat_session)
|
||||
```
|
||||
|
||||
### API documentation
|
||||
::: gpt4all.gpt4all.GPT4All
|
||||
|
@ -5,7 +5,7 @@ import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Union, Optional
|
||||
from typing import Any, Dict, Iterable, List, Union, Optional
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
@ -13,7 +13,17 @@ from tqdm import tqdm
|
||||
from . import pyllmodel
|
||||
|
||||
# TODO: move to config
|
||||
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
|
||||
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace(
|
||||
"\\", "\\\\"
|
||||
)
|
||||
|
||||
DEFAULT_MODEL_CONFIG = {
|
||||
"systemPrompt": "",
|
||||
"promptTemplate": "### Human: \n{0}\n### Assistant:\n",
|
||||
}
|
||||
|
||||
ConfigType = Dict[str,str]
|
||||
MessageType = Dict[str, str]
|
||||
|
||||
class Embed4All:
|
||||
"""
|
||||
@ -34,7 +44,7 @@ class Embed4All:
|
||||
def embed(
|
||||
self,
|
||||
text: str
|
||||
) -> list[float]:
|
||||
) -> List[float]:
|
||||
"""
|
||||
Generate an embedding.
|
||||
|
||||
@ -74,17 +84,20 @@ class GPT4All:
|
||||
self.model_type = model_type
|
||||
self.model = pyllmodel.LLModel()
|
||||
# Retrieve model and download if allowed
|
||||
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
||||
self.model.load_model(model_dest)
|
||||
self.config: ConfigType = self.retrieve_model(
|
||||
model_name, model_path=model_path, allow_download=allow_download
|
||||
)
|
||||
self.model.load_model(self.config["path"])
|
||||
# Set n_threads
|
||||
if n_threads is not None:
|
||||
self.model.set_thread_count(n_threads)
|
||||
|
||||
self._is_chat_session_activated = False
|
||||
self.current_chat_session = []
|
||||
self._is_chat_session_activated: bool = False
|
||||
self.current_chat_session: List[MessageType] = empty_chat_session()
|
||||
self._current_prompt_template: str = "{0}"
|
||||
|
||||
@staticmethod
|
||||
def list_models() -> Dict:
|
||||
def list_models() -> List[ConfigType]:
|
||||
"""
|
||||
Fetch model list from https://gpt4all.io/models/models.json.
|
||||
|
||||
@ -95,8 +108,11 @@ class GPT4All:
|
||||
|
||||
@staticmethod
|
||||
def retrieve_model(
|
||||
model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True
|
||||
) -> str:
|
||||
model_name: str,
|
||||
model_path: Optional[str] = None,
|
||||
allow_download: bool = True,
|
||||
verbose: bool = True,
|
||||
) -> ConfigType:
|
||||
"""
|
||||
Find model file, and if it doesn't exist, download the model.
|
||||
|
||||
@ -108,11 +124,25 @@ class GPT4All:
|
||||
verbose: If True (default), print debug messages.
|
||||
|
||||
Returns:
|
||||
Model file destination.
|
||||
Model config.
|
||||
"""
|
||||
|
||||
model_filename = append_bin_suffix_if_missing(model_name)
|
||||
|
||||
# get the config for the model
|
||||
config: ConfigType = DEFAULT_MODEL_CONFIG
|
||||
if allow_download:
|
||||
available_models = GPT4All.list_models()
|
||||
|
||||
for m in available_models:
|
||||
if model_filename == m["filename"]:
|
||||
config.update(m)
|
||||
config["systemPrompt"] = config["systemPrompt"].strip()
|
||||
config["promptTemplate"] = config["promptTemplate"].replace(
|
||||
"%1", "{0}", 1
|
||||
) # change to Python-style formatting
|
||||
break
|
||||
|
||||
# Validate download directory
|
||||
if model_path is None:
|
||||
try:
|
||||
@ -131,31 +161,34 @@ class GPT4All:
|
||||
|
||||
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
||||
if os.path.exists(model_dest):
|
||||
config.pop("url", None)
|
||||
config["path"] = model_dest
|
||||
if verbose:
|
||||
print("Found model file at ", model_dest)
|
||||
return model_dest
|
||||
|
||||
# If model file does not exist, download
|
||||
elif allow_download:
|
||||
# Make sure valid model filename before attempting download
|
||||
available_models = GPT4All.list_models()
|
||||
|
||||
selected_model = None
|
||||
for m in available_models:
|
||||
if model_filename == m['filename']:
|
||||
selected_model = m
|
||||
break
|
||||
|
||||
if selected_model is None:
|
||||
if "url" not in config:
|
||||
raise ValueError(f"Model filename not in model list: {model_filename}")
|
||||
url = selected_model.pop('url', None)
|
||||
url = config.pop("url", None)
|
||||
|
||||
return GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url)
|
||||
config["path"] = GPT4All.download_model(
|
||||
model_filename, model_path, verbose=verbose, url=url
|
||||
)
|
||||
else:
|
||||
raise ValueError("Failed to retrieve model")
|
||||
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str:
|
||||
def download_model(
|
||||
model_filename: str,
|
||||
model_path: str,
|
||||
verbose: bool = True,
|
||||
url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Download model from https://gpt4all.io.
|
||||
|
||||
@ -191,7 +224,7 @@ class GPT4All:
|
||||
except Exception:
|
||||
if os.path.exists(download_path):
|
||||
if verbose:
|
||||
print('Cleaning up the interrupted download...')
|
||||
print("Cleaning up the interrupted download...")
|
||||
os.remove(download_path)
|
||||
raise
|
||||
|
||||
@ -218,7 +251,8 @@ class GPT4All:
|
||||
n_batch: int = 8,
|
||||
n_predict: Optional[int] = None,
|
||||
streaming: bool = False,
|
||||
) -> Union[str, Iterable]:
|
||||
callback: pyllmodel.ResponseCallbackType = pyllmodel.empty_response_callback,
|
||||
) -> Union[str, Iterable[str]]:
|
||||
"""
|
||||
Generate outputs from any GPT4All model.
|
||||
|
||||
@ -233,12 +267,14 @@ class GPT4All:
|
||||
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
|
||||
n_predict: Equivalent to max_tokens, exists for backwards compatibility.
|
||||
streaming: If True, this method will instead return a generator that yields tokens as the model generates them.
|
||||
callback: A function with arguments token_id:int and response:str, which receives the tokens from the model as they are generated and stops the generation by returning False.
|
||||
|
||||
Returns:
|
||||
Either the entire completion or a generator that yields the completion token by token.
|
||||
"""
|
||||
generate_kwargs = dict(
|
||||
prompt=prompt,
|
||||
|
||||
# Preparing the model request
|
||||
generate_kwargs: Dict[str, Any] = dict(
|
||||
temp=temp,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
@ -249,42 +285,87 @@ class GPT4All:
|
||||
)
|
||||
|
||||
if self._is_chat_session_activated:
|
||||
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1 # check if there is only one message, i.e. system prompt
|
||||
self.current_chat_session.append({"role": "user", "content": prompt})
|
||||
generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session[-1:])
|
||||
generate_kwargs['reset_context'] = len(self.current_chat_session) == 1
|
||||
|
||||
prompt = self._format_chat_prompt_template(
|
||||
messages = self.current_chat_session[-1:],
|
||||
default_prompt_header = self.current_chat_session[0]["content"] if generate_kwargs["reset_context"] else "",
|
||||
)
|
||||
else:
|
||||
generate_kwargs['reset_context'] = True
|
||||
generate_kwargs["reset_context"] = True
|
||||
|
||||
if streaming:
|
||||
return self.model.prompt_model_streaming(**generate_kwargs)
|
||||
|
||||
output = self.model.prompt_model(**generate_kwargs)
|
||||
# Prepare the callback, process the model response
|
||||
output_collector: List[MessageType]
|
||||
output_collector = [{"content": ""}] # placeholder for the self.current_chat_session if chat session is not activated
|
||||
|
||||
if self._is_chat_session_activated:
|
||||
self.current_chat_session.append({"role": "assistant", "content": output})
|
||||
self.current_chat_session.append({"role": "assistant", "content": ""})
|
||||
output_collector = self.current_chat_session
|
||||
|
||||
return output
|
||||
def _callback_wrapper(
|
||||
callback: pyllmodel.ResponseCallbackType,
|
||||
output_collector: List[MessageType],
|
||||
) -> pyllmodel.ResponseCallbackType:
|
||||
|
||||
def _callback(token_id: int, response: str) -> bool:
|
||||
nonlocal callback, output_collector
|
||||
|
||||
output_collector[-1]["content"] += response
|
||||
|
||||
return callback(token_id, response)
|
||||
|
||||
return _callback
|
||||
|
||||
# Send the request to the model
|
||||
if streaming:
|
||||
return self.model.prompt_model_streaming(
|
||||
prompt=prompt,
|
||||
callback=_callback_wrapper(callback, output_collector),
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
self.model.prompt_model(
|
||||
prompt=prompt,
|
||||
callback=_callback_wrapper(callback, output_collector),
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return output_collector[-1]["content"]
|
||||
|
||||
@contextmanager
|
||||
def chat_session(self):
|
||||
'''
|
||||
def chat_session(
|
||||
self,
|
||||
system_prompt: str = "",
|
||||
prompt_template: str = "",
|
||||
):
|
||||
"""
|
||||
Context manager to hold an inference optimized chat session with a GPT4All model.
|
||||
'''
|
||||
|
||||
Args:
|
||||
system_prompt: An initial instruction for the model.
|
||||
prompt_template: Template for the prompts with {0} being replaced by the user message.
|
||||
"""
|
||||
# Code to acquire resource, e.g.:
|
||||
self._is_chat_session_activated = True
|
||||
self.current_chat_session = []
|
||||
self.current_chat_session = empty_chat_session(system_prompt or self.config["systemPrompt"])
|
||||
self._current_prompt_template = prompt_template or self.config["promptTemplate"]
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
# Code to release resource, e.g.:
|
||||
self._is_chat_session_activated = False
|
||||
self.current_chat_session = []
|
||||
self.current_chat_session = empty_chat_session()
|
||||
self._current_prompt_template = "{0}"
|
||||
|
||||
def _format_chat_prompt_template(
|
||||
self, messages: List[Dict], default_prompt_header=True, default_prompt_footer=True
|
||||
self,
|
||||
messages: List[MessageType],
|
||||
default_prompt_header: str = "",
|
||||
default_prompt_footer: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Helper method for building a prompt using template from list of messages.
|
||||
Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message.
|
||||
|
||||
Args:
|
||||
messages: List of dictionaries. Each dictionary should have a "role" key
|
||||
@ -296,19 +377,44 @@ class GPT4All:
|
||||
Returns:
|
||||
Formatted prompt.
|
||||
"""
|
||||
full_prompt = ""
|
||||
|
||||
if isinstance(default_prompt_header, bool):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Using True/False for the 'default_prompt_header' is deprecated. Use a string instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
default_prompt_header = ""
|
||||
|
||||
if isinstance(default_prompt_footer, bool):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Using True/False for the 'default_prompt_footer' is deprecated. Use a string instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
default_prompt_footer = ""
|
||||
|
||||
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
user_message = "### Human: \n" + message["content"] + "\n### Assistant:\n"
|
||||
user_message = self._current_prompt_template.format(message["content"])
|
||||
full_prompt += user_message
|
||||
if message["role"] == "assistant":
|
||||
assistant_message = message["content"] + '\n'
|
||||
assistant_message = message["content"] + "\n"
|
||||
full_prompt += assistant_message
|
||||
|
||||
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
|
||||
|
||||
return full_prompt
|
||||
|
||||
|
||||
def empty_chat_session(system_prompt: str = "") -> List[MessageType]:
|
||||
return [{"role": "system", "content": system_prompt}]
|
||||
|
||||
|
||||
def append_bin_suffix_if_missing(model_name):
|
||||
if not model_name.endswith(".bin"):
|
||||
model_name += ".bin"
|
||||
|
@ -6,26 +6,19 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from typing import Iterable
|
||||
import logging
|
||||
from typing import Iterable, Callable, List
|
||||
|
||||
import pkg_resources
|
||||
|
||||
|
||||
class DualStreamProcessor:
|
||||
def __init__(self, stream=None):
|
||||
self.stream = stream
|
||||
self.output = ""
|
||||
|
||||
def write(self, text):
|
||||
if self.stream is not None:
|
||||
self.stream.write(text)
|
||||
self.stream.flush()
|
||||
self.output += text
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: provide a config file to make this more robust
|
||||
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
|
||||
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\")
|
||||
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace(
|
||||
"\\", "\\\\"
|
||||
)
|
||||
|
||||
|
||||
def load_llmodel_library():
|
||||
@ -43,9 +36,9 @@ def load_llmodel_library():
|
||||
|
||||
c_lib_ext = get_c_shared_lib_extension()
|
||||
|
||||
llmodel_file = "libllmodel" + '.' + c_lib_ext
|
||||
llmodel_file = "libllmodel" + "." + c_lib_ext
|
||||
|
||||
llmodel_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llmodel_file))).replace(
|
||||
llmodel_dir = str(pkg_resources.resource_filename("gpt4all", os.path.join(LLMODEL_PATH, llmodel_file))).replace(
|
||||
"\\", "\\\\"
|
||||
)
|
||||
|
||||
@ -134,7 +127,15 @@ llmodel.llmodel_set_implementation_search_path.restype = None
|
||||
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
||||
|
||||
llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode('utf-8'))
|
||||
llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode("utf-8"))
|
||||
|
||||
|
||||
ResponseCallbackType = Callable[[int, str], bool]
|
||||
RawResponseCallbackType = Callable[[int, bytes], bool]
|
||||
|
||||
|
||||
def empty_response_callback(token_id: int, response: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class LLModel:
|
||||
@ -250,9 +251,10 @@ class LLModel:
|
||||
def generate_embedding(
|
||||
self,
|
||||
text: str
|
||||
) -> list[float]:
|
||||
) -> List[float]:
|
||||
if not text:
|
||||
raise ValueError("Text must not be None or empty")
|
||||
|
||||
embedding_size = ctypes.c_size_t()
|
||||
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
||||
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
||||
@ -263,6 +265,7 @@ class LLModel:
|
||||
def prompt_model(
|
||||
self,
|
||||
prompt: str,
|
||||
callback: ResponseCallbackType,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.9,
|
||||
@ -272,8 +275,7 @@ class LLModel:
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
streaming=False,
|
||||
) -> str:
|
||||
):
|
||||
"""
|
||||
Generate response from model from a prompt.
|
||||
|
||||
@ -281,26 +283,24 @@ class LLModel:
|
||||
----------
|
||||
prompt: str
|
||||
Question, task, or conversation for model to respond to
|
||||
streaming: bool
|
||||
Stream response to stdout
|
||||
callback(token_id:int, response:str): bool
|
||||
The model sends response tokens to callback
|
||||
|
||||
Returns
|
||||
-------
|
||||
Model response str
|
||||
None
|
||||
"""
|
||||
|
||||
prompt_bytes = prompt.encode('utf-8')
|
||||
logger.info(
|
||||
"LLModel.prompt_model -- prompt:\n"
|
||||
+ "%s\n"
|
||||
+ "===/LLModel.prompt_model -- prompt/===",
|
||||
prompt,
|
||||
)
|
||||
|
||||
prompt_bytes = prompt.encode("utf-8")
|
||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||
|
||||
old_stdout = sys.stdout
|
||||
|
||||
stream_processor = DualStreamProcessor()
|
||||
|
||||
if streaming:
|
||||
stream_processor.stream = sys.stdout
|
||||
|
||||
sys.stdout = stream_processor
|
||||
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
@ -317,56 +317,37 @@ class LLModel:
|
||||
self.model,
|
||||
prompt_ptr,
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(self._response_callback),
|
||||
ResponseCallback(self._callback_decoder(callback)),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
self.context,
|
||||
)
|
||||
|
||||
# Revert to old stdout
|
||||
sys.stdout = old_stdout
|
||||
# Force new line
|
||||
return stream_processor.output
|
||||
|
||||
def prompt_model_streaming(
|
||||
self,
|
||||
prompt: str,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.9,
|
||||
temp: float = 0.1,
|
||||
n_batch: int = 8,
|
||||
repeat_penalty: float = 1.2,
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
) -> Iterable:
|
||||
callback: ResponseCallbackType = empty_response_callback,
|
||||
**kwargs
|
||||
) -> Iterable[str]:
|
||||
# Symbol to terminate from generator
|
||||
TERMINATING_SYMBOL = object()
|
||||
|
||||
output_queue = queue.Queue()
|
||||
|
||||
prompt_bytes = prompt.encode('utf-8')
|
||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temp,
|
||||
n_batch=n_batch,
|
||||
repeat_penalty=repeat_penalty,
|
||||
repeat_last_n=repeat_last_n,
|
||||
context_erase=context_erase,
|
||||
reset_context=reset_context,
|
||||
)
|
||||
|
||||
# Put response tokens into an output queue
|
||||
def _generator_response_callback(token_id, response):
|
||||
output_queue.put(response.decode('utf-8', 'replace'))
|
||||
return True
|
||||
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
|
||||
def _generator_callback(token_id: int, response: str):
|
||||
nonlocal callback
|
||||
|
||||
def run_llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context):
|
||||
llmodel.llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context)
|
||||
if callback(token_id, response):
|
||||
output_queue.put(response)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return _generator_callback
|
||||
|
||||
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
||||
self.prompt_model(prompt, callback, **kwargs)
|
||||
output_queue.put(TERMINATING_SYMBOL)
|
||||
|
||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||
@ -374,13 +355,10 @@ class LLModel:
|
||||
thread = threading.Thread(
|
||||
target=run_llmodel_prompt,
|
||||
args=(
|
||||
self.model,
|
||||
prompt_ptr,
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(_generator_response_callback),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
self.context,
|
||||
prompt,
|
||||
_generator_callback_wrapper(callback)
|
||||
),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
thread.start()
|
||||
|
||||
@ -391,18 +369,19 @@ class LLModel:
|
||||
break
|
||||
yield response
|
||||
|
||||
def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType:
|
||||
def _raw_callback(token_id: int, response: bytes) -> bool:
|
||||
nonlocal callback
|
||||
return callback(token_id, response.decode("utf-8", "replace"))
|
||||
|
||||
return _raw_callback
|
||||
|
||||
# Empty prompt callback
|
||||
@staticmethod
|
||||
def _prompt_callback(token_id):
|
||||
return True
|
||||
|
||||
# Empty response callback method that just prints response to be collected
|
||||
@staticmethod
|
||||
def _response_callback(token_id, response):
|
||||
sys.stdout.write(response.decode('utf-8', 'replace'))
|
||||
def _prompt_callback(token_id: int) -> bool:
|
||||
return True
|
||||
|
||||
# Empty recalculate callback
|
||||
@staticmethod
|
||||
def _recalculate_callback(is_recalculating):
|
||||
def _recalculate_callback(is_recalculating: bool) -> bool:
|
||||
return is_recalculating
|
||||
|
Loading…
Reference in New Issue
Block a user