2023-03-05 09:07:25 -05:00
|
|
|
import json
|
|
|
|
import os
|
2023-04-09 22:52:18 -04:00
|
|
|
from typing import AsyncGenerator
|
2023-03-05 09:07:25 -05:00
|
|
|
|
2023-04-09 22:52:18 -04:00
|
|
|
import httpx
|
2023-03-05 09:07:25 -05:00
|
|
|
import requests
|
|
|
|
import tiktoken
|
|
|
|
|
|
|
|
|
|
|
|
class Chatbot:
|
|
|
|
"""
|
|
|
|
Official ChatGPT API
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
2023-04-09 22:52:18 -04:00
|
|
|
api_key: str,
|
|
|
|
engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo",
|
2023-03-05 09:07:25 -05:00
|
|
|
proxy: str = None,
|
2023-04-09 22:52:18 -04:00
|
|
|
timeout: float = None,
|
|
|
|
max_tokens: int = None,
|
2023-03-05 09:07:25 -05:00
|
|
|
temperature: float = 0.5,
|
|
|
|
top_p: float = 1.0,
|
2023-04-09 22:52:18 -04:00
|
|
|
presence_penalty: float = 0.0,
|
|
|
|
frequency_penalty: float = 0.0,
|
2023-03-05 09:07:25 -05:00
|
|
|
reply_count: int = 1,
|
|
|
|
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
|
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
|
|
|
"""
|
2023-04-09 22:52:18 -04:00
|
|
|
self.engine: str = engine
|
|
|
|
self.api_key: str = api_key
|
|
|
|
self.system_prompt: str = system_prompt
|
|
|
|
self.max_tokens: int = max_tokens or (
|
|
|
|
31000 if engine == "gpt-4-32k" else 7000 if engine == "gpt-4" else 4000
|
|
|
|
)
|
|
|
|
self.truncate_limit: int = (
|
|
|
|
30500 if engine == "gpt-4-32k" else 6500 if engine == "gpt-4" else 3500
|
|
|
|
)
|
|
|
|
self.temperature: float = temperature
|
|
|
|
self.top_p: float = top_p
|
|
|
|
self.presence_penalty: float = presence_penalty
|
|
|
|
self.frequency_penalty: float = frequency_penalty
|
|
|
|
self.reply_count: int = reply_count
|
|
|
|
self.timeout: float = timeout
|
|
|
|
self.proxy = proxy
|
2023-03-05 09:07:25 -05:00
|
|
|
self.session = requests.Session()
|
2023-04-09 22:52:18 -04:00
|
|
|
self.session.proxies.update(
|
|
|
|
{
|
|
|
|
"http": proxy,
|
|
|
|
"https": proxy,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
proxy = (
|
|
|
|
proxy or os.environ.get("all_proxy") or os.environ.get("ALL_PROXY") or None
|
|
|
|
)
|
|
|
|
|
|
|
|
if proxy:
|
|
|
|
if "socks5h" not in proxy:
|
|
|
|
self.aclient = httpx.AsyncClient(
|
|
|
|
follow_redirects=True,
|
|
|
|
proxies=proxy,
|
|
|
|
timeout=timeout,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.aclient = httpx.AsyncClient(
|
|
|
|
follow_redirects=True,
|
|
|
|
proxies=proxy,
|
|
|
|
timeout=timeout,
|
|
|
|
)
|
|
|
|
|
|
|
|
self.conversation: dict[str, list[dict]] = {
|
2023-03-05 09:07:25 -05:00
|
|
|
"default": [
|
|
|
|
{
|
|
|
|
"role": "system",
|
|
|
|
"content": system_prompt,
|
|
|
|
},
|
|
|
|
],
|
|
|
|
}
|
2023-04-09 22:52:18 -04:00
|
|
|
|
2023-03-05 09:07:25 -05:00
|
|
|
|
|
|
|
def add_to_conversation(
|
2023-04-09 22:52:18 -04:00
|
|
|
self,
|
|
|
|
message: str,
|
|
|
|
role: str,
|
|
|
|
convo_id: str = "default",
|
2023-03-05 09:07:25 -05:00
|
|
|
) -> None:
|
|
|
|
"""
|
|
|
|
Add a message to the conversation
|
|
|
|
"""
|
|
|
|
self.conversation[convo_id].append({"role": role, "content": message})
|
|
|
|
|
|
|
|
def __truncate_conversation(self, convo_id: str = "default") -> None:
|
|
|
|
"""
|
|
|
|
Truncate the conversation
|
|
|
|
"""
|
|
|
|
while True:
|
|
|
|
if (
|
2023-04-09 22:52:18 -04:00
|
|
|
self.get_token_count(convo_id) > self.truncate_limit
|
2023-03-05 09:07:25 -05:00
|
|
|
and len(self.conversation[convo_id]) > 1
|
|
|
|
):
|
|
|
|
# Don't remove the first message
|
|
|
|
self.conversation[convo_id].pop(1)
|
|
|
|
else:
|
|
|
|
break
|
|
|
|
|
2023-04-09 22:52:18 -04:00
|
|
|
|
|
|
|
def get_token_count(self, convo_id: str = "default") -> int:
|
|
|
|
"""
|
|
|
|
Get token count
|
|
|
|
"""
|
|
|
|
if self.engine not in [
|
|
|
|
"gpt-3.5-turbo",
|
|
|
|
"gpt-3.5-turbo-0301",
|
|
|
|
"gpt-4",
|
|
|
|
"gpt-4-0314",
|
|
|
|
"gpt-4-32k",
|
|
|
|
"gpt-4-32k-0314",
|
|
|
|
]:
|
|
|
|
raise NotImplementedError("Unsupported engine {self.engine}")
|
|
|
|
|
|
|
|
tiktoken.model.MODEL_TO_ENCODING["gpt-4"] = "cl100k_base"
|
|
|
|
|
|
|
|
encoding = tiktoken.encoding_for_model(self.engine)
|
|
|
|
|
|
|
|
num_tokens = 0
|
|
|
|
for message in self.conversation[convo_id]:
|
|
|
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
|
|
|
num_tokens += 5
|
|
|
|
for key, value in message.items():
|
|
|
|
num_tokens += len(encoding.encode(value))
|
|
|
|
if key == "name": # if there's a name, the role is omitted
|
|
|
|
num_tokens += 5 # role is always required and always 1 token
|
|
|
|
num_tokens += 5 # every reply is primed with <im_start>assistant
|
|
|
|
return num_tokens
|
|
|
|
|
2023-03-05 09:07:25 -05:00
|
|
|
def get_max_tokens(self, convo_id: str) -> int:
|
|
|
|
"""
|
|
|
|
Get max tokens
|
|
|
|
"""
|
2023-04-09 22:52:18 -04:00
|
|
|
return self.max_tokens - self.get_token_count(convo_id)
|
2023-03-05 09:07:25 -05:00
|
|
|
|
|
|
|
def ask_stream(
|
|
|
|
self,
|
|
|
|
prompt: str,
|
|
|
|
role: str = "user",
|
|
|
|
convo_id: str = "default",
|
|
|
|
**kwargs,
|
2023-04-09 22:52:18 -04:00
|
|
|
):
|
2023-03-05 09:07:25 -05:00
|
|
|
"""
|
|
|
|
Ask a question
|
|
|
|
"""
|
|
|
|
# Make conversation if it doesn't exist
|
|
|
|
if convo_id not in self.conversation:
|
|
|
|
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
|
|
|
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
|
|
|
self.__truncate_conversation(convo_id=convo_id)
|
|
|
|
# Get response
|
|
|
|
response = self.session.post(
|
2023-04-09 22:52:18 -04:00
|
|
|
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
|
2023-03-05 09:07:25 -05:00
|
|
|
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
|
|
|
json={
|
|
|
|
"model": self.engine,
|
|
|
|
"messages": self.conversation[convo_id],
|
|
|
|
"stream": True,
|
|
|
|
# kwargs
|
|
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
|
|
"top_p": kwargs.get("top_p", self.top_p),
|
2023-04-09 22:52:18 -04:00
|
|
|
"presence_penalty": kwargs.get(
|
|
|
|
"presence_penalty",
|
|
|
|
self.presence_penalty,
|
|
|
|
),
|
|
|
|
"frequency_penalty": kwargs.get(
|
|
|
|
"frequency_penalty",
|
|
|
|
self.frequency_penalty,
|
|
|
|
),
|
2023-03-05 09:07:25 -05:00
|
|
|
"n": kwargs.get("n", self.reply_count),
|
|
|
|
"user": role,
|
2023-04-09 22:52:18 -04:00
|
|
|
"max_tokens": self.get_max_tokens(convo_id=convo_id),
|
2023-03-05 09:07:25 -05:00
|
|
|
},
|
2023-04-09 22:52:18 -04:00
|
|
|
timeout=kwargs.get("timeout", self.timeout),
|
2023-03-05 09:07:25 -05:00
|
|
|
stream=True,
|
|
|
|
)
|
2023-04-09 22:52:18 -04:00
|
|
|
|
2023-03-05 09:07:25 -05:00
|
|
|
response_role: str = None
|
|
|
|
full_response: str = ""
|
|
|
|
for line in response.iter_lines():
|
|
|
|
if not line:
|
|
|
|
continue
|
|
|
|
# Remove "data: "
|
|
|
|
line = line.decode("utf-8")[6:]
|
|
|
|
if line == "[DONE]":
|
|
|
|
break
|
|
|
|
resp: dict = json.loads(line)
|
|
|
|
choices = resp.get("choices")
|
|
|
|
if not choices:
|
|
|
|
continue
|
|
|
|
delta = choices[0].get("delta")
|
|
|
|
if not delta:
|
|
|
|
continue
|
|
|
|
if "role" in delta:
|
|
|
|
response_role = delta["role"]
|
|
|
|
if "content" in delta:
|
|
|
|
content = delta["content"]
|
|
|
|
full_response += content
|
|
|
|
yield content
|
|
|
|
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
|
|
|
|
2023-04-09 22:52:18 -04:00
|
|
|
async def ask_stream_async(
|
|
|
|
self,
|
|
|
|
prompt: str,
|
|
|
|
role: str = "user",
|
|
|
|
convo_id: str = "default",
|
|
|
|
**kwargs,
|
|
|
|
) -> AsyncGenerator[str, None]:
|
|
|
|
"""
|
|
|
|
Ask a question
|
|
|
|
"""
|
|
|
|
# Make conversation if it doesn't exist
|
|
|
|
if convo_id not in self.conversation:
|
|
|
|
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
|
|
|
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
|
|
|
self.__truncate_conversation(convo_id=convo_id)
|
|
|
|
# Get response
|
|
|
|
async with self.aclient.stream(
|
|
|
|
"post",
|
|
|
|
os.environ.get("API_URL") or "https://api.openai.com/v1/chat/completions",
|
|
|
|
headers={"Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"},
|
|
|
|
json={
|
|
|
|
"model": self.engine,
|
|
|
|
"messages": self.conversation[convo_id],
|
|
|
|
"stream": True,
|
|
|
|
# kwargs
|
|
|
|
"temperature": kwargs.get("temperature", self.temperature),
|
|
|
|
"top_p": kwargs.get("top_p", self.top_p),
|
|
|
|
"presence_penalty": kwargs.get(
|
|
|
|
"presence_penalty",
|
|
|
|
self.presence_penalty,
|
|
|
|
),
|
|
|
|
"frequency_penalty": kwargs.get(
|
|
|
|
"frequency_penalty",
|
|
|
|
self.frequency_penalty,
|
|
|
|
),
|
|
|
|
"n": kwargs.get("n", self.reply_count),
|
|
|
|
"user": role,
|
|
|
|
"max_tokens": self.get_max_tokens(convo_id=convo_id),
|
|
|
|
},
|
|
|
|
timeout=kwargs.get("timeout", self.timeout),
|
|
|
|
) as response:
|
|
|
|
if response.status_code != 200:
|
|
|
|
await response.aread()
|
|
|
|
|
|
|
|
response_role: str = ""
|
|
|
|
full_response: str = ""
|
|
|
|
async for line in response.aiter_lines():
|
|
|
|
line = line.strip()
|
|
|
|
if not line:
|
|
|
|
continue
|
|
|
|
# Remove "data: "
|
|
|
|
line = line[6:]
|
|
|
|
if line == "[DONE]":
|
|
|
|
break
|
|
|
|
resp: dict = json.loads(line)
|
|
|
|
choices = resp.get("choices")
|
|
|
|
if not choices:
|
|
|
|
continue
|
|
|
|
delta: dict[str, str] = choices[0].get("delta")
|
|
|
|
if not delta:
|
|
|
|
continue
|
|
|
|
if "role" in delta:
|
|
|
|
response_role = delta["role"]
|
|
|
|
if "content" in delta:
|
|
|
|
content: str = delta["content"]
|
|
|
|
full_response += content
|
|
|
|
yield content
|
|
|
|
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
|
|
|
|
|
|
|
|
async def ask_async(
|
|
|
|
self,
|
|
|
|
prompt: str,
|
|
|
|
role: str = "user",
|
|
|
|
convo_id: str = "default",
|
|
|
|
**kwargs,
|
|
|
|
) -> str:
|
|
|
|
"""
|
|
|
|
Non-streaming ask
|
|
|
|
"""
|
|
|
|
response = self.ask_stream_async(
|
|
|
|
prompt=prompt,
|
|
|
|
role=role,
|
|
|
|
convo_id=convo_id,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
full_response: str = "".join([r async for r in response])
|
|
|
|
return full_response
|
|
|
|
|
2023-03-05 09:07:25 -05:00
|
|
|
def ask(
|
2023-04-09 22:52:18 -04:00
|
|
|
self,
|
|
|
|
prompt: str,
|
|
|
|
role: str = "user",
|
|
|
|
convo_id: str = "default",
|
|
|
|
**kwargs,
|
2023-03-05 09:07:25 -05:00
|
|
|
) -> str:
|
|
|
|
"""
|
|
|
|
Non-streaming ask
|
|
|
|
"""
|
|
|
|
response = self.ask_stream(
|
|
|
|
prompt=prompt,
|
|
|
|
role=role,
|
|
|
|
convo_id=convo_id,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
full_response: str = "".join(response)
|
|
|
|
return full_response
|
|
|
|
|
|
|
|
|
|
|
|
def reset(self, convo_id: str = "default", system_prompt: str = None) -> None:
|
|
|
|
"""
|
|
|
|
Reset the conversation
|
|
|
|
"""
|
|
|
|
self.conversation[convo_id] = [
|
|
|
|
{"role": "system", "content": system_prompt or self.system_prompt},
|
2023-04-09 22:52:18 -04:00
|
|
|
]
|