mirror of
https://github.com/hibobmaster/matrix_chatgpt_bot.git
synced 2024-10-01 05:35:36 -04:00
add official api_endpoint, properly handle failing request
This commit is contained in:
parent
ecf96124cb
commit
fab7a36bc4
26
ask_gpt.py
26
ask_gpt.py
@ -1,15 +1,9 @@
|
||||
"""
|
||||
api_endpoint from https://github.com/ayaka14732/ChatGPTAPIFree
|
||||
"""
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
api_endpoint_free = "https://chatgpt-api.shn.hk/v1/"
|
||||
headers = {'Content-Type': "application/json"}
|
||||
|
||||
|
||||
async def ask(prompt: str) -> str:
|
||||
async def ask(prompt: str, api_endpoint: str, headers: dict) -> str:
|
||||
jsons = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
@ -20,13 +14,14 @@ async def ask(prompt: str) -> str:
|
||||
],
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
||||
while True:
|
||||
max_try = 5
|
||||
while max_try > 0:
|
||||
try:
|
||||
async with session.post(url=api_endpoint_free,
|
||||
async with session.post(url=api_endpoint,
|
||||
json=jsons, headers=headers, timeout=10) as response:
|
||||
status_code = response.status
|
||||
if not status_code == 200:
|
||||
max_try = max_try - 1
|
||||
# wait 2s
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
@ -37,14 +32,3 @@ async def ask(prompt: str) -> str:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
|
||||
async def test() -> None:
|
||||
resp = await ask("Hello World")
|
||||
print(resp)
|
||||
# type: str
|
||||
print(type(resp))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test())
|
||||
|
23
bot.py
23
bot.py
@ -9,6 +9,14 @@ from ask_gpt import ask
|
||||
from send_message import send_room_message
|
||||
from v3 import Chatbot
|
||||
|
||||
"""
|
||||
free api_endpoint from https://github.com/ayaka14732/ChatGPTAPIFree
|
||||
"""
|
||||
api_endpoint_list = {
|
||||
"free": "https://chatgpt-api.shn.hk/v1/",
|
||||
"paid": "https://api.openai.com/v1/chat/completions"
|
||||
}
|
||||
|
||||
|
||||
class Bot:
|
||||
def __init__(
|
||||
@ -37,10 +45,21 @@ class Bot:
|
||||
# regular expression to match keyword [!gpt {prompt}] [!chat {prompt}]
|
||||
self.gpt_prog = re.compile(r"^\s*!gpt\s*(.+)$")
|
||||
self.chat_prog = re.compile(r"^\s*!chat\s*(.+)$")
|
||||
# initialize chatbot
|
||||
# initialize chatbot and api_endpoint
|
||||
if self.api_key != '':
|
||||
self.chatbot = Chatbot(api_key=self.api_key)
|
||||
|
||||
self.api_endpoint = api_endpoint_list['paid']
|
||||
# request header for !gpt command
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
}
|
||||
else:
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# message_callback event
|
||||
async def message_callback(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||
# remove newline character from event.body
|
||||
@ -55,7 +74,7 @@ class Bot:
|
||||
# sending typing state
|
||||
await self.client.room_typing(room_id)
|
||||
prompt = m.group(1)
|
||||
text = await ask(prompt)
|
||||
text = await ask(prompt, self.api_endpoint, self.headers)
|
||||
text = text.strip()
|
||||
await send_room_message(self.client, room_id, send_text=text)
|
||||
|
||||
|
28
test.py
28
test.py
@ -4,17 +4,37 @@ from ask_gpt import ask
|
||||
import json
|
||||
fp = open("config.json", "r")
|
||||
config = json.load(fp)
|
||||
api_key = config.get('api_key', '')
|
||||
api_endpoint_list = {
|
||||
"free": "https://chatgpt-api.shn.hk/v1/",
|
||||
"paid": "https://api.openai.com/v1/chat/completions"
|
||||
}
|
||||
|
||||
|
||||
def test_v3(prompt: str):
|
||||
bot = Chatbot(api_key=config['api_key'])
|
||||
bot = Chatbot(api_key=api_key)
|
||||
resp = bot.ask(prompt=prompt)
|
||||
print(resp)
|
||||
|
||||
|
||||
async def test_ask(prompt: str):
|
||||
print(await ask(prompt=prompt))
|
||||
async def test_ask_gpt_paid(prompt: str):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + api_key,
|
||||
}
|
||||
api_endpoint = api_endpoint_list['paid']
|
||||
# test ask_gpt.py ask()
|
||||
print(await ask(prompt, api_endpoint, headers))
|
||||
|
||||
|
||||
async def test_ask_gpt_free(prompt: str):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
api_endpoint = api_endpoint_list['free']
|
||||
print(await ask(prompt, api_endpoint, headers))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_v3("Hello World")
|
||||
asyncio.run(test_ask("Hello World"))
|
||||
asyncio.run(test_ask_gpt_paid("Hello World"))
|
||||
asyncio.run(test_ask_gpt_free("Hello World"))
|
||||
|
Loading…
Reference in New Issue
Block a user