diff --git a/.full-env.example b/.full-env.example index de666f9..45f2d9e 100644 --- a/.full-env.example +++ b/.full-env.example @@ -1,6 +1,7 @@ HOMESERVER="https://matrix-client.matrix.org" USER_ID="@lullap:xxxxxxxxxxxxx.xxx" PASSWORD="xxxxxxxxxxxxxxx" +ACCESS_TOKEN="xxxxxxxxxxx" DEVICE_ID="xxxxxxxxxxxxxx" ROOM_ID="!FYCmBSkCRUXXXXXXXXX:matrix.XXX.XXX" IMPORT_KEYS_PATH="element-keys.txt" diff --git a/.gitignore b/.gitignore index 9f22577..f0bc170 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,4 @@ cython_debug/ # Custom sync_db manage_db +element-keys.txt diff --git a/full-config.json.example b/full-config.json.example index aef2104..4d7d708 100644 --- a/full-config.json.example +++ b/full-config.json.example @@ -2,6 +2,7 @@ "homeserver": "https://matrix-client.matrix.org", "user_id": "@lullap:xxxxx.org", "password": "xxxxxxxxxxxxxxxxxx", + "access_token": "xxxxxxxxxxxxxx", "device_id": "MatrixChatGPTBot", "room_id": "!xxxxxxxxxxxxxxxxxxxxxx:xxxxx.org", "import_keys_path": "element-keys.txt", diff --git a/src/bot.py b/src/bot.py index e62c611..7ee7590 100644 --- a/src/bot.py +++ b/src/bot.py @@ -26,6 +26,7 @@ from nio import ( MegolmEvent, RoomMessageText, ToDeviceError, + WhoamiResponse, ) from nio.store.database import SqliteStore @@ -48,8 +49,9 @@ class Bot: self, homeserver: str, user_id: str, + device_id: str, password: Union[str, None] = None, - device_id: str = "MatrixChatGPTBot", + access_token: Union[str, None] = None, room_id: Union[str, None] = None, import_keys_path: Optional[str] = None, import_keys_password: Optional[str] = None, @@ -72,7 +74,7 @@ class Bot: logger.warning("homeserver && user_id && device_id is required") sys.exit(1) - if password is None: + if password is None and access_token is None: logger.warning("password is required") sys.exit(1) @@ -87,6 +89,7 @@ class Bot: self.homeserver: str = homeserver self.user_id: str = user_id self.password: str = password + self.access_token: str = access_token self.device_id: str = device_id self.room_id: str = room_id @@ -1418,13 +1421,33 @@ class Bot: # bot login async def login(self) -> None: - resp = await self.client.login(password=self.password, device_name=DEVICE_NAME) - if not isinstance(resp, LoginResponse): - logger.error("Login Failed") - await self.httpx_client.aclose() - await self.client.close() + try: + if self.password is not None: + resp = await self.client.login( + password=self.password, device_name=DEVICE_NAME + ) + if not isinstance(resp, LoginResponse): + logger.error("Login Failed") + await self.httpx_client.aclose() + await self.client.close() + sys.exit(1) + logger.info("Successfully login via password") + elif self.access_token is not None: + self.client.restore_login( + user_id=self.user_id, + device_id=self.device_id, + access_token=self.access_token, + ) + resp = await self.client.whoami() + if not isinstance(resp, WhoamiResponse): + logger.error("Login Failed") + await self.close() + sys.exit(1) + logger.info("Successfully login via access_token") + except Exception as e: + logger.error(e) + await self.close() sys.exit(1) - logger.info("Success login via password") # import keys async def import_keys(self): @@ -1434,9 +1457,7 @@ class Bot: if isinstance(resp, EncryptionError): logger.error(f"import_keys failed with {resp}") else: - logger.info( - "import_keys success, please remove import_keys configuration!!!" - ) + logger.info("import_keys success, you can remove import_keys configuration") # sync messages in the room async def sync_forever(self, timeout=30000, full_state=True) -> None: diff --git a/src/main.py b/src/main.py index ac4d73d..3641283 100644 --- a/src/main.py +++ b/src/main.py @@ -26,6 +26,7 @@ async def main(): homeserver=config.get("homeserver"), user_id=config.get("user_id"), password=config.get("password"), + access_token=config.get("access_token"), device_id=config.get("device_id"), room_id=config.get("room_id"), import_keys_path=config.get("import_keys_path"), @@ -56,6 +57,7 @@ async def main(): homeserver=os.environ.get("HOMESERVER"), user_id=os.environ.get("USER_ID"), password=os.environ.get("PASSWORD"), + access_token=os.environ.get("ACCESS_TOKEN"), device_id=os.environ.get("DEVICE_ID"), room_id=os.environ.get("ROOM_ID"), import_keys_path=os.environ.get("IMPORT_KEYS_PATH"), @@ -98,6 +100,9 @@ async def main(): lambda: asyncio.create_task(matrix_bot.close(sync_task)), ) + if matrix_bot.client.should_upload_keys: + await matrix_bot.client.keys_upload() + await sync_task