diff --git a/pantalaimon/client.py b/pantalaimon/client.py new file mode 100644 index 0000000..fdd8329 --- /dev/null +++ b/pantalaimon/client.py @@ -0,0 +1,69 @@ +from typing import Any, Dict + +from nio import ( + AsyncClient, + RoomEncryptedEvent, + MegolmEvent, + EncryptionError, + SyncResponse +) + + +class PantaClient(AsyncClient): + """A wrapper class around a nio AsyncClient extending its functionality.""" + + def decrypt_sync_body(self, body): + # type: (Dict[Any, Any]) -> Dict[Any, Any] + """Go through a json sync response and decrypt megolm encrypted events. + + Args: + body (Dict[Any, Any]): The dictionary of a Sync response. + + Returns the json response with decrypted events. + """ + for room_id, room_dict in body["rooms"]["join"].items(): + if not self.rooms[room_id].encrypted: + print("Room {} not encrypted skipping...".format( + self.rooms[room_id].display_name + )) + continue + + for event in room_dict["timeline"]["events"]: + if event["type"] != "m.room.encrypted": + print("Event not encrypted skipping...") + continue + + parsed_event = RoomEncryptedEvent.parse_event(event) + parsed_event.room_id = room_id + + if not isinstance(parsed_event, MegolmEvent): + print("Not a megolm event.") + continue + + try: + decrypted_event = self.decrypt_event(parsed_event) + print("Decrypted event: {}".format(decrypted_event)) + event["type"] = "m.room.message" + + # TODO support other event types + # This should be best done in nio, modify events so they + # keep the dictionary from which they are built in a source + # attribute. + event["content"] = { + "msgtype": "m.text", + "body": decrypted_event.body + } + + if decrypted_event.formatted_body: + event["content"]["formatted_body"] = ( + decrypted_event.formatted_body) + event["content"]["format"] = decrypted_event.format + + event["decrypted"] = True + event["verified"] = decrypted_event.verified + + except EncryptionError as error: + print("ERROR decrypting {}".format(error)) + continue + + return body diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index e1b31b9..b213f62 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -12,19 +12,17 @@ from urllib.parse import urlparse from aiohttp import web, ClientSession from nio import ( - AsyncClient, LoginResponse, KeysQueryResponse, GroupEncryptionError, - RoomEncryptedEvent, - MegolmEvent, - EncryptionError, SyncResponse ) from appdirs import user_data_dir from json import JSONDecodeError from multidict import CIMultiDict +from pantalaimon.client import PantaClient + @attr.s class ProxyDaemon: @@ -130,7 +128,7 @@ class ProxyDaemon: device_id = body.get("device_id", "") device_name = body.get("initial_device_display_name", "pantalaimon") - client = AsyncClient( + client = PantaClient( self.homeserver, user, device_id, @@ -241,51 +239,11 @@ class ProxyDaemon: json_response = await response.transport_response.json() - for room_id, room_dict in json_response["rooms"]["join"].items(): - if not client.rooms[room_id].encrypted: - print("Room {} not encrypted skipping...".format( - client.rooms[room_id].display_name - )) - continue - - for event in room_dict["timeline"]["events"]: - if event["type"] != "m.room.encrypted": - print("Event not encrypted skipping...") - continue - - parsed_event = RoomEncryptedEvent.parse_event(event) - parsed_event.room_id = room_id - - if not isinstance(parsed_event, MegolmEvent): - print("Not a megolm event.") - continue - - try: - decrypted_event = client.decrypt_event(parsed_event) - print("Decrypted event: {}".format(decrypted_event)) - event["type"] = "m.room.message" - - # TODO support other event types - event["content"] = { - "msgtype": "m.text", - "body": decrypted_event.body - } - - if decrypted_event.formatted_body: - event["content"]["formatted_body"] = ( - decrypted_event.formatted_body) - event["content"]["format"] = decrypted_event.format - - event["decrypted"] = True - event["verified"] = decrypted_event.verified - - except EncryptionError as e: - print("ERROR decrypting {}".format(e)) - continue + decrypted_response = client.decrypt_sync_body(json_response) return web.Response( status=response.transport_response.status, - text=json.dumps(json_response) + text=json.dumps(decrypted_response) ) async def send_message(self, request):