# -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.logging.opentracing import ( get_active_span_text_map, log_kv, set_tag, start_active_span, ) from synapse.types import UserID, get_domain_from_id from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) class DeviceMessageHandler(object): def __init__(self, hs): """ Args: hs (synapse.server.HomeServer): server """ self.store = hs.get_datastore() self.notifier = hs.get_notifier() self.is_mine = hs.is_mine self.federation = hs.get_federation_sender() hs.get_federation_registry().register_edu_handler( "m.direct_to_device", self.on_direct_to_device_edu ) @defer.inlineCallbacks def on_direct_to_device_edu(self, origin, content): local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): logger.warn( "Dropping device message from %r with spoofed sender %r", origin, sender_user_id, ) message_type = content["type"] message_id = content["message_id"] for user_id, by_device in content["messages"].items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") messages_by_device = { device_id: { "content": message_content, "type": message_type, "sender": sender_user_id, } for device_id, message_content in by_device.items() } if messages_by_device: local_messages[user_id] = messages_by_device stream_id = yield self.store.add_messages_from_remote_to_device_inbox( origin, message_id, local_messages ) self.notifier.on_new_event( "to_device_key", stream_id, users=local_messages.keys() ) @defer.inlineCallbacks def send_device_message(self, sender_user_id, message_type, messages): set_tag("number_of_messages", len(messages)) set_tag("sender", sender_user_id) local_messages = {} remote_messages = {} for user_id, by_device in messages.items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): messages_by_device = { device_id: { "content": message_content, "type": message_type, "sender": sender_user_id, } for device_id, message_content in by_device.items() } if messages_by_device: local_messages[user_id] = messages_by_device else: destination = get_domain_from_id(user_id) remote_messages.setdefault(destination, {})[user_id] = by_device message_id = random_string(16) context = get_active_span_text_map() remote_edu_contents = {} for destination, messages in remote_messages.items(): with start_active_span("to_device_for_user"): set_tag("destination", destination) remote_edu_contents[destination] = { "messages": messages, "sender": sender_user_id, "type": message_type, "message_id": message_id, "org.matrix.opentracing_context": json.dumps(context), } log_kv({"local_messages": local_messages}) stream_id = yield self.store.add_messages_to_device_inbox( local_messages, remote_edu_contents ) self.notifier.on_new_event( "to_device_key", stream_id, users=local_messages.keys() ) log_kv({"remote_messages": remote_messages}) for destination in remote_messages.keys(): # Enqueue a new federation transaction to send the new # device messages to each remote destination. self.federation.send_device_messages(destination)