diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index f00185f..b0949e4 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -12,7 +12,8 @@ import keyring from aiohttp import ClientSession, web from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError from multidict import CIMultiDict -from nio import EncryptionError, LoginResponse, SendRetryError, OlmTrustError +from nio import (EncryptionError, LoginResponse, SendRetryError, OlmTrustError, + Api) from pantalaimon.client import PanClient from pantalaimon.log import logger @@ -61,8 +62,6 @@ class ProxyDaemon: self.store = PanStore(self.data_dir) accounts = self.store.load_users(self.name) - self.client_info = self.store.load_clients(self.name) - for user_id, device_id in accounts: token = keyring.get_password( "pantalaimon", @@ -92,6 +91,50 @@ class ProxyDaemon: pan_client.start_loop() + async def _find_client(self, access_token): + client_info = self.client_info.get(access_token, None) + + if not client_info: + async with aiohttp.ClientSession() as session: + try: + method, path = Api.whoami(access_token) + resp = await session.request( + method, + self.homeserver_url + path, + proxy=self.proxy, + ssl=self.ssl + ) + except ClientConnectionError: + return None + + if resp.status != 200: + return None + + try: + body = await resp.json() + except (JSONDecodeError, ContentTypeError): + return None + + try: + user_id = body["user_id"] + except KeyError: + return None + + if user_id not in self.pan_clients: + logger.warn(f"User {user_id} doesn't have a matching pan " + f"client.") + return None + + logger.info(f"Homeserver confirmed valid access token " + f"for user {user_id}, caching info.") + + client_info = ClientInfo(user_id, access_token) + self.client_info[access_token] = client_info + + client = self.pan_clients.get(client_info.user_id, None) + + return client + async def _verify_device(self, message_id, client, device): ret = client.verify_device(device) @@ -424,7 +467,6 @@ class ProxyDaemon: async def start_pan_client(self, access_token, user, user_id, password): client = ClientInfo(user_id, access_token) self.client_info[access_token] = client - self.store.save_client(self.name, client) self.store.save_server_user(self.name, user_id) if user_id in self.pan_clients: @@ -578,10 +620,8 @@ class ProxyDaemon: if not access_token: return self._missing_token - try: - client_info = self.client_info[access_token] - client = self.pan_clients[client_info.user_id] - except KeyError: + client = await self._find_client(access_token) + if not client: return self._unknown_token sync_filter = request.query.get("filter", None) @@ -631,10 +671,8 @@ class ProxyDaemon: if not access_token: return self._missing_token - try: - client_info = self.client_info[access_token] - client = self.pan_clients[client_info.user_id] - except KeyError: + client = await self._find_client(access_token) + if not client: return self._unknown_token try: @@ -670,10 +708,8 @@ class ProxyDaemon: if not access_token: return self._missing_token - try: - client_info = self.client_info[access_token] - client = self.pan_clients[client_info.user_id] - except KeyError: + client = await self._find_client(access_token) + if not client: return self._unknown_token room_id = request.match_info["room_id"]