Support thread level context (#29)

This commit is contained in:
hibobmaster 2024-04-23 20:18:21 +08:00
parent a37df65c3d
commit 69ce5b43a3
No known key found for this signature in database
7 changed files with 177 additions and 43 deletions

1
.gitignore vendored
View File

@ -173,3 +173,4 @@ cython_debug/
sync_db
manage_db
element-keys.txt
context.db

View File

@ -1,5 +1,8 @@
# Changelog
## 1.7.0
- Support thread level context
## 1.6.0
- Add GPT Vision

View File

@ -12,6 +12,7 @@ This is a simple Matrix bot that support using OpenAI API, Langchain to generate
4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise))
5. Image Generation with [DALL·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
6. GPT Vision(openai or [GPT Vision API](https://platform.openai.com/docs/guides/vision) compatible such as [LocalAI](https://localai.io/features/gpt-vision/))
7. Room level and thread level chat context
## Installation and Setup
@ -21,10 +22,10 @@ For explainations and complete parameter list see: https://github.com/hibobmaste
Create two empty file, for persist database only<br>
```bash
touch sync_db manage_db
touch sync_db context.db manage_db
sudo docker compose up -d
```
manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database<br>
manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database, context.db is for bot chat context<br>
<hr>
Normal Method:<br>
system dependece: <code>libolm-dev</code>
@ -115,12 +116,16 @@ LangChain(flowise) admin: https://github.com/hibobmaster/matrix_chatgpt_bot/wiki
![demo2](https://i.imgur.com/BKZktWd.jpg)
https://github.com/hibobmaster/matrix_chatgpt_bot/wiki/ <br>
## Thread level Context
Mention bot with prompt, bot will reply in thread.
To keep context just send prompt in thread directly without mention it.
![thread level context 1](https://i.imgur.com/4vLvNCt.jpeg)
![thread level context 2](https://i.imgur.com/1eb1Lmd.jpeg)
## Thanks
1. [matrix-nio](https://github.com/poljar/matrix-nio)
2. [acheong08](https://github.com/acheong08)
3. [8go](https://github.com/8go/)
<a href="https://jb.gg/OpenSourceSupport" target="_blank">
<img src="https://resources.jetbrains.com/storage/products/company/brand/logos/jb_beam.png" alt="JetBrains Logo (Main) logo." width="200" height="200">
</a>

View File

@ -12,8 +12,9 @@ services:
# use env file or config.json
# - ./config.json:/app/config.json
# use touch to create empty db file, for persist database only
# manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database
# manage_db(can be ignored) is for langchain agent, sync_db is for matrix sync database, context.db is for bot chat context
- ./sync_db:/app/sync_db
- ./context.db:/app/context.db
# - ./manage_db:/app/manage_db
# import_keys path
# - ./element-keys.txt:/app/element-keys.txt

View File

@ -227,6 +227,8 @@ class Bot:
self.new_prog = re.compile(r"\s*!new\s+(.+)$")
async def close(self, task: asyncio.Task) -> None:
self.chatbot.cursor.close()
self.chatbot.conn.close()
await self.httpx_client.aclose()
if self.lc_admin is not None:
self.lc_manager.c.close()
@ -251,6 +253,9 @@ class Bot:
# sender_id
sender_id = event.sender
# event source
event_source = event.source
# user_message
raw_user_message = event.body
@ -265,6 +270,48 @@ class Bot:
# remove newline character from event.body
content_body = re.sub("\r\n|\r|\n", " ", raw_user_message)
# @bot and reply in thread
if "m.mentions" in event_source["content"]:
if "user_ids" in event_source["content"]["m.mentions"]:
# @bot
if (
self.user_id
in event_source["content"]["m.mentions"]["user_ids"]
):
try:
asyncio.create_task(
self.thread_chat(
room_id,
reply_to_event_id,
sender_id=sender_id,
thread_root_id=reply_to_event_id,
prompt=content_body,
)
)
except Exception as e:
logger.error(e, exe_info=True)
# thread converstaion
if "m.relates_to" in event_source["content"]:
if "rel_type" in event_source["content"]["m.relates_to"]:
thread_root_id = event_source["content"]["m.relates_to"]["event_id"]
# thread is created by @bot
if thread_root_id in self.chatbot.conversation:
try:
asyncio.create_task(
self.thread_chat(
room_id,
reply_to_event_id,
sender_id=sender_id,
thread_root_id=thread_root_id,
prompt=content_body,
)
)
except Exception as e:
logger.error(e, exe_info=True)
# common command
# !gpt command
if (
self.openai_api_key is not None
@ -1300,6 +1347,37 @@ class Bot:
estr = traceback.format_exc()
logger.info(estr)
# thread chat
async def thread_chat(
self, room_id, reply_to_event_id, thread_root_id, prompt, sender_id
):
try:
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
content = await self.chatbot.ask_async_v2(
prompt=prompt,
convo_id=thread_root_id,
)
await send_room_message(
self.client,
room_id,
reply_message=content,
reply_to_event_id=reply_to_event_id,
sender_id=sender_id,
reply_in_thread=True,
thread_root_id=thread_root_id,
)
except Exception as e:
logger.error(e, exe_info=True)
await send_room_message(
self.client,
room_id,
reply_message=GENERAL_ERROR_MESSAGE,
sender_id=sender_id,
reply_to_event_id=reply_to_event_id,
reply_in_thread=True,
thread_root_id=thread_root_id,
)
# !chat command
async def chat(self, room_id, reply_to_event_id, prompt, sender_id, user_message):
try:

View File

@ -2,6 +2,7 @@
Code derived from https://github.com/acheong08/ChatGPT/blob/main/src/revChatGPT/V3.py
A simple wrapper for the official ChatGPT API
"""
import sqlite3
import json
from typing import AsyncGenerator
from tenacity import retry, wait_random_exponential, stop_after_attempt
@ -9,16 +10,7 @@ import httpx
import tiktoken
ENGINES = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
]
ENGINES = ["gpt-3.5-turbo", "gpt-4", "gpt-4-32k", "gpt-4-turbo"]
class Chatbot:
@ -41,6 +33,7 @@ class Chatbot:
reply_count: int = 1,
truncate_limit: int = None,
system_prompt: str = None,
db_path: str = "context.db",
) -> None:
"""
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
@ -53,23 +46,24 @@ class Chatbot:
or "You are ChatGPT, \
a large language model trained by OpenAI. Respond conversationally"
)
# https://platform.openai.com/docs/models
self.max_tokens: int = max_tokens or (
31000
127000
if "gpt-4-turbo" in engine
else 31000
if "gpt-4-32k" in engine
else 7000
if "gpt-4" in engine
else 15000
if "gpt-3.5-turbo-16k" in engine
else 4000
else 16000
)
self.truncate_limit: int = truncate_limit or (
30500
126500
if "gpt-4-turbo" in engine
else 30500
if "gpt-4-32k" in engine
else 6500
if "gpt-4" in engine
else 14500
if "gpt-3.5-turbo-16k" in engine
else 3500
else 15500
)
self.temperature: float = temperature
self.top_p: float = top_p
@ -80,18 +74,50 @@ class Chatbot:
self.aclient = aclient
self.conversation: dict[str, list[dict]] = {
"default": [
{
"role": "system",
"content": system_prompt,
},
],
}
self.db_path = db_path
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self._create_tables()
self.conversation = self._load_conversation()
if self.get_token_count("default") > self.max_tokens:
raise Exception("System prompt is too long")
def _create_tables(self) -> None:
self.conn.execute(
"""
CREATE TABLE IF NOT EXISTS conversations(
id INTEGER PRIMARY KEY AUTOINCREMENT,
convo_id TEXT UNIQUE,
messages TEXT
)
"""
)
def _load_conversation(self) -> dict[str, list[dict]]:
conversations: dict[str, list[dict]] = {
"default": [
{
"role": "system",
"content": self.system_prompt,
},
],
}
self.cursor.execute("SELECT convo_id, messages FROM conversations")
for convo_id, messages in self.cursor.fetchall():
conversations[convo_id] = json.loads(messages)
return conversations
def _save_conversation(self, convo_id) -> None:
self.conn.execute(
"INSERT OR REPLACE INTO conversations (convo_id, messages) VALUES (?, ?)",
(convo_id, json.dumps(self.conversation[convo_id])),
)
self.conn.commit()
def add_to_conversation(
self,
message: str,
@ -102,6 +128,7 @@ class Chatbot:
Add a message to the conversation
"""
self.conversation[convo_id].append({"role": role, "content": message})
self._save_conversation(convo_id)
def __truncate_conversation(self, convo_id: str = "default") -> None:
"""
@ -116,6 +143,7 @@ class Chatbot:
self.conversation[convo_id].pop(1)
else:
break
self._save_conversation(convo_id)
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def get_token_count(self, convo_id: str = "default") -> int:
@ -305,6 +333,7 @@ class Chatbot:
self.conversation[convo_id] = [
{"role": "system", "content": system_prompt or self.system_prompt},
]
self._save_conversation(convo_id)
@retry(wait=wait_random_exponential(min=2, max=5), stop=stop_after_attempt(3))
async def oneTimeAsk(

View File

@ -12,6 +12,8 @@ async def send_room_message(
sender_id: str = "",
user_message: str = "",
reply_to_event_id: str = "",
reply_in_thread: bool = False,
thread_root_id: str = "",
) -> None:
if reply_to_event_id == "":
content = {
@ -23,6 +25,23 @@ async def send_room_message(
extensions=["nl2br", "tables", "fenced_code"],
),
}
elif reply_in_thread and thread_root_id:
content = {
"msgtype": "m.text",
"body": reply_message,
"format": "org.matrix.custom.html",
"formatted_body": markdown.markdown(
reply_message,
extensions=["nl2br", "tables", "fenced_code"],
),
"m.relates_to": {
"m.in_reply_to": {"event_id": reply_to_event_id},
"rel_type": "m.thread",
"event_id": thread_root_id,
"is_falling_back": True,
},
}
else:
body = "> <" + sender_id + "> " + user_message + "\n\n" + reply_message
format = r"org.matrix.custom.html"
@ -51,13 +70,11 @@ async def send_room_message(
"formatted_body": formatted_body,
"m.relates_to": {"m.in_reply_to": {"event_id": reply_to_event_id}},
}
try:
await client.room_send(
room_id,
message_type="m.room.message",
content=content,
ignore_unverified_devices=True,
)
await client.room_typing(room_id, typing_state=False)
except Exception as e:
logger.error(e)
await client.room_send(
room_id,
message_type="m.room.message",
content=content,
ignore_unverified_devices=True,
)
await client.room_typing(room_id, typing_state=False)