diff --git a/pantalaimon/client.py b/pantalaimon/client.py index e45c9b3..2f4f4ec 100644 --- a/pantalaimon/client.py +++ b/pantalaimon/client.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Dict from pprint import pformat @@ -6,7 +7,10 @@ from nio import ( RoomEncryptedEvent, MegolmEvent, EncryptionError, - SyncResponse + SyncResponse, + KeysQueryResponse, + LocalProtocolError, + GroupEncryptionError ) from pantalaimon.log import logger @@ -15,6 +19,96 @@ from pantalaimon.log import logger class PantaClient(AsyncClient): """A wrapper class around a nio AsyncClient extending its functionality.""" + def __init__( + self, + homeserver, + user="", + device_id="", + store_path="", + config=None, + ssl=None, + proxy=None + ): + super().__init__(homeserver, user, device_id, store_path, config, + ssl, proxy) + + self.loop_running = False + self.loop_stopped = asyncio.Event() + self.synced = asyncio.Event() + + def verify_devices(self, changed_devices): + # Verify new devices automatically for now. + for user_id, device_dict in changed_devices.items(): + for device in device_dict.values(): + if device.deleted: + continue + + logger.info("Automatically verifying device {} of " + "user {}".format(device.id, user_id)) + self.verify_device(device) + + async def loop(self): + """Start a loop that runs forever and keeps on syncing with the server. + + The loop can be stopped with the stop_loop() method. + """ + self.loop_running = True + self.loop_stopped.clear() + + logger.info(f"Starting sync loop for {self.user_id}") + + while self.loop_running: + if not self.logged_in: + # TODO login + pass + + # TODO use user lazy loading here + response = await self.sync(30000) + + if self.should_upload_keys: + await self.keys_upload() + + if self.should_query_keys: + key_query_response = await self.keys_query() + if isinstance(key_query_response, KeysQueryResponse): + self.verify_devices(key_query_response.changed) + + if not isinstance(response, SyncResponse): + # TODO error handling + pass + + self.synced.set() + self.synced.clear() + + logger.info("Stopping the sync loop") + self.loop_stopped.set() + + async def loop_stop(self): + """Stop the client loop. + + Raises LocalProtocolError if the loop isn't running. + """ + if not self.loop_running: + LocalProtocolError("Loop is not running") + + self.loop_running = False + await self.loop_stopped.wait() + + async def encrypt(self, room_id, msgtype, content): + try: + return super().encrypt( + room_id, + msgtype, + content + ) + except GroupEncryptionError: + await self.share_group_session(room_id) + return super().encrypt( + room_id, + msgtype, + content + ) + def decrypt_sync_body(self, body): # type: (Dict[Any, Any]) -> Dict[Any, Any] """Go through a json sync response and decrypt megolm encrypted events. @@ -34,7 +128,7 @@ class PantaClient(AsyncClient): for event in room_dict["timeline"]["events"]: if event["type"] != "m.room.encrypted": logger.info("Event is not encrypted: " - "{}".format(pformat(event))) + "\n{}".format(pformat(event))) continue parsed_event = RoomEncryptedEvent.parse_event(event) @@ -42,7 +136,7 @@ class PantaClient(AsyncClient): if not isinstance(parsed_event, MegolmEvent): logger.warn("Encrypted event is not a megolm event:" - "{}".format(pformat(event))) + "\n{}".format(pformat(event))) continue try: diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index e8dd328..ea11bd3 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -5,6 +5,7 @@ import asyncio import aiohttp import os import json +import ssl import click from ipaddress import ip_address @@ -26,6 +27,12 @@ from pantalaimon.client import PantaClient from pantalaimon.log import logger +@attr.s +class Client: + user_id = attr.ib(type=str) + access_token = attr.ib(type=str) + + @attr.s class ProxyDaemon: homeserver = attr.ib() @@ -33,7 +40,12 @@ class ProxyDaemon: proxy = attr.ib(default=None) ssl = attr.ib(default=None) - client_sessions = attr.ib(init=False, default=attr.Factory(dict)) + panta_clients = attr.ib(init=False, default=attr.Factory(dict)) + client_info = attr.ib( + init=False, + default=attr.Factory(dict), + type=dict + ) default_session = attr.ib(init=False, default=None) def get_access_token(self, request): @@ -55,7 +67,12 @@ class ProxyDaemon: return access_token - async def forward_request(self, request, session): + async def forward_request( + self, + request, + params=None, + session=None + ): # type: (aiohttp.BaseRequest, aiohttp.ClientSession) -> str """Forward the given request to our configured homeserver. @@ -65,14 +82,21 @@ class ProxyDaemon: session (aiohttp.ClientSession): The client session that should be used to forward the request. """ + + if not session: + if not self.default_session: + self.default_session = ClientSession() + session = self.default_session + path = request.path method = request.method - data = await request.text() headers = CIMultiDict(request.headers) headers.pop("Host", None) - params = request.query + params = params or request.query + + data = await request.text() return await session.request( method, @@ -81,25 +105,16 @@ class ProxyDaemon: params=params, headers=headers, proxy=self.proxy, - ssl=False + ssl=self.ssl ) async def router(self, request): """Catchall request router.""" - session = None + resp = await self.forward_request(request) - token = self.get_access_token(request) - client = self.client_sessions.get(token, None) - - if client: - session = client.client_session - else: - if not self.default_session: - self.default_session = ClientSession() - session = self.default_session - - resp = await self.forward_request(request, session) - return(web.Response(text=await resp.text())) + return( + await self.to_web_response(resp) + ) def _get_login_user(self, body): identifier = body.get("identifier", None) @@ -114,6 +129,35 @@ class ProxyDaemon: return user + async def start_panta_client(self, access_token, user, user_id, password): + client = Client(user_id, access_token) + self.client_info[access_token] = client + + if user_id in self.panta_clients: + logger.info(f"Background sync client already exists for {user_id}," + f" not starting new one") + return + + panta_client = PantaClient( + self.homeserver, + user, + store_path=self.data_dir, + ssl=self.ssl, + proxy=self.proxy + ) + response = await panta_client.login(password, "pantalaimon") + + if not isinstance(response, LoginResponse): + await panta_client.close() + return + + logger.info(f"Succesfully started new background sync client for " + f"{user_id}") + + self.panta_clients[user_id] = panta_client + + loop = asyncio.get_event_loop() + loop.create_task(panta_client.loop()) async def login(self, request): try: @@ -137,28 +181,30 @@ class ProxyDaemon: user = self._get_login_user(body) password = body.get("password", "") - device_id = body.get("device_id", "") - device_name = body.get("initial_device_display_name", "pantalaimon") - client = PantaClient( - self.homeserver, - user, - device_id, - store_path=self.data_dir, - ssl=self.ssl, - proxy=self.proxy - ) + logger.info(f"New user logging in: {user}") - response = await client.login(password, device_name) + response = await self.forward_request(request) - if isinstance(response, LoginResponse): - self.client_sessions[response.access_token] = client - else: - await client.close() + try: + json_response = await response.json() + except JSONDecodeError: + json_response = None + pass + + if response.status == 200 and json_response: + user_id = json_response.get("user_id", None) + access_token = json_response.get("access_token", None) + + if user_id and access_token: + logger.info(f"User: {user} succesfully logged in, starting " + f"a background sync client.") + await self.start_panta_client(access_token, user, user_id, + password) return web.Response( - status=response.transport_response.status, - text=await response.transport_response.text() + status=response.status, + text=await response.text() ) @property @@ -198,7 +244,8 @@ class ProxyDaemon: return self._missing_token try: - client = self.client_sessions[access_token] + client_info = self.client_info[access_token] + client = self.panta_clients[client_info.user_id] except KeyError: return self._unknown_token @@ -223,39 +270,27 @@ class ProxyDaemon: # if timeline_filter: # types_filter = timeline_filter.get("types", None) - response = await client.sync(timeout, sync_filter) + query = CIMultiDict(request.query) + query.pop("filter", None) + + response = await self.forward_request(request, query) + + if response.status == 200: + json_response = await response.json() + json_response = client.decrypt_sync_body(json_response) - if not isinstance(response, SyncResponse): return web.Response( - status=response.transport_response.status, + status=response.status, + text=json.dumps(json_response) + ) + else: + return web.Response( + status=response.status, text=await response.text() ) - if client.should_upload_keys: - await client.keys_upload() - - if client.should_query_keys: - key_query_response = await client.keys_query() - - # Verify new devices automatically for now. - if isinstance(key_query_response, KeysQueryResponse): - for user_id, device_dict in key_query_response.changed.items(): - for device in device_dict.values(): - if device.deleted: - continue - - logger.info("Automatically verifying device {} of " - "user {}".format(device.id, user_id)) - client.verify_device(device) - - json_response = await response.transport_response.json() - - decrypted_response = client.decrypt_sync_body(json_response) - - return web.Response( - status=response.transport_response.status, - text=json.dumps(decrypted_response) - ) + async def to_web_response(self, response): + return web.Response(status=response.status, text=await response.text()) async def send_message(self, request): access_token = self.get_access_token(request) @@ -264,12 +299,26 @@ class ProxyDaemon: return self._missing_token try: - client = self.client_sessions[access_token] + client_info = self.client_info[access_token] + client = self.panta_clients[client_info.user_id] except KeyError: return self._unknown_token - msgtype = request.match_info["event_type"] room_id = request.match_info["room_id"] + + try: + encrypt = client.rooms[room_id].encrypted + except KeyError: + return await self.to_web_response( + await self.forward_request(request) + ) + + if not encrypt: + return await self.to_web_response( + await self.forward_request(request) + ) + + msgtype = request.match_info["event_type"] txnid = request.match_info["txnid"] try: @@ -293,7 +342,8 @@ class ProxyDaemon: This method is called when we shut the whole app down """ - for client in self.client_sessions.values(): + for client in self.panta_clients.values(): + await client.loop_stop() await client.close() if self.default_session: