pantalaimon/pantalaimon/client.py

779 lines
25 KiB
Python

# Copyright 2019 The Matrix.org Foundation CIC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from collections import defaultdict
from pprint import pformat
from typing import Any, Dict, Optional
from aiohttp.client_exceptions import ClientConnectionError
from jsonschema import Draft4Validator, FormatChecker, validators
from nio import (
AsyncClient,
ClientConfig,
EncryptionError,
KeysQueryResponse,
KeyVerificationEvent,
KeyVerificationKey,
KeyVerificationMac,
KeyVerificationStart,
LocalProtocolError,
MegolmEvent,
RoomContextError,
RoomEncryptedEvent,
RoomEncryptedMedia,
RoomMessageMedia,
RoomMessageText,
RoomNameEvent,
RoomTopicEvent,
SyncResponse,
)
from nio.crypto import Sas
from nio.store import SqliteStore
from pantalaimon.index import INDEXING_ENABLED
from pantalaimon.log import logger
from pantalaimon.store import FetchTask
from pantalaimon.thread_messages import (
DaemonResponse,
InviteSasSignal,
SasDoneSignal,
ShowSasSignal,
UpdateDevicesMessage,
)
SEARCH_KEYS = ["content.body", "content.name", "content.topic"]
SEARCH_TERMS_SCHEMA = {
"type": "object",
"properties": {
"search_categories": {
"type": "object",
"properties": {
"room_events": {
"type": "object",
"properties": {
"search_term": {"type": "string"},
"keys": {
"type": "array",
"items": {"type": "string", "enum": SEARCH_KEYS},
"default": SEARCH_KEYS,
},
"order_by": {"type": "string", "default": "rank"},
"include_state": {"type": "boolean", "default": False},
"filter": {"type": "object", "default": {}},
"event_context": {"type": "object"},
"groupings": {"type": "object", "default": {}},
},
"required": ["search_term"],
}
},
},
"required": ["room_events"],
},
"required": ["search_categories"],
}
def extend_with_default(validator_class):
validate_properties = validator_class.VALIDATORS["properties"]
def set_defaults(validator, properties, instance, schema):
for prop, subschema in properties.items():
if "default" in subschema:
instance.setdefault(prop, subschema["default"])
for error in validate_properties(validator, properties, instance, schema):
yield error
return validators.extend(validator_class, {"properties": set_defaults})
Validator = extend_with_default(Draft4Validator)
def validate_json(instance, schema):
"""Validate a dictionary using the provided json schema."""
Validator(schema, format_checker=FormatChecker()).validate(instance)
class UnknownRoomError(Exception):
pass
class InvalidOrderByError(Exception):
pass
class InvalidLimit(Exception):
pass
class PanClient(AsyncClient):
"""A wrapper class around a nio AsyncClient extending its functionality."""
def __init__(
self,
server_name,
pan_store,
pan_conf,
homeserver,
queue=None,
user_id="",
device_id="",
store_path="",
config=None,
ssl=None,
proxy=None,
):
config = config or ClientConfig(store=SqliteStore, store_name="pan.db")
super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy)
index_dir = os.path.join(store_path, server_name, user_id)
try:
os.makedirs(index_dir)
except OSError:
pass
self.server_name = server_name
self.pan_store = pan_store
self.pan_conf = pan_conf
if INDEXING_ENABLED:
logger.info("Indexing enabled.")
from pantalaimon.index import IndexStore
self.index = IndexStore(self.user_id, index_dir)
else:
logger.info("Indexing disabled.")
self.index = None
self.task = None
self.queue = queue
self.room_members_fetched = defaultdict(bool)
self.send_semaphores = defaultdict(asyncio.Semaphore)
self.send_decision_queues = dict() # type: asyncio.Queue
self.last_sync_token = None
self.history_fetcher_task = None
self.history_fetch_queue = asyncio.Queue()
self.add_to_device_callback(self.key_verification_cb, KeyVerificationEvent)
self.add_event_callback(self.undecrypted_event_cb, MegolmEvent)
if INDEXING_ENABLED:
self.add_event_callback(
self.store_message_cb,
(
RoomMessageText,
RoomMessageMedia,
RoomEncryptedMedia,
RoomTopicEvent,
RoomNameEvent,
),
)
self.add_response_callback(self.keys_query_cb, KeysQueryResponse)
self.add_response_callback(self.sync_tasks, SyncResponse)
def store_message_cb(self, room, event):
assert INDEXING_ENABLED
display_name = room.user_name(event.sender)
avatar_url = room.avatar_url(event.sender)
if not room.encrypted and self.pan_conf.index_encrypted_only:
return
self.index.add_event(event, room.room_id, display_name, avatar_url)
@property
def unable_to_decrypt(self):
"""Room event signaling that the message couldn't be decrypted."""
return {
"type": "m.room.message",
"content": {
"msgtype": "m.text",
"body": (
"** Unable to decrypt: The sender's device has not "
"sent us the keys for this message. **"
),
},
}
async def send_message(self, message):
"""Send a thread message to the UI thread."""
await self.queue.put(message)
async def send_update_devices(self, devices):
"""Send a dictionary of devices to the UI thread."""
dict_devices = defaultdict(dict)
for user_devices in devices.values():
for device in user_devices.values():
# Turn the OlmDevice type into a dictionary, flatten the
# keys dict and remove the deleted key/value.
# Since all the keys and values are strings this also
# copies them making it thread safe.
device_dict = device.as_dict()
device_dict = {**device_dict, **device_dict["keys"]}
device_dict.pop("keys")
display_name = device_dict.pop("display_name")
device_dict["device_display_name"] = display_name
dict_devices[device.user_id][device.id] = device_dict
message = UpdateDevicesMessage(self.user_id, dict_devices)
await self.queue.put(message)
async def send_update_device(self, device):
"""Send a single device to the UI thread to be updated."""
await self.send_update_devices({device.user_id: {device.id: device}})
def delete_fetcher_task(self, task):
self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task)
async def fetcher_loop(self):
assert INDEXING_ENABLED
for t in self.pan_store.load_fetcher_tasks(self.server_name, self.user_id):
await self.history_fetch_queue.put(t)
while True:
try:
await asyncio.sleep(self.pan_conf.history_fetch_delay)
fetch_task = await self.history_fetch_queue.get()
try:
room = self.rooms[fetch_task.room_id]
except KeyError:
# The room is missing from our client, we probably left the
# room.
self.delete_fetcher_task(fetch_task)
continue
try:
logger.debug(
"Fetching room history for {}".format(room.display_name)
)
response = await self.room_messages(
fetch_task.room_id,
fetch_task.token,
limit=self.pan_conf.indexing_batch_size,
)
except ClientConnectionError:
self.history_fetch_queue.put(fetch_task)
# The chunk was empty, we're at the start of the timeline.
if not response.chunk:
self.delete_fetcher_task(fetch_task)
continue
for event in response.chunk:
if not isinstance(
event,
(
RoomMessageText,
RoomMessageMedia,
RoomEncryptedMedia,
RoomTopicEvent,
RoomNameEvent,
),
):
continue
display_name = room.user_name(event.sender)
avatar_url = room.avatar_url(event.sender)
self.index.add_event(event, room.room_id, display_name, avatar_url)
last_event = response.chunk[-1]
if not self.index.event_in_store(last_event.event_id, room.room_id):
# There may be even more events to fetch, add a new task to
# the queue.
task = FetchTask(room.room_id, response.end)
self.pan_store.save_fetcher_task(
self.server_name, self.user_id, task
)
await self.history_fetch_queue.put(task)
await self.index.commit_events()
self.delete_fetcher_task(fetch_task)
except asyncio.CancelledError:
return
async def sync_tasks(self, response):
if self.index:
await self.index.commit_events()
self.pan_store.save_token(self.server_name, self.user_id, self.next_batch)
for room_id, room_info in response.rooms.join.items():
if room_info.timeline.limited:
room = self.rooms[room_id]
if not room.encrypted and self.pan_conf.index_encrypted_only:
continue
logger.info(
"Room {} had a limited timeline, queueing "
"room for history fetching.".format(room.display_name)
)
task = FetchTask(room_id, room_info.timeline.prev_batch)
self.pan_store.save_fetcher_task(self.server_name, self.user_id, task)
await self.history_fetch_queue.put(task)
async def keys_query_cb(self, response):
if response.changed:
await self.send_update_devices(response.changed)
async def undecrypted_event_cb(self, room, event):
logger.info(
"Unable to decrypt event from {} via {}.".format(
event.sender, event.device_id
)
)
if event.session_id not in self.outgoing_key_requests:
logger.info("Requesting room key for undecrypted event.")
# TODO we may want to retry this
try:
await self.request_room_key(event)
except ClientConnectionError:
pass
async def key_verification_cb(self, event):
logger.info("Received key verification event: {}".format(event))
if isinstance(event, KeyVerificationStart):
logger.info(
f"{event.sender} via {event.from_device} has started "
f"a key verification process."
)
message = InviteSasSignal(
self.user_id, event.sender, event.from_device, event.transaction_id
)
await self.queue.put(message)
elif isinstance(event, KeyVerificationKey):
sas = self.key_verifications.get(event.transaction_id, None)
if not sas:
return
device = sas.other_olm_device
emoji = sas.get_emoji()
message = ShowSasSignal(
self.user_id, device.user_id, device.id, sas.transaction_id, emoji
)
await self.queue.put(message)
elif isinstance(event, KeyVerificationMac):
sas = self.key_verifications.get(event.transaction_id, None)
if not sas:
return
device = sas.other_olm_device
if sas.verified:
await self.send_message(
SasDoneSignal(
self.user_id, device.user_id, device.id, sas.transaction_id
)
)
await self.send_update_device(device)
def start_loop(self, loop_sleep_time=None):
"""Start a loop that runs forever and keeps on syncing with the server.
The loop can be stopped with the stop_loop() method.
"""
assert not self.task
logger.info(f"Starting sync loop for {self.user_id}")
loop = asyncio.get_event_loop()
if INDEXING_ENABLED:
self.history_fetcher_task = loop.create_task(self.fetcher_loop())
timeout = 30000
sync_filter = {"room": {"state": {"lazy_load_members": True}}}
next_batch = self.pan_store.load_token(self.server_name, self.user_id)
self.last_sync_token = next_batch
# We don't store any room state so initial sync needs to be with the
# full_state parameter. Subsequent ones are normal.
task = loop.create_task(
self.sync_forever(
timeout,
sync_filter,
full_state=True,
since=next_batch,
loop_sleep_time=loop_sleep_time
)
)
self.task = task
return task
async def start_sas(self, message, device):
try:
await self.start_key_verification(device)
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
"m.ok",
"Successfully started the key verification request",
)
)
except ClientConnectionError as e:
await self.send_message(
DaemonResponse(
message.message_id, self.user_id, "m.connection_error", str(e)
)
)
async def accept_sas(self, message):
user_id = message.user_id
device_id = message.device_id
sas = self.get_active_sas(user_id, device_id)
if not sas:
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
Sas._txid_error[0],
Sas._txid_error[1],
)
)
return
try:
await self.accept_key_verification(sas.transaction_id)
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
"m.ok",
"Successfully accepted the key verification request",
)
)
except LocalProtocolError as e:
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
Sas._unexpected_message_error[0],
str(e),
)
)
except ClientConnectionError as e:
await self.send_message(
DaemonResponse(
message.message_id, self.user_id, "m.connection_error", str(e)
)
)
async def cancel_sas(self, message):
user_id = message.user_id
device_id = message.device_id
sas = self.get_active_sas(user_id, device_id)
if not sas:
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
Sas._txid_error[0],
Sas._txid_error[1],
)
)
return
try:
await self.cancel_key_verification(sas.transaction_id)
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
"m.ok",
"Successfully canceled the key verification request",
)
)
except ClientConnectionError as e:
await self.send_message(
DaemonResponse(
message.message_id, self.user_id, "m.connection_error", str(e)
)
)
async def confirm_sas(self, message):
user_id = message.user_id
device_id = message.device_id
sas = self.get_active_sas(user_id, device_id)
if not sas:
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
Sas._txid_error[0],
Sas._txid_error[1],
)
)
return
try:
await self.confirm_short_auth_string(sas.transaction_id)
except ClientConnectionError as e:
await self.send_message(
DaemonResponse(
message.message_id, self.user_id, "m.connection_error", str(e)
)
)
return
device = sas.other_olm_device
if sas.verified:
await self.send_update_device(device)
await self.send_message(
SasDoneSignal(
self.user_id, device.user_id, device.id, sas.transaction_id
)
)
else:
await self.send_message(
DaemonResponse(
message.message_id,
self.user_id,
"m.ok",
f"Waiting for {device.user_id} to confirm.",
)
)
async def loop_stop(self):
"""Stop the client loop."""
logger.info("Stopping the sync loop")
if self.task and not self.task.done():
self.task.cancel()
try:
await self.task
except KeyboardInterrupt:
pass
self.task = None
if self.history_fetcher_task and not self.history_fetcher_task.done():
self.history_fetcher_task.cancel()
try:
await self.history_fetcher_task
except KeyboardInterrupt:
pass
self.history_fetcher_task = None
self.history_fetch_queue = asyncio.Queue()
def pan_decrypt_event(self, event_dict, room_id=None, ignore_failures=True):
# type: (Dict[Any, Any], Optional[str], bool) -> (bool)
event = RoomEncryptedEvent.parse_event(event_dict)
if not isinstance(event, MegolmEvent):
logger.warn(
"Encrypted event is not a megolm event:"
"\n{}".format(pformat(event_dict))
)
return False
if not event.room_id:
event.room_id = room_id
try:
decrypted_event = self.decrypt_event(event)
logger.info("Decrypted event: {}".format(decrypted_event))
event_dict.update(decrypted_event.source)
event_dict["decrypted"] = True
event_dict["verified"] = decrypted_event.verified
return True
except EncryptionError as error:
logger.warn(error)
if ignore_failures:
event_dict.update(self.unable_to_decrypt)
else:
raise
return False
def decrypt_messages_body(self, body, ignore_failures=True):
# type: (Dict[Any, Any], bool) -> Dict[Any, Any]
"""Go through a messages response and decrypt megolm encrypted events.
Args:
body (Dict[Any, Any]): The dictionary of a Sync response.
Returns the json response with decrypted events.
"""
if "chunk" not in body:
return body
logger.info("Decrypting room messages")
for event in body["chunk"]:
if "type" not in event:
continue
if event["type"] != "m.room.encrypted":
logger.debug("Event is not encrypted: " "\n{}".format(pformat(event)))
continue
self.pan_decrypt_event(event, ignore_failures=ignore_failures)
return body
def decrypt_sync_body(self, body, ignore_failures=True):
# type: (Dict[Any, Any], bool) -> Dict[Any, Any]
"""Go through a json sync response and decrypt megolm encrypted events.
Args:
body (Dict[Any, Any]): The dictionary of a Sync response.
Returns the json response with decrypted events.
"""
logger.info("Decrypting sync")
for room_id, room_dict in body["rooms"]["join"].items():
try:
if not self.rooms[room_id].encrypted:
logger.info(
"Room {} is not encrypted skipping...".format(
self.rooms[room_id].display_name
)
)
continue
except KeyError:
logger.info("Unknown room {} skipping...".format(room_id))
continue
for event in room_dict["timeline"]["events"]:
if "type" not in event:
continue
if event["type"] != "m.room.encrypted":
continue
self.pan_decrypt_event(event, room_id, ignore_failures)
return body
async def search(self, search_terms):
# type: (Dict[Any, Any]) -> Dict[Any, Any]
assert INDEXING_ENABLED
state_cache = dict()
async def add_context(event_dict, room_id, event_id, include_state):
try:
context = await self.room_context(room_id, event_id, limit=0)
except ClientConnectionError:
return
if isinstance(context, RoomContextError):
return
if include_state:
state_cache[room_id] = [e.source for e in context.state]
event_dict["context"]["start"] = context.start
event_dict["context"]["end"] = context.end
search_terms = search_terms["search_categories"]["room_events"]
term = search_terms["search_term"]
search_filter = search_terms["filter"]
limit = search_filter.get("limit", 10)
if limit <= 0:
raise InvalidLimit("The limit must be strictly greater than 0.")
rooms = search_filter.get("rooms", [])
room_id = rooms[0] if len(rooms) == 1 else None
order_by = search_terms.get("order_by")
if order_by not in ["rank", "recent"]:
raise InvalidOrderByError(f"Invalid order by: {order_by}")
order_by_recent = order_by == "recent"
before_limit = 0
after_limit = 0
include_profile = False
event_context = search_terms.get("event_context")
include_state = search_terms.get("include_state")
if event_context:
before_limit = event_context.get("before_limit", 5)
after_limit = event_context.get("before_limit", 5)
if before_limit < 0 or after_limit < 0:
raise InvalidLimit(
"Invalid context limit, the limit must be a " "positive number"
)
response_dict = await self.index.search(
term,
room=room_id,
max_results=limit,
order_by_recent=order_by_recent,
include_profile=include_profile,
before_limit=before_limit,
after_limit=after_limit,
)
if (event_context or include_state) and self.pan_conf.search_requests:
for event_dict in response_dict["results"]:
await add_context(
event_dict,
event_dict["result"]["room_id"],
event_dict["result"]["event_id"],
include_state,
)
if include_state:
response_dict["state"] = state_cache
return {"search_categories": {"room_events": response_dict}}