mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-01-09 14:39:34 -05:00
01e0c9363e
Event callbacks are only called for events that come through a sync resposne. Since events in the scrollback might be decryptable there's no reason to not store the media info for those as well.
1014 lines
33 KiB
Python
1014 lines
33 KiB
Python
# Copyright 2019 The Matrix.org Foundation CIC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import asyncio
|
|
import os
|
|
from collections import defaultdict
|
|
from pprint import pformat
|
|
from typing import Any, Dict, Optional
|
|
from urllib.parse import urlparse
|
|
|
|
from aiohttp.client_exceptions import ClientConnectionError
|
|
from jsonschema import Draft4Validator, FormatChecker, validators
|
|
from playhouse.sqliteq import SqliteQueueDatabase
|
|
from nio import (
|
|
AsyncClient,
|
|
AsyncClientConfig,
|
|
EncryptionError,
|
|
Event,
|
|
ToDeviceEvent,
|
|
KeysQueryResponse,
|
|
KeyVerificationEvent,
|
|
KeyVerificationKey,
|
|
KeyVerificationMac,
|
|
KeyVerificationStart,
|
|
LocalProtocolError,
|
|
MegolmEvent,
|
|
RoomContextError,
|
|
RoomEncryptedMedia,
|
|
RoomEncryptedImage,
|
|
RoomEncryptedFile,
|
|
RoomEncryptedVideo,
|
|
RoomMessageMedia,
|
|
RoomMessageText,
|
|
RoomNameEvent,
|
|
RoomTopicEvent,
|
|
RoomKeyRequest,
|
|
RoomKeyRequestCancellation,
|
|
SyncResponse,
|
|
)
|
|
from nio.crypto import Sas
|
|
from nio.store import SqliteStore
|
|
|
|
from pantalaimon.index import INDEXING_ENABLED
|
|
from pantalaimon.log import logger
|
|
from pantalaimon.store import FetchTask, MediaInfo
|
|
from pantalaimon.thread_messages import (
|
|
DaemonResponse,
|
|
InviteSasSignal,
|
|
SasDoneSignal,
|
|
ShowSasSignal,
|
|
UpdateDevicesMessage,
|
|
KeyRequestMessage,
|
|
ContinueKeyShare,
|
|
CancelKeyShare,
|
|
)
|
|
|
|
SEARCH_KEYS = ["content.body", "content.name", "content.topic"]
|
|
|
|
SEARCH_TERMS_SCHEMA = {
|
|
"type": "object",
|
|
"properties": {
|
|
"search_categories": {
|
|
"type": "object",
|
|
"properties": {
|
|
"room_events": {
|
|
"type": "object",
|
|
"properties": {
|
|
"search_term": {"type": "string"},
|
|
"keys": {
|
|
"type": "array",
|
|
"items": {"type": "string", "enum": SEARCH_KEYS},
|
|
"default": SEARCH_KEYS,
|
|
},
|
|
"order_by": {"type": "string", "default": "rank"},
|
|
"include_state": {"type": "boolean", "default": False},
|
|
"filter": {"type": "object", "default": {}},
|
|
"event_context": {"type": "object"},
|
|
"groupings": {"type": "object", "default": {}},
|
|
},
|
|
"required": ["search_term"],
|
|
}
|
|
},
|
|
},
|
|
"required": ["room_events"],
|
|
},
|
|
"required": ["search_categories"],
|
|
}
|
|
|
|
|
|
def extend_with_default(validator_class):
|
|
validate_properties = validator_class.VALIDATORS["properties"]
|
|
|
|
def set_defaults(validator, properties, instance, schema):
|
|
for prop, subschema in properties.items():
|
|
if "default" in subschema:
|
|
instance.setdefault(prop, subschema["default"])
|
|
|
|
for error in validate_properties(validator, properties, instance, schema):
|
|
yield error
|
|
|
|
return validators.extend(validator_class, {"properties": set_defaults})
|
|
|
|
|
|
Validator = extend_with_default(Draft4Validator)
|
|
|
|
|
|
def validate_json(instance, schema):
|
|
"""Validate a dictionary using the provided json schema."""
|
|
Validator(schema, format_checker=FormatChecker()).validate(instance)
|
|
|
|
|
|
class UnknownRoomError(Exception):
|
|
pass
|
|
|
|
|
|
class InvalidOrderByError(Exception):
|
|
pass
|
|
|
|
|
|
class InvalidLimit(Exception):
|
|
pass
|
|
|
|
|
|
class SqliteQStore(SqliteStore):
|
|
def _create_database(self):
|
|
return SqliteQueueDatabase(
|
|
self.database_path, pragmas=(("foregign_keys", 1), ("secure_delete", 1))
|
|
)
|
|
|
|
def close(self):
|
|
self.database.stop()
|
|
|
|
|
|
class PanClient(AsyncClient):
|
|
"""A wrapper class around a nio AsyncClient extending its functionality."""
|
|
|
|
def __init__(
|
|
self,
|
|
server_name,
|
|
pan_store,
|
|
pan_conf,
|
|
homeserver,
|
|
queue=None,
|
|
user_id="",
|
|
device_id="",
|
|
store_path="",
|
|
config=None,
|
|
ssl=None,
|
|
proxy=None,
|
|
store_class=None,
|
|
media_info=None,
|
|
):
|
|
config = config or AsyncClientConfig(
|
|
store=store_class or SqliteStore, store_name="pan.db"
|
|
)
|
|
super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy)
|
|
|
|
index_dir = os.path.join(store_path, server_name, user_id)
|
|
|
|
try:
|
|
os.makedirs(index_dir)
|
|
except OSError:
|
|
pass
|
|
|
|
self.server_name = server_name
|
|
self.pan_store = pan_store
|
|
self.pan_conf = pan_conf
|
|
self.media_info = media_info
|
|
|
|
if INDEXING_ENABLED:
|
|
logger.info("Indexing enabled.")
|
|
from pantalaimon.index import IndexStore
|
|
|
|
self.index = IndexStore(self.user_id, index_dir)
|
|
else:
|
|
logger.info("Indexing disabled.")
|
|
self.index = None
|
|
|
|
self.task = None
|
|
self.queue = queue
|
|
|
|
# Those two events are mainly used for testing.
|
|
self.new_fetch_task = asyncio.Event()
|
|
self.fetch_loop_event = asyncio.Event()
|
|
|
|
self.room_members_fetched = defaultdict(bool)
|
|
|
|
self.send_semaphores = defaultdict(asyncio.Semaphore)
|
|
self.send_decision_queues = dict() # type: asyncio.Queue
|
|
self.last_sync_token = None
|
|
|
|
self.history_fetcher_task = None
|
|
self.history_fetch_queue = asyncio.Queue()
|
|
|
|
self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent)
|
|
self.add_to_device_callback(
|
|
self.key_request_cb, (RoomKeyRequest, RoomKeyRequestCancellation)
|
|
)
|
|
self.add_event_callback(self.undecrypted_event_cb, MegolmEvent)
|
|
self.add_event_callback(
|
|
self.store_thumbnail_cb,
|
|
(RoomEncryptedImage, RoomEncryptedVideo, RoomEncryptedFile),
|
|
)
|
|
|
|
if INDEXING_ENABLED:
|
|
self.add_event_callback(
|
|
self.store_message_cb,
|
|
(
|
|
RoomMessageText,
|
|
RoomMessageMedia,
|
|
RoomEncryptedMedia,
|
|
RoomTopicEvent,
|
|
RoomNameEvent,
|
|
),
|
|
)
|
|
|
|
self.add_response_callback(self.keys_query_cb, KeysQueryResponse)
|
|
self.add_response_callback(self.sync_tasks, SyncResponse)
|
|
|
|
def store_message_cb(self, room, event):
|
|
assert INDEXING_ENABLED
|
|
|
|
display_name = room.user_name(event.sender)
|
|
avatar_url = room.avatar_url(event.sender)
|
|
|
|
if not room.encrypted and self.pan_conf.index_encrypted_only:
|
|
return
|
|
|
|
self.index.add_event(event, room.room_id, display_name, avatar_url)
|
|
|
|
def store_thumbnail_cb(self, room, event):
|
|
if not (
|
|
event.thumbnail_url
|
|
and event.thumbnail_key
|
|
and event.thumbnail_iv
|
|
and event.thumbnail_hashes
|
|
):
|
|
return
|
|
|
|
try:
|
|
mxc = urlparse(event.thumbnail_url)
|
|
except ValueError:
|
|
return
|
|
|
|
if mxc is None:
|
|
return
|
|
|
|
mxc_server = mxc.netloc.strip("/")
|
|
mxc_path = mxc.path.strip("/")
|
|
|
|
media = MediaInfo(
|
|
mxc_server,
|
|
mxc_path,
|
|
event.thumbnail_key,
|
|
event.thumbnail_iv,
|
|
event.thumbnail_hashes,
|
|
)
|
|
self.media_info[(mxc_server, mxc_path)] = media
|
|
self.pan_store.save_media(self.server_name, media)
|
|
|
|
def store_event_media(self, event):
|
|
try:
|
|
mxc = urlparse(event.url)
|
|
except ValueError:
|
|
return
|
|
|
|
if mxc is None:
|
|
return
|
|
|
|
mxc_server = mxc.netloc.strip("/")
|
|
mxc_path = mxc.path.strip("/")
|
|
|
|
logger.info(f"Adding media info for {mxc_server}/{mxc_path} to the store")
|
|
|
|
media = MediaInfo(mxc_server, mxc_path, event.key, event.iv, event.hashes)
|
|
self.media_info[(mxc_server, mxc_path)] = media
|
|
self.pan_store.save_media(self.server_name, media)
|
|
|
|
@property
|
|
def unable_to_decrypt(self):
|
|
"""Room event signaling that the message couldn't be decrypted."""
|
|
return {
|
|
"type": "m.room.message",
|
|
"content": {
|
|
"msgtype": "m.text",
|
|
"body": (
|
|
"** Unable to decrypt: The sender's device has not "
|
|
"sent us the keys for this message. **"
|
|
),
|
|
},
|
|
}
|
|
|
|
async def send_message(self, message):
|
|
"""Send a thread message to the UI thread."""
|
|
if self.queue:
|
|
await self.queue.put(message)
|
|
|
|
async def send_update_devices(self, devices):
|
|
"""Send a dictionary of devices to the UI thread."""
|
|
dict_devices = defaultdict(dict)
|
|
|
|
for user_devices in devices.values():
|
|
for device in user_devices.values():
|
|
# Turn the OlmDevice type into a dictionary, flatten the
|
|
# keys dict and remove the deleted key/value.
|
|
# Since all the keys and values are strings this also
|
|
# copies them making it thread safe.
|
|
device_dict = device.as_dict()
|
|
device_dict = {**device_dict, **device_dict["keys"]}
|
|
device_dict.pop("keys")
|
|
display_name = device_dict.pop("display_name")
|
|
device_dict["device_display_name"] = display_name
|
|
dict_devices[device.user_id][device.id] = device_dict
|
|
|
|
message = UpdateDevicesMessage(self.user_id, dict_devices)
|
|
await self.send_message(message)
|
|
|
|
async def send_update_device(self, device):
|
|
"""Send a single device to the UI thread to be updated."""
|
|
await self.send_update_devices({device.user_id: {device.id: device}})
|
|
|
|
def delete_fetcher_task(self, task):
|
|
self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task)
|
|
|
|
async def fetcher_loop(self):
|
|
assert INDEXING_ENABLED
|
|
|
|
for t in self.pan_store.load_fetcher_tasks(self.server_name, self.user_id):
|
|
await self.history_fetch_queue.put(t)
|
|
|
|
while True:
|
|
self.fetch_loop_event.set()
|
|
self.fetch_loop_event.clear()
|
|
|
|
try:
|
|
await asyncio.sleep(self.pan_conf.history_fetch_delay)
|
|
fetch_task = await self.history_fetch_queue.get()
|
|
|
|
try:
|
|
room = self.rooms[fetch_task.room_id]
|
|
except KeyError:
|
|
# The room is missing from our client, we probably left the
|
|
# room.
|
|
self.delete_fetcher_task(fetch_task)
|
|
continue
|
|
|
|
try:
|
|
logger.debug(
|
|
f"Fetching room history for {room.display_name} "
|
|
f"({room.room_id}), token {fetch_task.token}."
|
|
)
|
|
response = await self.room_messages(
|
|
fetch_task.room_id,
|
|
fetch_task.token,
|
|
limit=self.pan_conf.indexing_batch_size,
|
|
)
|
|
except ClientConnectionError as e:
|
|
logger.debug("Error fetching room history: ", e)
|
|
await self.history_fetch_queue.put(fetch_task)
|
|
|
|
# The chunk was empty, we're at the start of the timeline.
|
|
if not response.chunk:
|
|
self.delete_fetcher_task(fetch_task)
|
|
continue
|
|
|
|
for event in response.chunk:
|
|
if not isinstance(
|
|
event,
|
|
(
|
|
RoomMessageText,
|
|
RoomMessageMedia,
|
|
RoomEncryptedMedia,
|
|
RoomTopicEvent,
|
|
RoomNameEvent,
|
|
),
|
|
):
|
|
continue
|
|
|
|
display_name = room.user_name(event.sender)
|
|
avatar_url = room.avatar_url(event.sender)
|
|
self.index.add_event(event, room.room_id, display_name, avatar_url)
|
|
|
|
last_event = response.chunk[-1]
|
|
|
|
if not self.index.event_in_store(last_event.event_id, room.room_id):
|
|
# There may be even more events to fetch, add a new task to
|
|
# the queue.
|
|
task = FetchTask(room.room_id, response.end)
|
|
self.pan_store.replace_fetcher_task(
|
|
self.server_name, self.user_id, fetch_task, task
|
|
)
|
|
await self.history_fetch_queue.put(task)
|
|
self.new_fetch_task.set()
|
|
self.new_fetch_task.clear()
|
|
else:
|
|
await self.index.commit_events()
|
|
self.delete_fetcher_task(fetch_task)
|
|
|
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
return
|
|
|
|
async def sync_tasks(self, response):
|
|
if self.index:
|
|
await self.index.commit_events()
|
|
|
|
if self.last_sync_token == self.next_batch:
|
|
return
|
|
|
|
self.last_sync_token = self.next_batch
|
|
|
|
self.pan_store.save_token(self.server_name, self.user_id, self.next_batch)
|
|
|
|
for room_id, room_info in response.rooms.join.items():
|
|
if room_info.timeline.limited:
|
|
room = self.rooms[room_id]
|
|
|
|
if not room.encrypted and self.pan_conf.index_encrypted_only:
|
|
continue
|
|
|
|
logger.info(
|
|
"Room {} had a limited timeline, queueing "
|
|
"room for history fetching.".format(room.display_name)
|
|
)
|
|
task = FetchTask(room_id, room_info.timeline.prev_batch)
|
|
self.pan_store.save_fetcher_task(self.server_name, self.user_id, task)
|
|
|
|
await self.history_fetch_queue.put(task)
|
|
self.new_fetch_task.set()
|
|
self.new_fetch_task.clear()
|
|
|
|
async def keys_query_cb(self, response):
|
|
if response.changed:
|
|
await self.send_update_devices(response.changed)
|
|
|
|
async def undecrypted_event_cb(self, room, event):
|
|
logger.info(
|
|
"Unable to decrypt event from {} via {}.".format(
|
|
event.sender, event.device_id
|
|
)
|
|
)
|
|
|
|
if event.session_id not in self.outgoing_key_requests:
|
|
logger.info("Requesting room key for undecrypted event.")
|
|
|
|
# TODO we may want to retry this
|
|
try:
|
|
await self.request_room_key(event)
|
|
except ClientConnectionError:
|
|
pass
|
|
|
|
async def key_request_cb(self, event):
|
|
if isinstance(event, RoomKeyRequest):
|
|
logger.info(
|
|
f"{event.sender} via {event.requesting_device_id} has "
|
|
f" requested room keys from us."
|
|
)
|
|
|
|
message = KeyRequestMessage(self.user_id, event)
|
|
await self.send_message(message)
|
|
|
|
elif isinstance(event, RoomKeyRequestCancellation):
|
|
logger.info(
|
|
f"{event.sender} via {event.requesting_device_id} has "
|
|
f" canceled its key request."
|
|
)
|
|
|
|
message = KeyRequestMessage(self.user_id, event)
|
|
await self.send_message(message)
|
|
|
|
else:
|
|
assert False
|
|
|
|
async def key_verification_cb(self, event):
|
|
logger.info("Received key verification event: {}".format(event))
|
|
if isinstance(event, KeyVerificationStart):
|
|
logger.info(
|
|
f"{event.sender} via {event.from_device} has started "
|
|
f"a key verification process."
|
|
)
|
|
|
|
message = InviteSasSignal(
|
|
self.user_id, event.sender, event.from_device, event.transaction_id
|
|
)
|
|
|
|
await self.send_message(message)
|
|
|
|
elif isinstance(event, KeyVerificationKey):
|
|
sas = self.key_verifications.get(event.transaction_id, None)
|
|
if not sas:
|
|
return
|
|
|
|
device = sas.other_olm_device
|
|
emoji = sas.get_emoji()
|
|
|
|
message = ShowSasSignal(
|
|
self.user_id, device.user_id, device.id, sas.transaction_id, emoji
|
|
)
|
|
|
|
await self.send_message(message)
|
|
|
|
elif isinstance(event, KeyVerificationMac):
|
|
sas = self.key_verifications.get(event.transaction_id, None)
|
|
if not sas:
|
|
return
|
|
device = sas.other_olm_device
|
|
|
|
if sas.verified:
|
|
await self.send_message(
|
|
SasDoneSignal(
|
|
self.user_id, device.user_id, device.id, sas.transaction_id
|
|
)
|
|
)
|
|
await self.send_update_device(device)
|
|
|
|
def start_loop(self, loop_sleep_time=100):
|
|
"""Start a loop that runs forever and keeps on syncing with the server.
|
|
|
|
The loop can be stopped with the stop_loop() method.
|
|
"""
|
|
assert not self.task
|
|
|
|
logger.info(f"Starting sync loop for {self.user_id}")
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
if INDEXING_ENABLED:
|
|
self.history_fetcher_task = loop.create_task(self.fetcher_loop())
|
|
|
|
timeout = 30000
|
|
sync_filter = {"room": {"state": {"lazy_load_members": True}}}
|
|
next_batch = self.pan_store.load_token(self.server_name, self.user_id)
|
|
self.last_sync_token = next_batch
|
|
|
|
# We don't store any room state so initial sync needs to be with the
|
|
# full_state parameter. Subsequent ones are normal.
|
|
task = loop.create_task(
|
|
self.sync_forever(
|
|
timeout,
|
|
sync_filter,
|
|
full_state=True,
|
|
since=next_batch,
|
|
loop_sleep_time=loop_sleep_time,
|
|
)
|
|
)
|
|
self.task = task
|
|
|
|
return task
|
|
|
|
async def start_sas(self, message, device):
|
|
try:
|
|
await self.start_key_verification(device)
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id,
|
|
self.user_id,
|
|
"m.ok",
|
|
"Successfully started the key verification request",
|
|
)
|
|
)
|
|
except ClientConnectionError as e:
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id, self.user_id, "m.connection_error", str(e)
|
|
)
|
|
)
|
|
|
|
async def accept_sas(self, message):
|
|
user_id = message.user_id
|
|
device_id = message.device_id
|
|
|
|
sas = self.get_active_sas(user_id, device_id)
|
|
|
|
if not sas:
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id,
|
|
self.user_id,
|
|
Sas._txid_error[0],
|
|
Sas._txid_error[1],
|
|
)
|
|
)
|
|
return
|
|
|
|
try:
|
|
await self.accept_key_verification(sas.transaction_id)
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id,
|
|
self.user_id,
|
|
"m.ok",
|
|
"Successfully accepted the key verification request",
|
|
)
|
|
)
|
|
except LocalProtocolError as e:
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id,
|
|
self.user_id,
|
|
Sas._unexpected_message_error[0],
|
|
str(e),
|
|
)
|
|
)
|
|
except ClientConnectionError as e:
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id, self.user_id, "m.connection_error", str(e)
|
|
)
|
|
)
|
|
|
|
async def cancel_sas(self, message):
|
|
user_id = message.user_id
|
|
device_id = message.device_id
|
|
|
|
sas = self.get_active_sas(user_id, device_id)
|
|
|
|
if not sas:
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id,
|
|
self.user_id,
|
|
Sas._txid_error[0],
|
|
Sas._txid_error[1],
|
|
)
|
|
)
|
|
return
|
|
|
|
try:
|
|
await self.cancel_key_verification(sas.transaction_id)
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id,
|
|
self.user_id,
|
|
"m.ok",
|
|
"Successfully canceled the key verification request",
|
|
)
|
|
)
|
|
except ClientConnectionError as e:
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id, self.user_id, "m.connection_error", str(e)
|
|
)
|
|
)
|
|
|
|
async def confirm_sas(self, message):
|
|
user_id = message.user_id
|
|
device_id = message.device_id
|
|
|
|
sas = self.get_active_sas(user_id, device_id)
|
|
|
|
if not sas:
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id,
|
|
self.user_id,
|
|
Sas._txid_error[0],
|
|
Sas._txid_error[1],
|
|
)
|
|
)
|
|
return
|
|
|
|
try:
|
|
await self.confirm_short_auth_string(sas.transaction_id)
|
|
except ClientConnectionError as e:
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id, self.user_id, "m.connection_error", str(e)
|
|
)
|
|
)
|
|
|
|
return
|
|
|
|
device = sas.other_olm_device
|
|
|
|
if sas.verified:
|
|
await self.send_update_device(device)
|
|
await self.send_message(
|
|
SasDoneSignal(
|
|
self.user_id, device.user_id, device.id, sas.transaction_id
|
|
)
|
|
)
|
|
else:
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id,
|
|
self.user_id,
|
|
"m.ok",
|
|
f"Waiting for {device.user_id} to confirm.",
|
|
)
|
|
)
|
|
|
|
async def handle_key_request_message(self, message):
|
|
if isinstance(message, ContinueKeyShare):
|
|
continued = False
|
|
for share in self.get_active_key_requests(
|
|
message.user_id, message.device_id
|
|
):
|
|
|
|
continued = True
|
|
|
|
if not self.continue_key_share(share):
|
|
await self.send_message(
|
|
DaemonResponse(
|
|
message.message_id,
|
|
self.user_id,
|
|
"m.error",
|
|
(
|
|
f"Unable to continue the key sharing for "
|
|
f"{message.user_id} via {message.device_id}: The "
|
|
f"device is still not verified."
|
|
),
|
|
)
|
|
)
|
|
return
|
|
|
|
if continued:
|
|
try:
|
|
await self.send_to_device_messages()
|
|
except ClientConnectionError:
|
|
# We can safely ignore this since this will be retried
|
|
# after the next sync in the sync_forever method.
|
|
pass
|
|
|
|
response = (
|
|
f"Succesfully continued the key requests from "
|
|
f"{message.user_id} via {message.device_id}"
|
|
)
|
|
ret = "m.ok"
|
|
else:
|
|
response = (
|
|
f"No active key requests from {message.user_id} "
|
|
f"via {message.device_id} found."
|
|
)
|
|
ret = "m.error"
|
|
|
|
await self.send_message(
|
|
DaemonResponse(message.message_id, self.user_id, ret, response)
|
|
)
|
|
|
|
elif isinstance(message, CancelKeyShare):
|
|
cancelled = False
|
|
|
|
for share in self.get_active_key_requests(
|
|
message.user_id, message.device_id
|
|
):
|
|
cancelled = self.cancel_key_share(share)
|
|
|
|
if cancelled:
|
|
response = (
|
|
f"Succesfully cancelled key requests from "
|
|
f"{message.user_id} via {message.device_id}"
|
|
)
|
|
ret = "m.ok"
|
|
else:
|
|
response = (
|
|
f"No active key requests from {message.user_id} "
|
|
f"via {message.device_id} found."
|
|
)
|
|
ret = "m.error"
|
|
|
|
await self.send_message(
|
|
DaemonResponse(message.message_id, self.user_id, ret, response)
|
|
)
|
|
|
|
async def loop_stop(self):
|
|
"""Stop the client loop."""
|
|
logger.info("Stopping the sync loop")
|
|
|
|
if self.task and not self.task.done():
|
|
self.task.cancel()
|
|
|
|
try:
|
|
await self.task
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
self.task = None
|
|
|
|
if self.history_fetcher_task and not self.history_fetcher_task.done():
|
|
self.history_fetcher_task.cancel()
|
|
|
|
try:
|
|
await self.history_fetcher_task
|
|
except KeyboardInterrupt:
|
|
pass
|
|
|
|
self.history_fetcher_task = None
|
|
|
|
if isinstance(self.store, SqliteQueueDatabase):
|
|
self.store.close()
|
|
|
|
self.history_fetch_queue = asyncio.Queue()
|
|
|
|
def pan_decrypt_event(self, event_dict, room_id=None, ignore_failures=True):
|
|
# type: (Dict[Any, Any], Optional[str], bool) -> (bool)
|
|
event = Event.parse_encrypted_event(event_dict)
|
|
|
|
if not isinstance(event, MegolmEvent):
|
|
logger.warn(
|
|
"Encrypted event is not a megolm event:"
|
|
"\n{}".format(pformat(event_dict))
|
|
)
|
|
return False
|
|
|
|
if not event.room_id:
|
|
event.room_id = room_id
|
|
|
|
try:
|
|
decrypted_event = self.decrypt_event(event)
|
|
logger.debug("Decrypted event: {}".format(decrypted_event))
|
|
logger.info(
|
|
"Decrypted event from {} in {}, event id: {}".format(
|
|
decrypted_event.sender,
|
|
decrypted_event.room_id,
|
|
decrypted_event.event_id,
|
|
)
|
|
)
|
|
|
|
if isinstance(decrypted_event, RoomEncryptedMedia):
|
|
self.store_event_media(decrypted_event)
|
|
|
|
decrypted_event.source["content"]["url"] = decrypted_event.url
|
|
|
|
if decrypted_event.thumbnail_url:
|
|
decrypted_event.source["content"]["info"][
|
|
"thumbnail_url"
|
|
] = decrypted_event.thumbnail_url
|
|
|
|
event_dict.update(decrypted_event.source)
|
|
event_dict["decrypted"] = True
|
|
event_dict["verified"] = decrypted_event.verified
|
|
|
|
return True
|
|
|
|
except EncryptionError as error:
|
|
logger.warn(error)
|
|
|
|
if ignore_failures:
|
|
event_dict.update(self.unable_to_decrypt)
|
|
else:
|
|
raise
|
|
|
|
return False
|
|
|
|
def decrypt_messages_body(self, body, ignore_failures=True):
|
|
# type: (Dict[Any, Any], bool) -> Dict[Any, Any]
|
|
"""Go through a messages response and decrypt megolm encrypted events.
|
|
|
|
Args:
|
|
body (Dict[Any, Any]): The dictionary of a Sync response.
|
|
|
|
Returns the json response with decrypted events.
|
|
"""
|
|
if "chunk" not in body:
|
|
return body
|
|
|
|
logger.info("Decrypting room messages")
|
|
|
|
for event in body["chunk"]:
|
|
if "type" not in event:
|
|
continue
|
|
|
|
if event["type"] != "m.room.encrypted":
|
|
logger.debug("Event is not encrypted: " "\n{}".format(pformat(event)))
|
|
continue
|
|
|
|
self.pan_decrypt_event(event, ignore_failures=ignore_failures)
|
|
|
|
return body
|
|
|
|
def handle_to_device_from_sync_body(self, body):
|
|
to_device_events = body.get("to_device")
|
|
|
|
if not to_device_events or "events" not in to_device_events:
|
|
return
|
|
|
|
for event in to_device_events["events"]:
|
|
event = ToDeviceEvent.parse_encrypted_event(event)
|
|
|
|
if not isinstance(event, ToDeviceEvent):
|
|
continue
|
|
|
|
self.olm.handle_to_device_event(event)
|
|
|
|
def decrypt_sync_body(self, body, ignore_failures=True):
|
|
# type: (Dict[Any, Any], bool) -> Dict[Any, Any]
|
|
"""Go through a json sync response and decrypt megolm encrypted events.
|
|
|
|
Args:
|
|
body (Dict[Any, Any]): The dictionary of a Sync response.
|
|
|
|
Returns the json response with decrypted events.
|
|
"""
|
|
logger.info("Decrypting sync")
|
|
|
|
self.handle_to_device_from_sync_body(body)
|
|
|
|
for room_id, room_dict in body["rooms"]["join"].items():
|
|
try:
|
|
if not self.rooms[room_id].encrypted:
|
|
logger.info(
|
|
"Room {} is not encrypted skipping...".format(
|
|
self.rooms[room_id].display_name
|
|
)
|
|
)
|
|
continue
|
|
except KeyError:
|
|
# We don't know if the room is encrypted or not, probably
|
|
# because the client sync stream got to join the room before the
|
|
# pan sync stream did. Let's assume that the room is encrypted.
|
|
pass
|
|
|
|
for event in room_dict["timeline"]["events"]:
|
|
if "type" not in event:
|
|
continue
|
|
|
|
if event["type"] != "m.room.encrypted":
|
|
continue
|
|
|
|
self.pan_decrypt_event(event, room_id, ignore_failures)
|
|
|
|
return body
|
|
|
|
async def search(self, search_terms):
|
|
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
|
assert INDEXING_ENABLED
|
|
|
|
state_cache = dict()
|
|
|
|
async def add_context(event_dict, room_id, event_id, include_state):
|
|
try:
|
|
context = await self.room_context(room_id, event_id, limit=0)
|
|
except ClientConnectionError:
|
|
return
|
|
|
|
if isinstance(context, RoomContextError):
|
|
return
|
|
|
|
if include_state:
|
|
state_cache[room_id] = [e.source for e in context.state]
|
|
|
|
event_dict["context"]["start"] = context.start
|
|
event_dict["context"]["end"] = context.end
|
|
|
|
search_terms = search_terms["search_categories"]["room_events"]
|
|
|
|
term = search_terms["search_term"]
|
|
search_filter = search_terms["filter"]
|
|
limit = search_filter.get("limit", 10)
|
|
|
|
if limit <= 0:
|
|
raise InvalidLimit("The limit must be strictly greater than 0.")
|
|
|
|
rooms = search_filter.get("rooms", [])
|
|
|
|
room_id = rooms[0] if len(rooms) == 1 else None
|
|
|
|
order_by = search_terms.get("order_by")
|
|
|
|
if order_by not in ["rank", "recent"]:
|
|
raise InvalidOrderByError(f"Invalid order by: {order_by}")
|
|
|
|
order_by_recent = order_by == "recent"
|
|
|
|
before_limit = 0
|
|
after_limit = 0
|
|
include_profile = False
|
|
|
|
event_context = search_terms.get("event_context")
|
|
include_state = search_terms.get("include_state")
|
|
|
|
if event_context:
|
|
before_limit = event_context.get("before_limit", 5)
|
|
after_limit = event_context.get("before_limit", 5)
|
|
|
|
if before_limit < 0 or after_limit < 0:
|
|
raise InvalidLimit(
|
|
"Invalid context limit, the limit must be a " "positive number"
|
|
)
|
|
|
|
response_dict = await self.index.search(
|
|
term,
|
|
room=room_id,
|
|
max_results=limit,
|
|
order_by_recent=order_by_recent,
|
|
include_profile=include_profile,
|
|
before_limit=before_limit,
|
|
after_limit=after_limit,
|
|
)
|
|
|
|
if (event_context or include_state) and self.pan_conf.search_requests:
|
|
for event_dict in response_dict["results"]:
|
|
await add_context(
|
|
event_dict,
|
|
event_dict["result"]["room_id"],
|
|
event_dict["result"]["event_id"],
|
|
include_state,
|
|
)
|
|
|
|
if include_state:
|
|
response_dict["state"] = state_cache
|
|
|
|
return {"search_categories": {"room_events": response_dict}}
|