diff --git a/maubot/matrix.py b/maubot/matrix.py index 756ec95..f93b80d 100644 --- a/maubot/matrix.py +++ b/maubot/matrix.py @@ -23,7 +23,8 @@ from mautrix.client import Client as MatrixClient, SyncStream from mautrix.util.formatter import MatrixParser, MarkdownString, EntityType from mautrix.util import markdown from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent, - MessageType, TextMessageEventContent, Format, RelatesTo) + MessageType, TextMessageEventContent, Format, RelatesTo, EncryptedEvent) +from mautrix.errors import DecryptionError class HumanReadableString(MarkdownString): @@ -132,10 +133,19 @@ class MaubotMatrixClient(MatrixClient): return super().dispatch_event(event, source) async def get_event(self, room_id: RoomID, event_id: EventID) -> Event: - event = await super().get_event(room_id, event_id) - if isinstance(event, MessageEvent): - event.content.trim_reply_fallback() - return MaubotMessageEvent(event, self) + evt = await super().get_event(room_id, event_id) + if isinstance(evt, EncryptedEvent) and self.crypto: + try: + self.crypto_log.trace(f"get_event: Decrypting {evt.event_id} in {evt.room_id}...") + decrypted = await self.crypto.decrypt_megolm_event(evt) + except DecryptionError as e: + self.crypto_log.warning(f"get_event: Failed to decrypt {evt.event_id}: {e}") + return + self.crypto_log.trace(f"get_event: Decrypted {evt.event_id}: {decrypted}") + evt = decrypted + if isinstance(evt, MessageEvent): + evt.content.trim_reply_fallback() + return MaubotMessageEvent(evt, self) else: - event.client = self - return event + evt.client = self + return evt