diff --git a/changelog.d/6521.misc b/changelog.d/6521.misc new file mode 100644 index 000000000..d9a44389b --- /dev/null +++ b/changelog.d/6521.misc @@ -0,0 +1 @@ +Refactor some code in the event authentication path for clarity. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bcd3b422a..62985bab9 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -378,22 +378,10 @@ class FederationHandler(BaseHandler): ( remote_state, got_auth_chain, - ) = await self._get_state_for_room(origin, room_id, p) - - # we want the state *after* p; _get_state_for_room returns the - # state *before* p. - remote_event = await self.federation_client.get_pdu( - [origin], p, room_version, outlier=True + ) = await self._get_state_for_room( + origin, room_id, p, include_event_in_state=True ) - if remote_event is None: - raise Exception( - "Unable to get missing prev_event %s" % (p,) - ) - - if remote_event.is_state(): - remote_state.append(remote_event) - # XXX hrm I'm not convinced that duplicate events will compare # for equality, so I'm not sure this does what the author # hoped. @@ -579,20 +567,25 @@ class FederationHandler(BaseHandler): else: raise - @log_function async def _get_state_for_room( - self, destination: str, room_id: str, event_id: str + self, + destination: str, + room_id: str, + event_id: str, + include_event_in_state: bool = False, ) -> Tuple[List[EventBase], List[EventBase]]: """Requests all of the room state at a given event from a remote homeserver. Args: - destination:: The remote homeserver to query for the state. + destination: The remote homeserver to query for the state. room_id: The id of the room we're interested in. event_id: The id of the event we want the state at. + include_event_in_state: if true, the event itself will be included in the + returned state event list. Returns: - A list of events in the state, and a list of events in the auth chain - for the given event. + A list of events in the state, possibly including the event itself, and + a list of events in the auth chain for the given event. """ ( state_event_ids, @@ -602,6 +595,10 @@ class FederationHandler(BaseHandler): ) desired_events = set(state_event_ids + auth_event_ids) + + if include_event_in_state: + desired_events.add(event_id) + event_map = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) @@ -614,12 +611,21 @@ class FederationHandler(BaseHandler): failed_to_fetch, ) - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + remote_state = [ + event_map[e_id] for e_id in state_event_ids if e_id in event_map + ] + if include_event_in_state: + remote_event = event_map.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + if remote_event.is_state(): + remote_state.append(remote_event) + + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] auth_chain.sort(key=lambda e: e.depth) - return pdus, auth_chain + return remote_state, auth_chain async def _get_events_from_store_or_dest( self, destination: str, room_id: str, event_ids: Iterable[str]