feat: Add GPT Vision
This commit is contained in:
parent
80409811dc
commit
c9fa6065fa
|
@ -24,4 +24,6 @@ IMAGE_FORMAT="webp"
|
|||
SDWUI_STEPS=20
|
||||
SDWUI_SAMPLER_NAME="Euler a"
|
||||
SDWUI_CFG_SCALE=7
|
||||
GPT_VISION_MODEL="llava"
|
||||
GPT_VISION_API_ENDPOINT="https://localai.xxxxxxx.xxxxxxx/v1/chat/completions"
|
||||
TIMEOUT=120.0
|
||||
|
|
13
README.md
13
README.md
|
@ -1,7 +1,8 @@
|
|||
## Introduction
|
||||
|
||||
This is a simple Matrix bot that support using OpenAI API, Langchain to generate responses from user inputs. The bot responds to these commands: `!gpt`, `!chat`, `!pic`, `!new`, `!lc` and `!help` depending on the first word of the prompt.
|
||||
This is a simple Matrix bot that support using OpenAI API, Langchain to generate responses from user inputs. The bot responds to these commands: `!gpt`, `!chat`, `!v`, `!pic`, `!new`, `!lc` and `!help` depending on the first word of the prompt.
|
||||
![ChatGPT](https://i.imgur.com/kK4rnPf.jpeg)
|
||||
![GPT Vision](https://i.imgur.com/6EqC603.jpeg)
|
||||
|
||||
## Feature
|
||||
|
||||
|
@ -10,7 +11,7 @@ This is a simple Matrix bot that support using OpenAI API, Langchain to generate
|
|||
3. Colorful code blocks
|
||||
4. Langchain([Flowise](https://github.com/FlowiseAI/Flowise))
|
||||
5. Image Generation with [DALL·E](https://platform.openai.com/docs/api-reference/images/create) or [LocalAI](https://localai.io/features/image-generation/) or [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API)
|
||||
|
||||
6. GPT Vision(openai or [GPT Vision API](https://platform.openai.com/docs/guides/vision) compatible such as [LocalAI](https://localai.io/features/gpt-vision/))
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
|
@ -83,6 +84,14 @@ To interact with the bot, simply send a message to the bot in the Matrix room wi
|
|||
!chat Can you tell me a joke?
|
||||
```
|
||||
|
||||
- `!v` GPT Vision command
|
||||
```
|
||||
# if image is uploaded by bot
|
||||
!v what is in the image?
|
||||
# else you should @bot_account before the command
|
||||
@bot !v what is in the image?
|
||||
```
|
||||
|
||||
- `!lc` To chat using langchain api endpoint
|
||||
```
|
||||
!lc All the world is a stage
|
||||
|
|
|
@ -25,5 +25,7 @@
|
|||
"sdwui_sampler_name": "Euler a",
|
||||
"sdwui_cfg_scale": 7,
|
||||
"image_format": "webp",
|
||||
"gpt_vision_api_endpoint": "https://api.openai.com/v1/chat/completions",
|
||||
"gpt_vision_model": "gpt-4-vision-preview",
|
||||
"timeout": 120.0
|
||||
}
|
||||
|
|
143
src/bot.py
143
src/bot.py
|
@ -6,6 +6,7 @@ import sys
|
|||
import traceback
|
||||
from typing import Union, Optional
|
||||
import aiofiles.os
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -21,6 +22,7 @@ from nio import (
|
|||
KeyVerificationMac,
|
||||
KeyVerificationStart,
|
||||
LocalProtocolError,
|
||||
DownloadError,
|
||||
LoginResponse,
|
||||
MatrixRoom,
|
||||
MegolmEvent,
|
||||
|
@ -29,6 +31,7 @@ from nio import (
|
|||
WhoamiResponse,
|
||||
)
|
||||
from nio.store.database import SqliteStore
|
||||
from nio.api import Api
|
||||
|
||||
from log import getlogger
|
||||
from send_image import send_room_image
|
||||
|
@ -36,6 +39,7 @@ from send_message import send_room_message
|
|||
from flowise import flowise_query
|
||||
from lc_manager import LCManager
|
||||
from gptbot import Chatbot
|
||||
from gpt_vision import gpt_vision_query
|
||||
import imagegen
|
||||
|
||||
logger = getlogger()
|
||||
|
@ -73,6 +77,8 @@ class Bot:
|
|||
sdwui_steps: Optional[int] = None,
|
||||
sdwui_sampler_name: Optional[str] = None,
|
||||
sdwui_cfg_scale: Optional[float] = None,
|
||||
gpt_vision_model: Optional[str] = None,
|
||||
gpt_vision_api_endpoint: Optional[str] = None,
|
||||
timeout: Union[float, None] = None,
|
||||
):
|
||||
if homeserver is None or user_id is None or device_id is None:
|
||||
|
@ -127,6 +133,9 @@ class Bot:
|
|||
self.image_generation_endpoint: str = image_generation_endpoint
|
||||
self.image_generation_backend: str = image_generation_backend
|
||||
|
||||
self.gpt_vision_model = gpt_vision_model
|
||||
self.gpt_vision_api_endpoint = gpt_vision_api_endpoint
|
||||
|
||||
if image_format:
|
||||
self.image_format: str = image_format
|
||||
else:
|
||||
|
@ -206,15 +215,16 @@ class Bot:
|
|||
self.to_device_callback, (KeyVerificationEvent,)
|
||||
)
|
||||
|
||||
# regular expression to match keyword commands
|
||||
self.gpt_prog = re.compile(r"^\s*!gpt\s+(.+)$")
|
||||
self.chat_prog = re.compile(r"^\s*!chat\s+(.+)$")
|
||||
self.pic_prog = re.compile(r"^\s*!pic\s+(.+)$")
|
||||
self.lc_prog = re.compile(r"^\s*!lc\s+(.+)$")
|
||||
self.lcadmin_prog = re.compile(r"^\s*!lcadmin\s+(.+)$")
|
||||
self.agent_prog = re.compile(r"^\s*!agent\s+(.+)$")
|
||||
self.help_prog = re.compile(r"^\s*!help\s*.*$")
|
||||
self.new_prog = re.compile(r"^\s*!new\s+(.+)$")
|
||||
# regular expression to search keyword commands
|
||||
self.gpt_prog = re.compile(r"\s*!gpt\s+(.+)$")
|
||||
self.chat_prog = re.compile(r"\s*!chat\s+(.+)$")
|
||||
self.pic_prog = re.compile(r"\s*!pic\s+(.+)$")
|
||||
self.lc_prog = re.compile(r"\s*!lc\s+(.+)$")
|
||||
self.lcadmin_prog = re.compile(r"\s*!lcadmin\s+(.+)$")
|
||||
self.agent_prog = re.compile(r"\s*!agent\s+(.+)$")
|
||||
self.gpt_vision_prog = re.compile(r"\s*!v\s+(.+)$")
|
||||
self.help_prog = re.compile(r"\s*!help\s*.*$")
|
||||
self.new_prog = re.compile(r"\s*!new\s+(.+)$")
|
||||
|
||||
async def close(self, task: asyncio.Task) -> None:
|
||||
await self.httpx_client.aclose()
|
||||
|
@ -297,10 +307,59 @@ class Bot:
|
|||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
|
||||
# !v command
|
||||
# not work in E2EE room
|
||||
if self.gpt_vision_api_endpoint and self.gpt_vision_model:
|
||||
if (
|
||||
"m.relates_to" in event.source["content"]
|
||||
and "m.mentions" in event.source["content"]
|
||||
and "user_ids" in event.source["content"]["m.mentions"]
|
||||
):
|
||||
if (
|
||||
self.user_id
|
||||
in event.source["content"]["m.mentions"]["user_ids"]
|
||||
):
|
||||
v = self.gpt_vision_prog.search(content_body)
|
||||
if v:
|
||||
prompt = v.group(1)
|
||||
# Trigger gpt vision flow
|
||||
in_reply_to_event_id = event.source["content"][
|
||||
"m.relates_to"
|
||||
]["m.in_reply_to"]["event_id"]
|
||||
event_info = await self.get_event(
|
||||
room_id, in_reply_to_event_id
|
||||
)
|
||||
msgtype = event_info["content"]["msgtype"]
|
||||
if "m.image" == msgtype:
|
||||
image_mimetype = event_info["content"]["info"][
|
||||
"mimetype"
|
||||
]
|
||||
url = event_info["content"]["url"]
|
||||
resp = await self.download_mxc(url)
|
||||
if isinstance(resp, DownloadError):
|
||||
logger.error("Download of image failed")
|
||||
else:
|
||||
b64_image = base64.b64encode(resp.body).decode(
|
||||
"utf-8"
|
||||
)
|
||||
image_url = (
|
||||
f"data:{image_mimetype};base64,{b64_image}"
|
||||
)
|
||||
asyncio.create_task(
|
||||
self.gpt_vision_cmd(
|
||||
room_id,
|
||||
reply_to_event_id,
|
||||
prompt,
|
||||
image_url,
|
||||
sender_id,
|
||||
raw_user_message,
|
||||
)
|
||||
)
|
||||
|
||||
# lc command
|
||||
if self.lc_admin is not None:
|
||||
perm_flags = 0
|
||||
m = self.lc_prog.match(content_body)
|
||||
m = self.lc_prog.search(content_body)
|
||||
if m:
|
||||
try:
|
||||
# room_level permission
|
||||
|
@ -443,7 +502,7 @@ class Bot:
|
|||
!lcadmin list
|
||||
""" # noqa: E501
|
||||
if self.lc_admin is not None:
|
||||
q = self.lcadmin_prog.match(content_body)
|
||||
q = self.lcadmin_prog.search(content_body)
|
||||
if q:
|
||||
if sender_id in self.lc_admin:
|
||||
try:
|
||||
|
@ -885,7 +944,7 @@ class Bot:
|
|||
)
|
||||
|
||||
# !agent command
|
||||
a = self.agent_prog.match(content_body)
|
||||
a = self.agent_prog.search(content_body)
|
||||
if a:
|
||||
command_with_params = a.group(1).strip()
|
||||
split_items = re.sub("\s{1,}", " ", command_with_params).split(" ")
|
||||
|
@ -956,7 +1015,7 @@ class Bot:
|
|||
logger.error(e, exc_info=True)
|
||||
|
||||
# !new command
|
||||
n = self.new_prog.match(content_body)
|
||||
n = self.new_prog.search(content_body)
|
||||
if n:
|
||||
new_command = n.group(1)
|
||||
try:
|
||||
|
@ -973,7 +1032,7 @@ class Bot:
|
|||
logger.error(e, exc_info=True)
|
||||
|
||||
# !pic command
|
||||
p = self.pic_prog.match(content_body)
|
||||
p = self.pic_prog.search(content_body)
|
||||
if p:
|
||||
prompt = p.group(1)
|
||||
try:
|
||||
|
@ -990,7 +1049,7 @@ class Bot:
|
|||
logger.error(e, exc_info=True)
|
||||
|
||||
# help command
|
||||
h = self.help_prog.match(content_body)
|
||||
h = self.help_prog.search(content_body)
|
||||
if h:
|
||||
try:
|
||||
asyncio.create_task(
|
||||
|
@ -1288,6 +1347,42 @@ class Bot:
|
|||
room_id, reply_to_event_id, sender_id, user_message
|
||||
)
|
||||
|
||||
# !v command
|
||||
async def gpt_vision_cmd(
|
||||
self,
|
||||
room_id: str,
|
||||
reply_to_event_id: str,
|
||||
prompt: str,
|
||||
image_url: str,
|
||||
sender_id: str,
|
||||
user_message: str,
|
||||
) -> None:
|
||||
try:
|
||||
# sending typing state, seconds to milliseconds
|
||||
await self.client.room_typing(room_id, timeout=int(self.timeout) * 1000)
|
||||
responseMessage = await gpt_vision_query(
|
||||
self.gpt_vision_api_endpoint,
|
||||
prompt,
|
||||
image_url,
|
||||
self.gpt_vision_model,
|
||||
self.httpx_client,
|
||||
api_key=self.openai_api_key,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
await send_room_message(
|
||||
self.client,
|
||||
room_id,
|
||||
reply_message=responseMessage.strip(),
|
||||
reply_to_event_id=reply_to_event_id,
|
||||
sender_id=sender_id,
|
||||
user_message=user_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e, exc_info=True)
|
||||
await self.send_general_error_message(
|
||||
room_id, reply_to_event_id, sender_id, user_message
|
||||
)
|
||||
|
||||
# !lc command
|
||||
async def lc(
|
||||
self,
|
||||
|
@ -1413,6 +1508,7 @@ class Bot:
|
|||
+ "!pic [prompt], Image generation by DALL·E or LocalAI or stable-diffusion-webui\n" # noqa: E501
|
||||
+ "!new + chat, start a new conversation \n"
|
||||
+ "!lc [prompt], chat using langchain api\n"
|
||||
+ "!v [prompt], gpt_vision\n"
|
||||
+ "!help, help message"
|
||||
) # noqa: E501
|
||||
|
||||
|
@ -1464,6 +1560,7 @@ class Bot:
|
|||
await self.client.close()
|
||||
sys.exit(1)
|
||||
logger.info("Successfully login via password")
|
||||
self.access_token = resp.access_token
|
||||
elif self.access_token is not None:
|
||||
self.client.restore_login(
|
||||
user_id=self.user_id,
|
||||
|
@ -1494,3 +1591,19 @@ class Bot:
|
|||
# sync messages in the room
|
||||
async def sync_forever(self, timeout=30000, full_state=True) -> None:
|
||||
await self.client.sync_forever(timeout=timeout, full_state=full_state)
|
||||
|
||||
# get event from http
|
||||
async def get_event(self, room_id: str, event_id: str) -> dict:
|
||||
method, path = Api.room_get_event(self.access_token, room_id, event_id)
|
||||
url = self.homeserver + path
|
||||
if method == "GET":
|
||||
resp = await self.httpx_client.get(url)
|
||||
return resp.json()
|
||||
elif method == "POST":
|
||||
resp = await self.httpx_client.post(url)
|
||||
return resp.json()
|
||||
|
||||
# download mxc
|
||||
async def download_mxc(self, mxc: str, filename: Optional[str] = None):
|
||||
response = await self.client.download(mxc, filename)
|
||||
return response
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
import httpx
|
||||
|
||||
|
||||
async def gpt_vision_query(
|
||||
api_url: str,
|
||||
prompt: str,
|
||||
image_url: str,
|
||||
model: str,
|
||||
session: httpx.AsyncClient,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
model: gpt-4-vision-preview or llava
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {kwargs.get('api_key', '')}",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {"url": image_url}},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = await session.post(
|
||||
api_url, headers=headers, json=payload, timeout=kwargs.get("timeout", "120")
|
||||
)
|
||||
if response.status_code == 200:
|
||||
resp = response.json()["choices"][0]
|
||||
return resp["message"]["content"]
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
async def test():
|
||||
async with httpx.AsyncClient() as session:
|
||||
api_url = "http://127.0.0.1:12345/v1/chat/completions"
|
||||
prompt = "What is in the image?"
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
model = "llava"
|
||||
api_key = "xxxx"
|
||||
response = await gpt_vision_query(
|
||||
api_url, prompt, image_url, model, session, api_key=api_key, timeout=300
|
||||
)
|
||||
print(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(test())
|
|
@ -49,6 +49,8 @@ async def main():
|
|||
sdwui_sampler_name=config.get("sdwui_sampler_name"),
|
||||
sdwui_cfg_scale=config.get("sdwui_cfg_scale"),
|
||||
image_format=config.get("image_format"),
|
||||
gpt_vision_model=config.get("gpt_vision_model"),
|
||||
gpt_vision_api_endpoint=config.get("gpt_vision_api_endpoint"),
|
||||
timeout=config.get("timeout"),
|
||||
)
|
||||
if (
|
||||
|
@ -85,6 +87,8 @@ async def main():
|
|||
sdwui_sampler_name=os.environ.get("SDWUI_SAMPLER_NAME"),
|
||||
sdwui_cfg_scale=float(os.environ.get("SDWUI_CFG_SCALE", 7)),
|
||||
image_format=os.environ.get("IMAGE_FORMAT"),
|
||||
gpt_vision_model=os.environ.get("GPT_VISION_MODEL"),
|
||||
gpt_vision_api_endpoint=os.environ.get("GPT_VISION_API_ENDPOINT"),
|
||||
timeout=float(os.environ.get("TIMEOUT", 120.0)),
|
||||
)
|
||||
if (
|
||||
|
|
Loading…
Reference in New Issue