diff --git a/changelog.d/10896.misc b/changelog.d/10896.misc index 41de99584..9a765435d 100644 --- a/changelog.d/10896.misc +++ b/changelog.d/10896.misc @@ -1 +1 @@ - Clean up some of the federation event authentication code for clarity. +Clean up some of the federation event authentication code for clarity. diff --git a/changelog.d/10926.misc b/changelog.d/10926.misc new file mode 100644 index 000000000..9a765435d --- /dev/null +++ b/changelog.d/10926.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 01fd84112..2c4644b4a 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -68,11 +68,7 @@ from synapse.types import ( UserID, get_domain_from_id, ) -from synapse.util.async_helpers import ( - Linearizer, - concurrently_execute, - yieldable_gather_results, -) +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.iterutils import batch_iter from synapse.util.retryutils import NotRetryingDestination from synapse.util.stringutils import shortstr @@ -1189,7 +1185,10 @@ class FederationEventHandler: allow_rejected=True, ) - async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: + room_version = await self._store.get_room_version_id(room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + + def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]: with nested_logging_context(suffix=event.event_id): auth = {} for auth_event_id in event.auth_event_ids(): @@ -1207,17 +1206,15 @@ class FederationEventHandler: auth[(ae.type, ae.state_key)] = ae context = EventContext.for_outlier() - context = await self._check_event_auth( - origin, - event, - context, - claimed_auth_event_map=auth, - ) + try: + event_auth.check(room_version_obj, event, auth_events=auth) + except AuthError as e: + logger.warning("Rejecting %r because %s", event, e) + context.rejected = RejectedReason.AUTH_ERROR + return event, context - events_to_persist = ( - x for x in await yieldable_gather_results(prep, fetched_events) if x - ) + events_to_persist = (x for x in (prep(event) for event in fetched_events) if x) await self.persist_events_and_notify(room_id, tuple(events_to_persist)) async def _check_event_auth( @@ -1226,7 +1223,6 @@ class FederationEventHandler: event: EventBase, context: EventContext, state: Optional[Iterable[EventBase]] = None, - claimed_auth_event_map: Optional[StateMap[EventBase]] = None, backfilled: bool = False, ) -> EventContext: """ @@ -1242,42 +1238,36 @@ class FederationEventHandler: The state events used to check the event for soft-fail. If this is not provided the current state events will be used. - claimed_auth_event_map: - A map of (type, state_key) => event for the event's claimed auth_events. - Possibly including events that were rejected, or are in the wrong room. - - Only populated when populating outliers. - backfilled: True if the event was backfilled. Returns: The updated context object. """ - # claimed_auth_event_map should be given iff the event is an outlier - assert bool(claimed_auth_event_map) == event.internal_metadata.outlier + # This method should only be used for non-outliers + assert not event.internal_metadata.outlier room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] - if claimed_auth_event_map: - # if we have a copy of the auth events from the event, use that as the - # basis for auth. - auth_events = claimed_auth_event_map - else: - # otherwise, we calculate what the auth events *should* be, and use that - prev_state_ids = await context.get_prev_state_ids() - auth_events_ids = self._event_auth_handler.compute_auth_events( - event, prev_state_ids, for_verification=True - ) - auth_events_x = await self._store.get_events(auth_events_ids) - auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} + # calculate what the auth events *should* be, to use as a basis for auth. + prev_state_ids = await context.get_prev_state_ids() + auth_events_ids = self._event_auth_handler.compute_auth_events( + event, prev_state_ids, for_verification=True + ) + auth_events_x = await self._store.get_events(auth_events_ids) + calculated_auth_event_map = { + (e.type, e.state_key): e for e in auth_events_x.values() + } try: ( context, auth_events_for_auth, ) = await self._update_auth_events_and_context_for_auth( - origin, event, context, auth_events + origin, + event, + context, + calculated_auth_event_map=calculated_auth_event_map, ) except Exception: # We don't really mind if the above fails, so lets not fail @@ -1289,7 +1279,7 @@ class FederationEventHandler: "Ignoring failure and continuing processing of event.", event.event_id, ) - auth_events_for_auth = auth_events + auth_events_for_auth = calculated_auth_event_map try: event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth) @@ -1425,7 +1415,7 @@ class FederationEventHandler: origin: str, event: EventBase, context: EventContext, - input_auth_events: StateMap[EventBase], + calculated_auth_event_map: StateMap[EventBase], ) -> Tuple[EventContext, StateMap[EventBase]]: """Helper for _check_event_auth. See there for docs. @@ -1443,19 +1433,17 @@ class FederationEventHandler: event: context: - input_auth_events: - Map from (event_type, state_key) to event - - Normally, our calculated auth_events based on the state of the room - at the event's position in the DAG, though occasionally (eg if the - event is an outlier), may be the auth events claimed by the remote - server. + calculated_auth_event_map: + Our calculated auth_events based on the state of the room + at the event's position in the DAG. Returns: updated context, updated auth event map """ - # take a copy of input_auth_events before we modify it. - auth_events: MutableStateMap[EventBase] = dict(input_auth_events) + assert not event.internal_metadata.outlier + + # take a copy of calculated_auth_event_map before we modify it. + auth_events: MutableStateMap[EventBase] = dict(calculated_auth_event_map) event_auth_events = set(event.auth_event_ids()) @@ -1496,15 +1484,6 @@ class FederationEventHandler: } ) - if event.internal_metadata.is_outlier(): - # XXX: given that, for an outlier, we'll be working with the - # event's *claimed* auth events rather than those we calculated: - # (a) is there any point in this test, since different_auth below will - # obviously be empty - # (b) alternatively, why don't we do it earlier? - logger.info("Skipping auth_event fetch for outlier") - return context, auth_events - different_auth = event_auth_events.difference( e.event_id for e in auth_events.values() ) diff --git a/tests/test_federation.py b/tests/test_federation.py index c51e018da..24fc77d7a 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -82,7 +82,6 @@ class MessageAcceptTests(unittest.HomeserverTestCase): event, context, state=None, - claimed_auth_event_map=None, backfilled=False, ): return context