diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index e358de942..719bfcc42 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -426,6 +426,11 @@ class ReplicationLayer(object): "auth_chain": [p.get_dict() for p in res_pdus["auth_chain"]], })) + @defer.inlineCallbacks + def on_event_auth(self, origin, context, event_id): + auth_pdus = yield self.handler.on_event_auth(event_id) + defer.returnValue((200, [a.get_dict() for a in auth_pdus])) + @defer.inlineCallbacks def make_join(self, destination, context, user_id): pdu_dict = yield self.transport_layer.make_join( diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py index b9f7d54c7..babe8447e 100644 --- a/synapse/federation/transport.py +++ b/synapse/federation/transport.py @@ -256,6 +256,21 @@ class TransportLayer(object): defer.returnValue(json.loads(content)) + @defer.inlineCallbacks + @log_function + def get_event_auth(self, destination, context, event_id): + path = PREFIX + "/event_auth/%s/%s" % ( + context, + event_id, + ) + + response = yield self.client.get_json( + destination=destination, + path=path, + ) + + defer.returnValue(response) + @defer.inlineCallbacks def _authenticate_request(self, request): json_request = { @@ -426,6 +441,17 @@ class TransportLayer(object): ) ) + self.server.register_path( + "GET", + re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + handler.on_event_auth( + origin, context, event_id, + ) + ) + ) + self.server.register_path( "PUT", re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"), diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e6afd95a5..ce65bbcd6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -224,6 +224,11 @@ class FederationHandler(BaseHandler): defer.returnValue(self.pdu_codec.event_from_pdu(pdu)) + @defer.inlineCallbacks + def on_event_auth(self, event_id): + auth = yield self.store.get_auth_chain(event_id) + defer.returnValue([self.pdu_codec.pdu_from_event(e) for e in auth]) + @log_function @defer.inlineCallbacks def do_invite_join(self, target_host, room_id, joinee, content, snapshot): diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index d66a49e9f..06e32d592 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -32,6 +32,24 @@ class EventFederationStore(SQLBaseStore): ) def _get_auth_chain_txn(self, txn, event_id): + results = self._get_auth_chain_ids_txn(txn, event_id) + + sql = "SELECT * FROM events WHERE event_id = ?" + rows = [] + for ev_id in results: + c = txn.execute(sql, (ev_id,)) + rows.extend(self.cursor_to_dict(c)) + + return self._parse_events_txn(txn, rows) + + def get_auth_chain_ids(self, event_id): + return self.runInteraction( + "get_auth_chain_ids", + self._get_auth_chain_ids_txn, + event_id + ) + + def _get_auth_chain_ids_txn(self, txn, event_id): results = set() base_sql = ( @@ -48,13 +66,7 @@ class EventFederationStore(SQLBaseStore): front = [r[0] for r in txn.fetchall()] results.update(front) - sql = "SELECT * FROM events WHERE event_id = ?" - rows = [] - for ev_id in results: - c = txn.execute(sql, (ev_id,)) - rows.extend(self.cursor_to_dict(c)) - - return self._parse_events_txn(txn, rows) + return list(results) def get_oldest_events_in_room(self, room_id): return self.runInteraction(