matrix_chatgpt_bot/v3.py
2023-03-05 22:24:15 +08:00

193 lines
5.9 KiB
Python

"""
A simple wrapper for the official ChatGPT API
https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
"""
import json
import os
import requests
import tiktoken
ENGINE = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo"
ENCODER = tiktoken.get_encoding("gpt2")
class Chatbot:
"""
Official ChatGPT API
"""
def __init__(
self,
api_key: str = None,
engine: str = None,
proxy: str = None,
max_tokens: int = 4096,
temperature: float = 0.5,
top_p: float = 1.0,
reply_count: int = 1,
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
) -> None:
"""
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
"""
self.engine = engine or ENGINE
self.session = requests.Session()
self.api_key = api_key
# self.proxy = proxy
# if self.proxy:
# proxies = {
# "http": self.proxy,
# "https": self.proxy,
# }
# self.session.proxies = proxies
self.conversation: dict = {
"default": [
{
"role": "system",
"content": system_prompt,
},
],
}
self.system_prompt = system_prompt
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.reply_count = reply_count
initial_conversation = "\n".join(
[x["content"] for x in self.conversation["default"]],
)
if len(ENCODER.encode(initial_conversation)) > self.max_tokens:
raise Exception("System prompt is too long")
def add_to_conversation(
self, message: str, role: str, convo_id: str = "default"
) -> None:
"""
Add a message to the conversation
"""
self.conversation[convo_id].append({"role": role, "content": message})
def __truncate_conversation(self, convo_id: str = "default") -> None:
"""
Truncate the conversation
"""
while True:
full_conversation = "".join(
message["role"] + ": " + message["content"] + "\n"
for message in self.conversation[convo_id]
)
if (
len(ENCODER.encode(full_conversation)) > self.max_tokens
and len(self.conversation[convo_id]) > 1
):
# Don't remove the first message
self.conversation[convo_id].pop(1)
else:
break
def get_max_tokens(self, convo_id: str) -> int:
"""
Get max tokens
"""
full_conversation = "".join(
message["role"] + ": " + message["content"] + "\n"
for message in self.conversation[convo_id]
)
return 4000 - len(ENCODER.encode(full_conversation))
def ask_stream(
self,
prompt: str,
role: str = "user",
convo_id: str = "default",
**kwargs,
) -> str:
"""
Ask a question
"""
# Make conversation if it doesn't exist
if convo_id not in self.conversation:
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
self.add_to_conversation(prompt, "user", convo_id=convo_id)
self.__truncate_conversation(convo_id=convo_id)
# Get response
response = self.session.post(
"https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
json={
"model": self.engine,
"messages": self.conversation[convo_id],
"stream": True,
# kwargs
"temperature": kwargs.get("temperature", self.temperature),
"top_p": kwargs.get("top_p", self.top_p),
"n": kwargs.get("n", self.reply_count),
"user": role,
# "max_tokens": self.get_max_tokens(convo_id=convo_id),
},
stream=True,
)
if response.status_code != 200:
raise Exception(
f"Error: {response.status_code} {response.reason} {response.text}",
)
response_role: str = None
full_response: str = ""
for line in response.iter_lines():
if not line:
continue
# Remove "data: "
line = line.decode("utf-8")[6:]
if line == "[DONE]":
break
resp: dict = json.loads(line)
choices = resp.get("choices")
if not choices:
continue
delta = choices[0].get("delta")
if not delta:
continue
if "role" in delta:
response_role = delta["role"]
if "content" in delta:
content = delta["content"]
full_response += content
yield content
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
def ask(
self, prompt: str, role: str = "user", convo_id: str = "default", **kwargs
) -> str:
"""
Non-streaming ask
"""
response = self.ask_stream(
prompt=prompt,
role=role,
convo_id=convo_id,
**kwargs,
)
full_response: str = "".join(response)
return full_response
def rollback(self, n: int = 1, convo_id: str = "default") -> None:
"""
Rollback the conversation
"""
for _ in range(n):
self.conversation[convo_id].pop()
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
"""
Reset the conversation
"""
self.conversation[convo_id] = [
{"role": "system", "content": system_prompt or self.system_prompt},
]