adds a simple cli chat repl (#566)

* adds a simple cli chat repl

* add n thread support and append assistant response
This commit is contained in:
drbh 2023-05-16 13:47:54 -07:00 committed by GitHub
parent 95a4516844
commit d4861030b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 163 additions and 5 deletions

118
gpt4all-bindings/cli/app.py Normal file
View File

@ -0,0 +1,118 @@
import sys
import typer
from typing_extensions import Annotated
from gpt4all import GPT4All
MESSAGES = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello there."},
{"role": "assistant", "content": "Hi, how can I help you?"},
]
SPECIAL_COMMANDS = {
"/reset": lambda messages: messages.clear(),
"/exit": lambda _: sys.exit(),
"/clear": lambda _: print("\n" * 100),
"/help": lambda _: print("Special commands: /reset, /exit, /help and /clear"),
}
VERSION = "0.1.0"
CLI_START_MESSAGE = f"""
Welcome to the GPT4All CLI! Version {VERSION}
Type /help for special commands.
"""
def _cli_override_response_callback(token_id, response):
resp = response.decode("utf-8")
print(resp, end="", flush=True)
return True
# create typer app
app = typer.Typer()
@app.command()
def repl(
model: Annotated[
str,
typer.Option("--model", "-m", help="Model to use for chatbot"),
] = "ggml-gpt4all-j-v1.3-groovy",
n_threads: Annotated[
int,
typer.Option("--n-threads", "-t", help="Number of threads to use for chatbot"),
] = 4,
):
gpt4all_instance = GPT4All(model)
# if threads are passed, set them
if n_threads != 4:
num_threads = gpt4all_instance.model.thread_count()
print(f"\nAdjusted: {num_threads}", end="")
# set number of threads
gpt4all_instance.model.set_thread_count(n_threads)
num_threads = gpt4all_instance.model.thread_count()
print(f" {num_threads} threads", end="", flush=True)
# overwrite _response_callback on model
gpt4all_instance.model._response_callback = _cli_override_response_callback
print(CLI_START_MESSAGE)
while True:
message = input("")
# Check if special command and take action
if message in SPECIAL_COMMANDS:
SPECIAL_COMMANDS[message](MESSAGES)
continue
# if regular message, append to messages
MESSAGES.append({"role": "user", "content": message})
# execute chat completion and ignore the full response since
# we are outputting it incrementally
full_response = gpt4all_instance.chat_completion(
MESSAGES,
# preferential kwargs for chat ux
logits_size=0,
tokens_size=0,
n_past=0,
n_ctx=0,
n_predict=200,
top_k=40,
top_p=0.9,
temp=0.9,
n_batch=9,
repeat_penalty=1.1,
repeat_last_n=64,
context_erase=0.0,
# required kwargs for cli ux (incremental response)
verbose=False,
std_passthrough=True,
)
# record assistant's response to messages
MESSAGES.append(full_response.get("choices")[0].get("message"))
print() # newline before next prompt
@app.command()
def version():
print("gpt4all-cli v0.1.0")
if __name__ == "__main__":
app()

View File

@ -6,6 +6,21 @@ import platform
import re
import sys
class DualOutput:
def __init__(self, stdout, string_io):
self.stdout = stdout
self.string_io = string_io
def write(self, text):
self.stdout.write(text)
self.string_io.write(text)
def flush(self):
# It's a good idea to also define a flush method that flushes both
# outputs, as sys.stdout is expected to have this method.
self.stdout.flush()
self.string_io.flush()
# TODO: provide a config file to make this more robust
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
@ -81,6 +96,15 @@ llmodel.llmodel_prompt.argtypes = [ctypes.c_void_p,
RecalculateCallback,
ctypes.POINTER(LLModelPromptContext)]
llmodel.llmodel_prompt.restype = None
llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32]
llmodel.llmodel_setThreadCount.restype = None
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
llmodel.llmodel_threadCount.restype = ctypes.c_int32
class LLModel:
"""
Base class and universal wrapper for GPT4All language models
@ -125,6 +149,18 @@ class LLModel:
else:
return False
def set_thread_count(self, n_threads):
if not llmodel.llmodel_isModelLoaded(self.model):
raise Exception("Model not loaded")
llmodel.llmodel_setThreadCount(self.model, n_threads)
def thread_count(self):
if not llmodel.llmodel_isModelLoaded(self.model):
raise Exception("Model not loaded")
return llmodel.llmodel_threadCount(self.model)
def generate(self,
prompt: str,
logits_size: int = 0,
@ -138,7 +174,8 @@ class LLModel:
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = .5) -> str:
context_erase: float = .5,
std_passthrough: bool = False) -> str:
"""
Generate response from model from a prompt.
@ -164,6 +201,9 @@ class LLModel:
# Change stdout to StringIO so we can collect response
old_stdout = sys.stdout
collect_response = StringIO()
if std_passthrough:
sys.stdout = DualOutput(old_stdout, collect_response)
else:
sys.stdout = collect_response
context = LLModelPromptContext(
@ -222,7 +262,7 @@ class GPTJModel(LLModel):
self.model = llmodel.llmodel_gptj_create()
def __del__(self):
if self.model is not None:
if self.model is not None and llmodel is not None:
llmodel.llmodel_gptj_destroy(self.model)
super().__del__()
@ -236,7 +276,7 @@ class LlamaModel(LLModel):
self.model = llmodel.llmodel_llama_create()
def __del__(self):
if self.model is not None:
if self.model is not None and llmodel is not None:
llmodel.llmodel_llama_destroy(self.model)
super().__del__()
@ -250,6 +290,6 @@ class MPTModel(LLModel):
self.model = llmodel.llmodel_mpt_create()
def __del__(self):
if self.model is not None:
if self.model is not None and llmodel is not None:
llmodel.llmodel_mpt_destroy(self.model)
super().__del__()