Simplify _auth_and_persist_fetched_events (#10901)

Combine the two loops over the list of events, and hence get rid of
`_NewEventInfo`. Also pass the event back alongside the context, so that it's
easier to process the result.
This commit is contained in:
Richard van der Hoff 2021-09-24 11:56:13 +01:00 committed by GitHub
parent 50022cff96
commit 261c9763c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 69 deletions

1
changelog.d/10901.misc Normal file
View File

@ -0,0 +1 @@
Clean up some of the federation event authentication code for clarity.

View File

@ -27,11 +27,8 @@ from typing import (
Tuple, Tuple,
) )
import attr
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import ( from synapse.api.constants import (
EventContentFields, EventContentFields,
@ -54,11 +51,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.federation.federation_client import InvalidResponseError from synapse.federation.federation_client import InvalidResponseError
from synapse.logging.context import ( from synapse.logging.context import nested_logging_context, run_in_background
make_deferred_yieldable,
nested_logging_context,
run_in_background,
)
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
@ -75,7 +68,11 @@ from synapse.types import (
UserID, UserID,
get_domain_from_id, get_domain_from_id,
) )
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import (
Linearizer,
concurrently_execute,
yieldable_gather_results,
)
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
@ -92,30 +89,6 @@ soft_failed_event_counter = Counter(
) )
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _NewEventInfo:
"""Holds information about a received event, ready for passing to _auth_and_persist_events
Attributes:
event: the received event
claimed_auth_event_map: a map of (type, state_key) => event for the event's
claimed auth_events.
This can include events which have not yet been persisted, in the case that
we are backfilling a batch of events.
Note: May be incomplete: if we were unable to find all of the claimed auth
events. Also, treat the contents with caution: the events might also have
been rejected, might not yet have been authorized themselves, or they might
be in the wrong room.
"""
event: EventBase
claimed_auth_event_map: StateMap[EventBase]
class FederationEventHandler: class FederationEventHandler:
"""Handles events that originated from federation. """Handles events that originated from federation.
@ -1203,47 +1176,27 @@ class FederationEventHandler:
allow_rejected=True, allow_rejected=True,
) )
event_infos = [] async def prep(event: EventBase) -> Optional[Tuple[EventBase, EventContext]]:
for event in fetched_events:
auth = {}
for auth_event_id in event.auth_event_ids():
ae = persisted_events.get(auth_event_id)
if ae:
auth[(ae.type, ae.state_key)] = ae
else:
logger.info("Missing auth event %s", auth_event_id)
event_infos.append(_NewEventInfo(event, auth))
if not event_infos:
return
async def prep(ev_info: _NewEventInfo) -> EventContext:
event = ev_info.event
with nested_logging_context(suffix=event.event_id): with nested_logging_context(suffix=event.event_id):
res = EventContext.for_outlier() auth = {}
res = await self._check_event_auth( for auth_event_id in event.auth_event_ids():
ae = persisted_events.get(auth_event_id)
if ae:
auth[(ae.type, ae.state_key)] = ae
else:
logger.info("Missing auth event %s", auth_event_id)
context = EventContext.for_outlier()
context = await self._check_event_auth(
origin, origin,
event, event,
res, context,
claimed_auth_event_map=ev_info.claimed_auth_event_map, claimed_auth_event_map=auth,
) )
return res return event, context
contexts = await make_deferred_yieldable( events_to_persist = await yieldable_gather_results(prep, fetched_events)
defer.gatherResults( await self.persist_events_and_notify(room_id, events_to_persist)
[run_in_background(prep, ev_info) for ev_info in event_infos],
consumeErrors=True,
)
)
await self.persist_events_and_notify(
room_id,
[
(ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
],
)
async def _check_event_auth( async def _check_event_auth(
self, self,