# Copyright 2019 The Matrix.org Foundation CIC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import os from collections import defaultdict from pprint import pformat from typing import Any, Dict, Optional from aiohttp.client_exceptions import ClientConnectionError from jsonschema import Draft4Validator, FormatChecker, validators from nio import ( AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse, KeyVerificationEvent, KeyVerificationKey, KeyVerificationMac, KeyVerificationStart, LocalProtocolError, MegolmEvent, RoomContextError, RoomEncryptedEvent, RoomEncryptedMedia, RoomMessageMedia, RoomMessageText, RoomNameEvent, RoomTopicEvent, SyncResponse, ) from nio.crypto import Sas from nio.store import SqliteStore from pantalaimon.index import INDEXING_ENABLED from pantalaimon.log import logger from pantalaimon.store import FetchTask from pantalaimon.thread_messages import ( DaemonResponse, InviteSasSignal, SasDoneSignal, ShowSasSignal, UpdateDevicesMessage, ) SEARCH_KEYS = ["content.body", "content.name", "content.topic"] SEARCH_TERMS_SCHEMA = { "type": "object", "properties": { "search_categories": { "type": "object", "properties": { "room_events": { "type": "object", "properties": { "search_term": {"type": "string"}, "keys": { "type": "array", "items": {"type": "string", "enum": SEARCH_KEYS}, "default": SEARCH_KEYS, }, "order_by": {"type": "string", "default": "rank"}, "include_state": {"type": "boolean", "default": False}, "filter": {"type": "object", "default": {}}, "event_context": {"type": "object"}, "groupings": {"type": "object", "default": {}}, }, "required": ["search_term"], } }, }, "required": ["room_events"], }, "required": ["search_categories"], } def extend_with_default(validator_class): validate_properties = validator_class.VALIDATORS["properties"] def set_defaults(validator, properties, instance, schema): for prop, subschema in properties.items(): if "default" in subschema: instance.setdefault(prop, subschema["default"]) for error in validate_properties(validator, properties, instance, schema): yield error return validators.extend(validator_class, {"properties": set_defaults}) Validator = extend_with_default(Draft4Validator) def validate_json(instance, schema): """Validate a dictionary using the provided json schema.""" Validator(schema, format_checker=FormatChecker()).validate(instance) class UnknownRoomError(Exception): pass class InvalidOrderByError(Exception): pass class InvalidLimit(Exception): pass class PanClient(AsyncClient): """A wrapper class around a nio AsyncClient extending its functionality.""" def __init__( self, server_name, pan_store, pan_conf, homeserver, queue=None, user_id="", device_id="", store_path="", config=None, ssl=None, proxy=None, ): config = config or ClientConfig(store=SqliteStore, store_name="pan.db") super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy) index_dir = os.path.join(store_path, server_name, user_id) try: os.makedirs(index_dir) except OSError: pass self.server_name = server_name self.pan_store = pan_store self.pan_conf = pan_conf if INDEXING_ENABLED: logger.info("Indexing enabled.") from pantalaimon.index import IndexStore self.index = IndexStore(self.user_id, index_dir) else: logger.info("Indexing disabled.") self.index = None self.task = None self.queue = queue self.room_members_fetched = defaultdict(bool) self.send_semaphores = defaultdict(asyncio.Semaphore) self.send_decision_queues = dict() # type: asyncio.Queue self.last_sync_token = None self.history_fetcher_task = None self.history_fetch_queue = asyncio.Queue() self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent) self.add_event_callback(self.undecrypted_event_cb, MegolmEvent) if INDEXING_ENABLED: self.add_event_callback( self.store_message_cb, ( RoomMessageText, RoomMessageMedia, RoomEncryptedMedia, RoomTopicEvent, RoomNameEvent, ), ) self.add_response_callback(self.keys_query_cb, KeysQueryResponse) self.add_response_callback(self.sync_tasks, SyncResponse) def store_message_cb(self, room, event): assert INDEXING_ENABLED display_name = room.user_name(event.sender) avatar_url = room.avatar_url(event.sender) if not room.encrypted and self.pan_conf.index_encrypted_only: return self.index.add_event(event, room.room_id, display_name, avatar_url) @property def unable_to_decrypt(self): """Room event signaling that the message couldn't be decrypted.""" return { "type": "m.room.message", "content": { "msgtype": "m.text", "body": ( "** Unable to decrypt: The sender's device has not " "sent us the keys for this message. **" ), }, } async def send_message(self, message): """Send a thread message to the UI thread.""" await self.queue.put(message) async def send_update_devices(self, devices): """Send a dictionary of devices to the UI thread.""" dict_devices = defaultdict(dict) for user_devices in devices.values(): for device in user_devices.values(): # Turn the OlmDevice type into a dictionary, flatten the # keys dict and remove the deleted key/value. # Since all the keys and values are strings this also # copies them making it thread safe. device_dict = device.as_dict() device_dict = {**device_dict, **device_dict["keys"]} device_dict.pop("keys") display_name = device_dict.pop("display_name") device_dict["device_display_name"] = display_name dict_devices[device.user_id][device.id] = device_dict message = UpdateDevicesMessage(self.user_id, dict_devices) await self.queue.put(message) async def send_update_device(self, device): """Send a single device to the UI thread to be updated.""" await self.send_update_devices({device.user_id: {device.id: device}}) def delete_fetcher_task(self, task): self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task) async def fetcher_loop(self): assert INDEXING_ENABLED for t in self.pan_store.load_fetcher_tasks(self.server_name, self.user_id): await self.history_fetch_queue.put(t) while True: try: await asyncio.sleep(self.pan_conf.history_fetch_delay) fetch_task = await self.history_fetch_queue.get() try: room = self.rooms[fetch_task.room_id] except KeyError: # The room is missing from our client, we probably left the # room. self.delete_fetcher_task(fetch_task) continue try: logger.debug( "Fetching room history for {}".format(room.display_name) ) response = await self.room_messages( fetch_task.room_id, fetch_task.token, limit=self.pan_conf.indexing_batch_size, ) except ClientConnectionError: self.history_fetch_queue.put(fetch_task) # The chunk was empty, we're at the start of the timeline. if not response.chunk: self.delete_fetcher_task(fetch_task) continue for event in response.chunk: if not isinstance( event, ( RoomMessageText, RoomMessageMedia, RoomEncryptedMedia, RoomTopicEvent, RoomNameEvent, ), ): continue display_name = room.user_name(event.sender) avatar_url = room.avatar_url(event.sender) self.index.add_event(event, room.room_id, display_name, avatar_url) last_event = response.chunk[-1] if not self.index.event_in_store(last_event.event_id, room.room_id): # There may be even more events to fetch, add a new task to # the queue. task = FetchTask(room.room_id, response.end) self.pan_store.save_fetcher_task( self.server_name, self.user_id, task ) await self.history_fetch_queue.put(task) await self.index.commit_events() self.delete_fetcher_task(fetch_task) except asyncio.CancelledError: return async def sync_tasks(self, response): if self.index: await self.index.commit_events() self.pan_store.save_token(self.server_name, self.user_id, self.next_batch) for room_id, room_info in response.rooms.join.items(): if room_info.timeline.limited: room = self.rooms[room_id] if not room.encrypted and self.pan_conf.index_encrypted_only: continue logger.info( "Room {} had a limited timeline, queueing " "room for history fetching.".format(room.display_name) ) task = FetchTask(room_id, room_info.timeline.prev_batch) self.pan_store.save_fetcher_task(self.server_name, self.user_id, task) await self.history_fetch_queue.put(task) async def keys_query_cb(self, response): if response.changed: await self.send_update_devices(response.changed) async def undecrypted_event_cb(self, room, event): logger.info( "Unable to decrypt event from {} via {}.".format( event.sender, event.device_id ) ) if event.session_id not in self.outgoing_key_requests: logger.info("Requesting room key for undecrypted event.") # TODO we may want to retry this try: await self.request_room_key(event) except ClientConnectionError: pass async def key_verification_cb(self, event): logger.info("Received key verification event: {}".format(event)) if isinstance(event, KeyVerificationStart): logger.info( f"{event.sender} via {event.from_device} has started " f"a key verification process." ) message = InviteSasSignal( self.user_id, event.sender, event.from_device, event.transaction_id ) await self.queue.put(message) elif isinstance(event, KeyVerificationKey): sas = self.key_verifications.get(event.transaction_id, None) if not sas: return device = sas.other_olm_device emoji = sas.get_emoji() message = ShowSasSignal( self.user_id, device.user_id, device.id, sas.transaction_id, emoji ) await self.queue.put(message) elif isinstance(event, KeyVerificationMac): sas = self.key_verifications.get(event.transaction_id, None) if not sas: return device = sas.other_olm_device if sas.verified: await self.send_message( SasDoneSignal( self.user_id, device.user_id, device.id, sas.transaction_id ) ) await self.send_update_device(device) def start_loop(self, loop_sleep_time=None): """Start a loop that runs forever and keeps on syncing with the server. The loop can be stopped with the stop_loop() method. """ assert not self.task logger.info(f"Starting sync loop for {self.user_id}") loop = asyncio.get_event_loop() if INDEXING_ENABLED: self.history_fetcher_task = loop.create_task(self.fetcher_loop()) timeout = 30000 sync_filter = {"room": {"state": {"lazy_load_members": True}}} next_batch = self.pan_store.load_token(self.server_name, self.user_id) self.last_sync_token = next_batch # We don't store any room state so initial sync needs to be with the # full_state parameter. Subsequent ones are normal. task = loop.create_task( self.sync_forever( timeout, sync_filter, full_state=True, since=next_batch, loop_sleep_time=loop_sleep_time ) ) self.task = task return task async def start_sas(self, message, device): try: await self.start_key_verification(device) await self.send_message( DaemonResponse( message.message_id, self.user_id, "m.ok", "Successfully started the key verification request", ) ) except ClientConnectionError as e: await self.send_message( DaemonResponse( message.message_id, self.user_id, "m.connection_error", str(e) ) ) async def accept_sas(self, message): user_id = message.user_id device_id = message.device_id sas = self.get_active_sas(user_id, device_id) if not sas: await self.send_message( DaemonResponse( message.message_id, self.user_id, Sas._txid_error[0], Sas._txid_error[1], ) ) return try: await self.accept_key_verification(sas.transaction_id) await self.send_message( DaemonResponse( message.message_id, self.user_id, "m.ok", "Successfully accepted the key verification request", ) ) except LocalProtocolError as e: await self.send_message( DaemonResponse( message.message_id, self.user_id, Sas._unexpected_message_error[0], str(e), ) ) except ClientConnectionError as e: await self.send_message( DaemonResponse( message.message_id, self.user_id, "m.connection_error", str(e) ) ) async def cancel_sas(self, message): user_id = message.user_id device_id = message.device_id sas = self.get_active_sas(user_id, device_id) if not sas: await self.send_message( DaemonResponse( message.message_id, self.user_id, Sas._txid_error[0], Sas._txid_error[1], ) ) return try: await self.cancel_key_verification(sas.transaction_id) await self.send_message( DaemonResponse( message.message_id, self.user_id, "m.ok", "Successfully canceled the key verification request", ) ) except ClientConnectionError as e: await self.send_message( DaemonResponse( message.message_id, self.user_id, "m.connection_error", str(e) ) ) async def confirm_sas(self, message): user_id = message.user_id device_id = message.device_id sas = self.get_active_sas(user_id, device_id) if not sas: await self.send_message( DaemonResponse( message.message_id, self.user_id, Sas._txid_error[0], Sas._txid_error[1], ) ) return try: await self.confirm_short_auth_string(sas.transaction_id) except ClientConnectionError as e: await self.send_message( DaemonResponse( message.message_id, self.user_id, "m.connection_error", str(e) ) ) return device = sas.other_olm_device if sas.verified: await self.send_update_device(device) await self.send_message( SasDoneSignal( self.user_id, device.user_id, device.id, sas.transaction_id ) ) else: await self.send_message( DaemonResponse( message.message_id, self.user_id, "m.ok", f"Waiting for {device.user_id} to confirm.", ) ) async def loop_stop(self): """Stop the client loop.""" logger.info("Stopping the sync loop") if self.task and not self.task.done(): self.task.cancel() try: await self.task except KeyboardInterrupt: pass self.task = None if self.history_fetcher_task and not self.history_fetcher_task.done(): self.history_fetcher_task.cancel() try: await self.history_fetcher_task except KeyboardInterrupt: pass self.history_fetcher_task = None self.history_fetch_queue = asyncio.Queue() def pan_decrypt_event(self, event_dict, room_id=None, ignore_failures=True): # type: (Dict[Any, Any], Optional[str], bool) -> (bool) event = RoomEncryptedEvent.parse_event(event_dict) if not isinstance(event, MegolmEvent): logger.warn( "Encrypted event is not a megolm event:" "\n{}".format(pformat(event_dict)) ) return False if not event.room_id: event.room_id = room_id try: decrypted_event = self.decrypt_event(event) logger.info("Decrypted event: {}".format(decrypted_event)) event_dict.update(decrypted_event.source) event_dict["decrypted"] = True event_dict["verified"] = decrypted_event.verified return True except EncryptionError as error: logger.warn(error) if ignore_failures: event_dict.update(self.unable_to_decrypt) else: raise return False def decrypt_messages_body(self, body, ignore_failures=True): # type: (Dict[Any, Any], bool) -> Dict[Any, Any] """Go through a messages response and decrypt megolm encrypted events. Args: body (Dict[Any, Any]): The dictionary of a Sync response. Returns the json response with decrypted events. """ if "chunk" not in body: return body logger.info("Decrypting room messages") for event in body["chunk"]: if "type" not in event: continue if event["type"] != "m.room.encrypted": logger.debug("Event is not encrypted: " "\n{}".format(pformat(event))) continue self.pan_decrypt_event(event, ignore_failures=ignore_failures) return body def decrypt_sync_body(self, body, ignore_failures=True): # type: (Dict[Any, Any], bool) -> Dict[Any, Any] """Go through a json sync response and decrypt megolm encrypted events. Args: body (Dict[Any, Any]): The dictionary of a Sync response. Returns the json response with decrypted events. """ logger.info("Decrypting sync") for room_id, room_dict in body["rooms"]["join"].items(): try: if not self.rooms[room_id].encrypted: logger.info( "Room {} is not encrypted skipping...".format( self.rooms[room_id].display_name ) ) continue except KeyError: logger.info("Unknown room {} skipping...".format(room_id)) continue for event in room_dict["timeline"]["events"]: if "type" not in event: continue if event["type"] != "m.room.encrypted": continue self.pan_decrypt_event(event, room_id, ignore_failures) return body async def search(self, search_terms): # type: (Dict[Any, Any]) -> Dict[Any, Any] assert INDEXING_ENABLED state_cache = dict() async def add_context(event_dict, room_id, event_id, include_state): try: context = await self.room_context(room_id, event_id, limit=0) except ClientConnectionError: return if isinstance(context, RoomContextError): return if include_state: state_cache[room_id] = [e.source for e in context.state] event_dict["context"]["start"] = context.start event_dict["context"]["end"] = context.end search_terms = search_terms["search_categories"]["room_events"] term = search_terms["search_term"] search_filter = search_terms["filter"] limit = search_filter.get("limit", 10) if limit <= 0: raise InvalidLimit("The limit must be strictly greater than 0.") rooms = search_filter.get("rooms", []) room_id = rooms[0] if len(rooms) == 1 else None order_by = search_terms.get("order_by") if order_by not in ["rank", "recent"]: raise InvalidOrderByError(f"Invalid order by: {order_by}") order_by_recent = order_by == "recent" before_limit = 0 after_limit = 0 include_profile = False event_context = search_terms.get("event_context") include_state = search_terms.get("include_state") if event_context: before_limit = event_context.get("before_limit", 5) after_limit = event_context.get("before_limit", 5) if before_limit < 0 or after_limit < 0: raise InvalidLimit( "Invalid context limit, the limit must be a " "positive number" ) response_dict = await self.index.search( term, room=room_id, max_results=limit, order_by_recent=order_by_recent, include_profile=include_profile, before_limit=before_limit, after_limit=after_limit, ) if (event_context or include_state) and self.pan_conf.search_requests: for event_dict in response_dict["results"]: await add_context( event_dict, event_dict["result"]["room_id"], event_dict["result"]["event_id"], include_state, ) if include_state: response_dict["state"] = state_cache return {"search_categories": {"room_events": response_dict}}