Send device messages over federation

This commit is contained in:
Mark Haines 2016-09-06 18:16:20 +01:00
parent e020834e4f
commit d4a35ada28
7 changed files with 178 additions and 47 deletions

View File

@ -188,7 +188,7 @@ class FederationServer(FederationBase):
except SynapseError as e: except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e) logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e: except Exception as e:
logger.exception("Failed to handle edu %r", edu_type, e) logger.exception("Failed to handle edu %r", edu_type)
else: else:
logger.warn("Received EDU of type %s with no handler", edu_type) logger.warn("Received EDU of type %s with no handler", edu_type)

View File

@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from .persistence import TransactionActions from .persistence import TransactionActions
from .units import Transaction from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -187,6 +187,24 @@ class TransactionQueue(object):
destination, pending_pdus, pending_edus, pending_failures destination, pending_pdus, pending_edus, pending_failures
) )
@defer.inlineCallbacks
def _get_new_device_messages(self, destination):
last_device_stream_id = 0
to_device_stream_id = self.store.get_to_device_stream_token()
contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
destination, last_device_stream_id, to_device_stream_id
)
edus = [
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.direct_to_device",
content=content,
)
for content in contents
]
defer.returnValue((edus, stream_id))
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus,
@ -211,13 +229,19 @@ class TransactionQueue(object):
self.store, self.store,
) )
device_message_edus, device_stream_id = (
yield self._get_new_device_messages(destination)
)
edus.extend(device_message_edus)
logger.debug( logger.debug(
"TX [%s] {%s} Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d, failures: %d)",
destination, txn_id, destination, txn_id,
len(pending_pdus), len(pdus),
len(pending_edus), len(edus),
len(pending_failures) len(failures)
) )
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
@ -242,9 +266,9 @@ class TransactionQueue(object):
" (PDUs: %d, EDUs: %d, failures: %d)", " (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id, destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pending_pdus), len(pdus),
len(pending_edus), len(edus),
len(pending_failures), len(failures),
) )
with limiter: with limiter:
@ -299,6 +323,11 @@ class TransactionQueue(object):
logger.info( logger.info(
"Failed to send event %s to %s", p.event_id, destination "Failed to send event %s to %s", p.event_id, destination
) )
else:
# Remove the acknowledged device messages from the database
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
except NotRetryingDestination: except NotRetryingDestination:
logger.info( logger.info(
"TX [%s] not ready for retry yet - " "TX [%s] not ready for retry yet - "

View File

@ -0,0 +1,121 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.types import get_domain_from_id
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
@defer.inlineCallbacks
def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
logger.warn(
"Dropping device message from %r with spoofed sender %r",
origin, sender_user_id
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
@defer.inlineCallbacks
def send_device_message(self, sender_user_id, message_type, messages):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
message_id = random_string(16)
remote_edu_contents = {}
for destination, messages in remote_messages.items():
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
}
stream_id = yield self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
for destination in remote_messages.keys():
# Hack to send make synapse send a federation transaction
# to the remote servers.
self.federation.send_edu(
destination=destination,
edu_type="m.ping",
content={},
)

View File

@ -16,10 +16,11 @@
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.http import servlet from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.v1.transactions import HttpTransactionStore from synapse.rest.client.v1.transactions import HttpTransactionStore
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,10 +40,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__() super(SendToDeviceRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()
self.device_message_handler = hs.get_device_message_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id): def on_PUT(self, request, message_type, txn_id):
@ -57,28 +56,10 @@ class SendToDeviceRestServlet(servlet.RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
# TODO: Prod the notifier to wake up sync streams. sender_user_id = requester.user.to_string()
# TODO: Implement replication for the messages.
# TODO: Send the messages to remote servers if needed.
local_messages = {} yield self.device_message_handler.send_device_message(
for user_id, by_device in content["messages"].items(): sender_user_id, message_type, content["messages"]
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": requester.user.to_string(),
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_to_device_inbox(local_messages)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
) )
response = (200, {}) response = (200, {})

View File

@ -35,6 +35,7 @@ from synapse.federation import initialize_http_replication
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
@ -100,6 +101,7 @@ class HomeServer(object):
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
'device_message_handler',
'notifier', 'notifier',
'distributor', 'distributor',
'client_resource', 'client_resource',
@ -205,6 +207,9 @@ class HomeServer(object):
def build_device_handler(self): def build_device_handler(self):
return DeviceHandler(self) return DeviceHandler(self)
def build_device_message_handler(self):
return DeviceMessageHandler(self)
def build_e2e_keys_handler(self): def build_e2e_keys_handler(self):
return E2eKeysHandler(self) return E2eKeysHandler(self)

View File

@ -59,10 +59,10 @@ class DeviceInboxStore(SQLBaseStore):
self._add_messages_to_local_device_inbox_txn( self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device txn, stream_id, local_messages_by_user_then_device
) )
add_messages_to_device_federation_outbox(now_ms, stream_id) add_messages_to_device_federation_outbox(txn, now_ms, stream_id)
with self._device_inbox_id_gen.get_next() as stream_id: with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_now_ms() now_ms = self.clock.time_msec()
yield self.runInteraction( yield self.runInteraction(
"add_messages_to_device_inbox", "add_messages_to_device_inbox",
add_messages_txn, add_messages_txn,
@ -100,7 +100,7 @@ class DeviceInboxStore(SQLBaseStore):
) )
with self._device_inbox_id_gen.get_next() as stream_id: with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_now_ms() now_ms = self.clock.time_msec()
yield self.runInteraction( yield self.runInteraction(
"add_messages_from_remote_to_device_inbox", "add_messages_from_remote_to_device_inbox",
add_messages_txn, add_messages_txn,
@ -239,8 +239,7 @@ class DeviceInboxStore(SQLBaseStore):
def get_to_device_stream_token(self): def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
@defer.inlineCallbacks def get_new_device_msgs_for_remote(
def get_new_device_messages_for_remote_destination(
self, destination, last_stream_id, current_stream_id, limit=100 self, destination, last_stream_id, current_stream_id, limit=100
): ):
""" """
@ -274,13 +273,11 @@ class DeviceInboxStore(SQLBaseStore):
return (messages, stream_pos) return (messages, stream_pos)
return self.runInteraction( return self.runInteraction(
"get_new_device_messages_for_remote_destination", "get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn, get_new_messages_for_remote_destination_txn,
) )
@defer.inlineCallbacks def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
def delete_device_messages_for_remote_destination(self, destination,
up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges """Used to delete messages when the remote destination acknowledges
their receipt. their receipt.
@ -293,12 +290,12 @@ class DeviceInboxStore(SQLBaseStore):
def delete_messages_for_remote_destination_txn(txn): def delete_messages_for_remote_destination_txn(txn):
sql = ( sql = (
"DELETE FROM device_federation_outbox" "DELETE FROM device_federation_outbox"
" WHERE destination = ? AND" " WHERE destination = ?"
" AND stream_id <= ?" " AND stream_id <= ?"
) )
txn.execute(sql, (destination, up_to_stream_id)) txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction( return self.runInteraction(
"delete_device_messages_for_remote_destination", "delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn delete_messages_for_remote_destination_txn
) )

View File

@ -16,9 +16,7 @@
CREATE TABLE device_federation_outbox ( CREATE TABLE device_federation_outbox (
destination TEXT NOT NULL, destination TEXT NOT NULL,
stream_id BIGINT NOT NULL, stream_id BIGINT NOT NULL,
sender TEXT NOT NULL, queued_ts BIGINT NOT NULL,
message_id TEXT NOT NULL,
sent_ts BIGINT NOT NULL,
messages_json TEXT NOT NULL messages_json TEXT NOT NULL
); );