mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
adds a simple cli chat repl (#566)
* adds a simple cli chat repl * add n thread support and append assistant response
This commit is contained in:
parent
95a4516844
commit
d4861030b7
118
gpt4all-bindings/cli/app.py
Normal file
118
gpt4all-bindings/cli/app.py
Normal file
@ -0,0 +1,118 @@
|
||||
import sys
|
||||
import typer
|
||||
|
||||
from typing_extensions import Annotated
|
||||
from gpt4all import GPT4All
|
||||
|
||||
MESSAGES = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello there."},
|
||||
{"role": "assistant", "content": "Hi, how can I help you?"},
|
||||
]
|
||||
|
||||
SPECIAL_COMMANDS = {
|
||||
"/reset": lambda messages: messages.clear(),
|
||||
"/exit": lambda _: sys.exit(),
|
||||
"/clear": lambda _: print("\n" * 100),
|
||||
"/help": lambda _: print("Special commands: /reset, /exit, /help and /clear"),
|
||||
}
|
||||
|
||||
VERSION = "0.1.0"
|
||||
|
||||
CLI_START_MESSAGE = f"""
|
||||
|
||||
██████ ██████ ████████ ██ ██ █████ ██ ██
|
||||
██ ██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||
██ ███ ██████ ██ ███████ ███████ ██ ██
|
||||
██ ██ ██ ██ ██ ██ ██ ██ ██
|
||||
██████ ██ ██ ██ ██ ██ ███████ ███████
|
||||
|
||||
|
||||
Welcome to the GPT4All CLI! Version {VERSION}
|
||||
Type /help for special commands.
|
||||
|
||||
"""
|
||||
|
||||
def _cli_override_response_callback(token_id, response):
|
||||
resp = response.decode("utf-8")
|
||||
print(resp, end="", flush=True)
|
||||
return True
|
||||
|
||||
|
||||
# create typer app
|
||||
app = typer.Typer()
|
||||
|
||||
@app.command()
|
||||
def repl(
|
||||
model: Annotated[
|
||||
str,
|
||||
typer.Option("--model", "-m", help="Model to use for chatbot"),
|
||||
] = "ggml-gpt4all-j-v1.3-groovy",
|
||||
n_threads: Annotated[
|
||||
int,
|
||||
typer.Option("--n-threads", "-t", help="Number of threads to use for chatbot"),
|
||||
] = 4,
|
||||
):
|
||||
gpt4all_instance = GPT4All(model)
|
||||
|
||||
# if threads are passed, set them
|
||||
if n_threads != 4:
|
||||
num_threads = gpt4all_instance.model.thread_count()
|
||||
print(f"\nAdjusted: {num_threads} →", end="")
|
||||
|
||||
# set number of threads
|
||||
gpt4all_instance.model.set_thread_count(n_threads)
|
||||
|
||||
num_threads = gpt4all_instance.model.thread_count()
|
||||
print(f" {num_threads} threads", end="", flush=True)
|
||||
|
||||
|
||||
# overwrite _response_callback on model
|
||||
gpt4all_instance.model._response_callback = _cli_override_response_callback
|
||||
|
||||
print(CLI_START_MESSAGE)
|
||||
|
||||
while True:
|
||||
message = input(" ⇢ ")
|
||||
|
||||
# Check if special command and take action
|
||||
if message in SPECIAL_COMMANDS:
|
||||
SPECIAL_COMMANDS[message](MESSAGES)
|
||||
continue
|
||||
|
||||
# if regular message, append to messages
|
||||
MESSAGES.append({"role": "user", "content": message})
|
||||
|
||||
# execute chat completion and ignore the full response since
|
||||
# we are outputting it incrementally
|
||||
full_response = gpt4all_instance.chat_completion(
|
||||
MESSAGES,
|
||||
# preferential kwargs for chat ux
|
||||
logits_size=0,
|
||||
tokens_size=0,
|
||||
n_past=0,
|
||||
n_ctx=0,
|
||||
n_predict=200,
|
||||
top_k=40,
|
||||
top_p=0.9,
|
||||
temp=0.9,
|
||||
n_batch=9,
|
||||
repeat_penalty=1.1,
|
||||
repeat_last_n=64,
|
||||
context_erase=0.0,
|
||||
# required kwargs for cli ux (incremental response)
|
||||
verbose=False,
|
||||
std_passthrough=True,
|
||||
)
|
||||
# record assistant's response to messages
|
||||
MESSAGES.append(full_response.get("choices")[0].get("message"))
|
||||
print() # newline before next prompt
|
||||
|
||||
|
||||
@app.command()
|
||||
def version():
|
||||
print("gpt4all-cli v0.1.0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
@ -6,6 +6,21 @@ import platform
|
||||
import re
|
||||
import sys
|
||||
|
||||
class DualOutput:
|
||||
def __init__(self, stdout, string_io):
|
||||
self.stdout = stdout
|
||||
self.string_io = string_io
|
||||
|
||||
def write(self, text):
|
||||
self.stdout.write(text)
|
||||
self.string_io.write(text)
|
||||
|
||||
def flush(self):
|
||||
# It's a good idea to also define a flush method that flushes both
|
||||
# outputs, as sys.stdout is expected to have this method.
|
||||
self.stdout.flush()
|
||||
self.string_io.flush()
|
||||
|
||||
# TODO: provide a config file to make this more robust
|
||||
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
|
||||
|
||||
@ -81,6 +96,15 @@ llmodel.llmodel_prompt.argtypes = [ctypes.c_void_p,
|
||||
RecalculateCallback,
|
||||
ctypes.POINTER(LLModelPromptContext)]
|
||||
|
||||
llmodel.llmodel_prompt.restype = None
|
||||
|
||||
llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32]
|
||||
llmodel.llmodel_setThreadCount.restype = None
|
||||
|
||||
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
||||
|
||||
|
||||
class LLModel:
|
||||
"""
|
||||
Base class and universal wrapper for GPT4All language models
|
||||
@ -125,6 +149,18 @@ class LLModel:
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def set_thread_count(self, n_threads):
|
||||
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||
raise Exception("Model not loaded")
|
||||
llmodel.llmodel_setThreadCount(self.model, n_threads)
|
||||
|
||||
def thread_count(self):
|
||||
if not llmodel.llmodel_isModelLoaded(self.model):
|
||||
raise Exception("Model not loaded")
|
||||
return llmodel.llmodel_threadCount(self.model)
|
||||
|
||||
|
||||
def generate(self,
|
||||
prompt: str,
|
||||
logits_size: int = 0,
|
||||
@ -138,7 +174,8 @@ class LLModel:
|
||||
n_batch: int = 8,
|
||||
repeat_penalty: float = 1.2,
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = .5) -> str:
|
||||
context_erase: float = .5,
|
||||
std_passthrough: bool = False) -> str:
|
||||
"""
|
||||
Generate response from model from a prompt.
|
||||
|
||||
@ -164,6 +201,9 @@ class LLModel:
|
||||
# Change stdout to StringIO so we can collect response
|
||||
old_stdout = sys.stdout
|
||||
collect_response = StringIO()
|
||||
if std_passthrough:
|
||||
sys.stdout = DualOutput(old_stdout, collect_response)
|
||||
else:
|
||||
sys.stdout = collect_response
|
||||
|
||||
context = LLModelPromptContext(
|
||||
@ -222,7 +262,7 @@ class GPTJModel(LLModel):
|
||||
self.model = llmodel.llmodel_gptj_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None:
|
||||
if self.model is not None and llmodel is not None:
|
||||
llmodel.llmodel_gptj_destroy(self.model)
|
||||
super().__del__()
|
||||
|
||||
@ -236,7 +276,7 @@ class LlamaModel(LLModel):
|
||||
self.model = llmodel.llmodel_llama_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None:
|
||||
if self.model is not None and llmodel is not None:
|
||||
llmodel.llmodel_llama_destroy(self.model)
|
||||
super().__del__()
|
||||
|
||||
@ -250,6 +290,6 @@ class MPTModel(LLModel):
|
||||
self.model = llmodel.llmodel_mpt_create()
|
||||
|
||||
def __del__(self):
|
||||
if self.model is not None:
|
||||
if self.model is not None and llmodel is not None:
|
||||
llmodel.llmodel_mpt_destroy(self.model)
|
||||
super().__del__()
|
||||
|
Loading…
Reference in New Issue
Block a user