2023-04-13 15:22:24 +00:00
|
|
|
import asyncio
|
2023-03-05 14:07:25 +00:00
|
|
|
import json
|
2023-04-10 02:52:18 +00:00
|
|
|
import os
|
2023-06-05 03:27:37 +00:00
|
|
|
from pathlib import Path
|
2023-09-13 07:27:34 +00:00
|
|
|
import signal
|
|
|
|
import sys
|
2023-09-13 06:36:35 +00:00
|
|
|
|
2023-03-05 14:07:25 +00:00
|
|
|
from bot import Bot
|
2023-04-10 02:52:18 +00:00
|
|
|
from log import getlogger
|
2023-03-05 14:07:25 +00:00
|
|
|
|
2023-04-10 02:52:18 +00:00
|
|
|
logger = getlogger()
|
2023-03-05 14:07:25 +00:00
|
|
|
|
2023-04-11 05:42:43 +00:00
|
|
|
|
2023-03-05 14:07:25 +00:00
|
|
|
async def main():
|
2023-04-20 07:39:14 +00:00
|
|
|
need_import_keys = False
|
2023-06-05 03:27:37 +00:00
|
|
|
config_path = Path(os.path.dirname(__file__)).parent / "config.json"
|
|
|
|
if os.path.isfile(config_path):
|
2023-09-13 07:27:34 +00:00
|
|
|
try:
|
|
|
|
fp = open(config_path, encoding="utf8")
|
|
|
|
config = json.load(fp)
|
|
|
|
except Exception:
|
|
|
|
logger.error("config.json load error, please check the file")
|
|
|
|
sys.exit(1)
|
2023-04-11 05:42:43 +00:00
|
|
|
|
2023-05-30 01:59:25 +00:00
|
|
|
matrix_bot = Bot(
|
|
|
|
homeserver=config.get("homeserver"),
|
|
|
|
user_id=config.get("user_id"),
|
|
|
|
password=config.get("password"),
|
2023-12-05 11:14:26 +00:00
|
|
|
access_token=config.get("access_token"),
|
2023-05-30 01:59:25 +00:00
|
|
|
device_id=config.get("device_id"),
|
|
|
|
room_id=config.get("room_id"),
|
|
|
|
import_keys_path=config.get("import_keys_path"),
|
|
|
|
import_keys_password=config.get("import_keys_password"),
|
2023-09-13 07:27:34 +00:00
|
|
|
openai_api_key=config.get("openai_api_key"),
|
|
|
|
gpt_api_endpoint=config.get("gpt_api_endpoint"),
|
|
|
|
gpt_model=config.get("gpt_model"),
|
2023-09-17 15:48:21 +00:00
|
|
|
max_tokens=config.get("max_tokens"),
|
|
|
|
top_p=config.get("top_p"),
|
|
|
|
presence_penalty=config.get("presence_penalty"),
|
|
|
|
frequency_penalty=config.get("frequency_penalty"),
|
|
|
|
reply_count=config.get("reply_count"),
|
2023-09-13 07:27:34 +00:00
|
|
|
system_prompt=config.get("system_prompt"),
|
2023-09-17 15:48:21 +00:00
|
|
|
temperature=config.get("temperature"),
|
2023-09-17 15:00:02 +00:00
|
|
|
lc_admin=config.get("lc_admin"),
|
2023-09-17 04:27:16 +00:00
|
|
|
image_generation_endpoint=config.get("image_generation_endpoint"),
|
|
|
|
image_generation_backend=config.get("image_generation_backend"),
|
2023-12-23 13:03:36 +00:00
|
|
|
image_generation_size=config.get("image_generation_size"),
|
2024-01-04 13:35:28 +00:00
|
|
|
sdwui_steps=config.get("sdwui_steps"),
|
|
|
|
sdwui_sampler_name=config.get("sdwui_sampler_name"),
|
|
|
|
sdwui_cfg_scale=config.get("sdwui_cfg_scale"),
|
2023-12-23 13:03:36 +00:00
|
|
|
image_format=config.get("image_format"),
|
2024-03-08 07:07:01 +00:00
|
|
|
gpt_vision_model=config.get("gpt_vision_model"),
|
|
|
|
gpt_vision_api_endpoint=config.get("gpt_vision_api_endpoint"),
|
2023-09-17 15:48:21 +00:00
|
|
|
timeout=config.get("timeout"),
|
2023-05-30 01:59:25 +00:00
|
|
|
)
|
|
|
|
if (
|
|
|
|
config.get("import_keys_path")
|
|
|
|
and config.get("import_keys_password") is not None
|
|
|
|
):
|
2023-04-20 07:39:14 +00:00
|
|
|
need_import_keys = True
|
2023-04-10 13:40:39 +00:00
|
|
|
|
|
|
|
else:
|
2023-05-30 01:59:25 +00:00
|
|
|
matrix_bot = Bot(
|
|
|
|
homeserver=os.environ.get("HOMESERVER"),
|
|
|
|
user_id=os.environ.get("USER_ID"),
|
|
|
|
password=os.environ.get("PASSWORD"),
|
2023-12-05 11:14:26 +00:00
|
|
|
access_token=os.environ.get("ACCESS_TOKEN"),
|
2023-05-30 01:59:25 +00:00
|
|
|
device_id=os.environ.get("DEVICE_ID"),
|
|
|
|
room_id=os.environ.get("ROOM_ID"),
|
|
|
|
import_keys_path=os.environ.get("IMPORT_KEYS_PATH"),
|
|
|
|
import_keys_password=os.environ.get("IMPORT_KEYS_PASSWORD"),
|
2023-09-13 07:27:34 +00:00
|
|
|
openai_api_key=os.environ.get("OPENAI_API_KEY"),
|
|
|
|
gpt_api_endpoint=os.environ.get("GPT_API_ENDPOINT"),
|
|
|
|
gpt_model=os.environ.get("GPT_MODEL"),
|
2023-12-12 08:31:02 +00:00
|
|
|
max_tokens=int(os.environ.get("MAX_TOKENS", 4000)),
|
|
|
|
top_p=float(os.environ.get("TOP_P", 1.0)),
|
|
|
|
presence_penalty=float(os.environ.get("PRESENCE_PENALTY", 0.0)),
|
|
|
|
frequency_penalty=float(os.environ.get("FREQUENCY_PENALTY", 0.0)),
|
|
|
|
reply_count=int(os.environ.get("REPLY_COUNT", 1)),
|
2023-09-13 07:27:34 +00:00
|
|
|
system_prompt=os.environ.get("SYSTEM_PROMPT"),
|
2023-12-12 08:31:02 +00:00
|
|
|
temperature=float(os.environ.get("TEMPERATURE", 0.8)),
|
2023-09-17 15:48:21 +00:00
|
|
|
lc_admin=os.environ.get("LC_ADMIN"),
|
2023-09-17 04:27:16 +00:00
|
|
|
image_generation_endpoint=os.environ.get("IMAGE_GENERATION_ENDPOINT"),
|
|
|
|
image_generation_backend=os.environ.get("IMAGE_GENERATION_BACKEND"),
|
2023-12-23 13:03:36 +00:00
|
|
|
image_generation_size=os.environ.get("IMAGE_GENERATION_SIZE"),
|
2024-01-04 13:35:28 +00:00
|
|
|
sdwui_steps=int(os.environ.get("SDWUI_STEPS", 20)),
|
|
|
|
sdwui_sampler_name=os.environ.get("SDWUI_SAMPLER_NAME"),
|
|
|
|
sdwui_cfg_scale=float(os.environ.get("SDWUI_CFG_SCALE", 7)),
|
2023-12-23 13:03:36 +00:00
|
|
|
image_format=os.environ.get("IMAGE_FORMAT"),
|
2024-03-08 07:07:01 +00:00
|
|
|
gpt_vision_model=os.environ.get("GPT_VISION_MODEL"),
|
|
|
|
gpt_vision_api_endpoint=os.environ.get("GPT_VISION_API_ENDPOINT"),
|
2023-12-12 08:31:02 +00:00
|
|
|
timeout=float(os.environ.get("TIMEOUT", 120.0)),
|
2023-05-30 01:59:25 +00:00
|
|
|
)
|
|
|
|
if (
|
|
|
|
os.environ.get("IMPORT_KEYS_PATH")
|
|
|
|
and os.environ.get("IMPORT_KEYS_PASSWORD") is not None
|
|
|
|
):
|
2023-04-20 07:39:14 +00:00
|
|
|
need_import_keys = True
|
2023-04-10 02:52:18 +00:00
|
|
|
|
|
|
|
await matrix_bot.login()
|
2023-04-20 07:39:14 +00:00
|
|
|
if need_import_keys:
|
2023-05-20 01:46:16 +00:00
|
|
|
logger.info("start import_keys process, this may take a while...")
|
|
|
|
await matrix_bot.import_keys()
|
2023-09-13 07:27:34 +00:00
|
|
|
|
|
|
|
sync_task = asyncio.create_task(
|
|
|
|
matrix_bot.sync_forever(timeout=30000, full_state=True)
|
|
|
|
)
|
|
|
|
|
|
|
|
# handle signal interrupt
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
for signame in ("SIGINT", "SIGTERM"):
|
|
|
|
loop.add_signal_handler(
|
|
|
|
getattr(signal, signame),
|
|
|
|
lambda: asyncio.create_task(matrix_bot.close(sync_task)),
|
|
|
|
)
|
|
|
|
|
2023-12-05 11:14:26 +00:00
|
|
|
if matrix_bot.client.should_upload_keys:
|
|
|
|
await matrix_bot.client.keys_upload()
|
|
|
|
|
2023-09-13 07:27:34 +00:00
|
|
|
await sync_task
|
2023-03-05 14:07:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-04-10 13:40:39 +00:00
|
|
|
logger.info("matrix chatgpt bot start.....")
|
2023-03-09 16:21:12 +00:00
|
|
|
asyncio.run(main())
|