Factor out an _AsyncEventContextImpl (#6298)

The intention here is to make it clearer which fields we can expect to be
populated when: notably, that the _event_type etc aren't used for the
synchronous impl of EventContext.
This commit is contained in:
Richard van der Hoff 2019-11-01 16:19:09 +00:00 committed by GitHub
parent fa7e52caf1
commit c6516adbe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 65 additions and 85 deletions

1
changelog.d/6298.misc Normal file
View File

@ -0,0 +1 @@
Refactor EventContext for clarity.

View File

@ -37,9 +37,6 @@ class EventContext:
delta_ids (dict[(str, str), str]): Delta from ``prev_group``. delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
(type, state_key) -> event_id. ``None`` for an outlier. (type, state_key) -> event_id. ``None`` for an outlier.
prev_state_events (?): XXX: is this ever set to anything other than
the empty list?
app_service: FIXME app_service: FIXME
_current_state_ids (dict[(str, str), str]|None): _current_state_ids (dict[(str, str), str]|None):
@ -51,36 +48,16 @@ class EventContext:
The current state map excluding the current event. None if outlier The current state map excluding the current event. None if outlier
or we haven't fetched the state from DB yet. or we haven't fetched the state from DB yet.
(type, state_key) -> event_id (type, state_key) -> event_id
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet
_event_type (str): The type of the event the context is associated with.
Only set when state has not been fetched yet.
_event_state_key (str|None): The state_key of the event the context is
associated with. Only set when state has not been fetched yet.
_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
Only set when state has not been fetched yet.
""" """
state_group = attr.ib(default=None) state_group = attr.ib(default=None)
rejected = attr.ib(default=False) rejected = attr.ib(default=False)
prev_group = attr.ib(default=None) prev_group = attr.ib(default=None)
delta_ids = attr.ib(default=None) delta_ids = attr.ib(default=None)
prev_state_events = attr.ib(default=attr.Factory(list))
app_service = attr.ib(default=None) app_service = attr.ib(default=None)
_current_state_ids = attr.ib(default=None)
_prev_state_ids = attr.ib(default=None) _prev_state_ids = attr.ib(default=None)
_prev_state_id = attr.ib(default=None) _current_state_ids = attr.ib(default=None)
_event_type = attr.ib(default=None)
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
@staticmethod @staticmethod
def with_state( def with_state(
@ -90,7 +67,6 @@ class EventContext:
current_state_ids=current_state_ids, current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids, prev_state_ids=prev_state_ids,
state_group=state_group, state_group=state_group,
fetching_state_deferred=defer.succeed(None),
prev_group=prev_group, prev_group=prev_group,
delta_ids=delta_ids, delta_ids=delta_ids,
) )
@ -125,7 +101,6 @@ class EventContext:
"rejected": self.rejected, "rejected": self.rejected,
"prev_group": self.prev_group, "prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids), "delta_ids": _encode_state_dict(self.delta_ids),
"prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None, "app_service_id": self.app_service.id if self.app_service else None,
} }
@ -141,7 +116,7 @@ class EventContext:
Returns: Returns:
EventContext EventContext
""" """
context = EventContext( context = _AsyncEventContextImpl(
# We use the state_group and prev_state_id stuff to pull the # We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids. # current_state_ids out of the DB and construct prev_state_ids.
prev_state_id=input["prev_state_id"], prev_state_id=input["prev_state_id"],
@ -151,7 +126,6 @@ class EventContext:
prev_group=input["prev_group"], prev_group=input["prev_group"],
delta_ids=_decode_state_dict(input["delta_ids"]), delta_ids=_decode_state_dict(input["delta_ids"]),
rejected=input["rejected"], rejected=input["rejected"],
prev_state_events=input["prev_state_events"],
) )
app_service_id = input["app_service_id"] app_service_id = input["app_service_id"]
@ -170,14 +144,7 @@ class EventContext:
Maps a (type, state_key) to the event ID of the state event matching Maps a (type, state_key) to the event ID of the state event matching
this tuple. this tuple.
""" """
yield self._ensure_fetched(store)
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store
)
yield make_deferred_yieldable(self._fetching_state_deferred)
return self._current_state_ids return self._current_state_ids
@defer.inlineCallbacks @defer.inlineCallbacks
@ -190,14 +157,7 @@ class EventContext:
Maps a (type, state_key) to the event ID of the state event matching Maps a (type, state_key) to the event ID of the state event matching
this tuple. this tuple.
""" """
yield self._ensure_fetched(store)
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store
)
yield make_deferred_yieldable(self._fetching_state_deferred)
return self._prev_state_ids return self._prev_state_ids
def get_cached_current_state_ids(self): def get_cached_current_state_ids(self):
@ -211,6 +171,44 @@ class EventContext:
return self._current_state_ids return self._current_state_ids
def _ensure_fetched(self, store):
return defer.succeed(None)
@attr.s(slots=True)
class _AsyncEventContextImpl(EventContext):
"""
An implementation of EventContext which fetches _current_state_ids and
_prev_state_ids from the database on demand.
Attributes:
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet
_event_type (str): The type of the event the context is associated with.
_event_state_key (str): The state_key of the event the context is
associated with.
_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
"""
_prev_state_id = attr.ib(default=None)
_event_type = attr.ib(default=None)
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
def _ensure_fetched(self, store):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store
)
return make_deferred_yieldable(self._fetching_state_deferred)
@defer.inlineCallbacks @defer.inlineCallbacks
def _fill_out_state(self, store): def _fill_out_state(self, store):
"""Called to populate the _current_state_ids and _prev_state_ids """Called to populate the _current_state_ids and _prev_state_ids
@ -228,27 +226,6 @@ class EventContext:
else: else:
self._prev_state_ids = self._current_state_ids self._prev_state_ids = self._current_state_ids
@defer.inlineCallbacks
def update_state(
self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
):
"""Replace the state in the context
"""
# We need to make sure we wait for any ongoing fetching of state
# to complete so that the updated state doesn't get clobbered
if self._fetching_state_deferred:
yield make_deferred_yieldable(self._fetching_state_deferred)
self.state_group = state_group
self._prev_state_ids = prev_state_ids
self.prev_group = prev_group
self._current_state_ids = current_state_ids
self.delta_ids = delta_ids
# We need to ensure that that we've marked as having fetched the state
self._fetching_state_deferred = defer.succeed(None)
def _encode_state_dict(state_dict): def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in """Since dicts of (type, state_key) -> event_id cannot be serialized in

View File

@ -45,6 +45,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event from synapse.event_auth import auth_types_for_event
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.logging.context import ( from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
@ -1871,14 +1872,7 @@ class FederationHandler(BaseHandler):
if c and c.type == EventTypes.Create: if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c auth_events[(c.type, c.state_key)] = c
try: context = yield self.do_auth(origin, event, context, auth_events=auth_events)
yield self.do_auth(origin, event, context, auth_events=auth_events)
except AuthError as e:
logger.warning(
"[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
)
context.rejected = RejectedReason.AUTH_ERROR
if not context.rejected: if not context.rejected:
yield self._check_for_soft_fail(event, state, backfilled) yield self._check_for_soft_fail(event, state, backfilled)
@ -2047,12 +2041,12 @@ class FederationHandler(BaseHandler):
Also NB that this function adds entries to it. Also NB that this function adds entries to it.
Returns: Returns:
defer.Deferred[None] defer.Deferred[EventContext]: updated context object
""" """
room_version = yield self.store.get_room_version(event.room_id) room_version = yield self.store.get_room_version(event.room_id)
try: try:
yield self._update_auth_events_and_context_for_auth( context = yield self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events origin, event, context, auth_events
) )
except Exception: except Exception:
@ -2070,7 +2064,9 @@ class FederationHandler(BaseHandler):
event_auth.check(room_version, event, auth_events=auth_events) event_auth.check(room_version, event, auth_events=auth_events)
except AuthError as e: except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e) logger.warning("Failed auth resolution for %r because %s", event, e)
raise e context.rejected = RejectedReason.AUTH_ERROR
return context
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_auth_events_and_context_for_auth( def _update_auth_events_and_context_for_auth(
@ -2094,7 +2090,7 @@ class FederationHandler(BaseHandler):
auth_events (dict[(str, str)->synapse.events.EventBase]): auth_events (dict[(str, str)->synapse.events.EventBase]):
Returns: Returns:
defer.Deferred[None] defer.Deferred[EventContext]: updated context
""" """
event_auth_events = set(event.auth_event_ids()) event_auth_events = set(event.auth_event_ids())
@ -2133,7 +2129,7 @@ class FederationHandler(BaseHandler):
# The other side isn't around or doesn't implement the # The other side isn't around or doesn't implement the
# endpoint, so lets just bail out. # endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e) logger.info("Failed to get event auth from remote: %s", e)
return return context
seen_remotes = yield self.store.have_seen_events( seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain] [e.event_id for e in remote_auth_chain]
@ -2174,7 +2170,7 @@ class FederationHandler(BaseHandler):
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():
logger.info("Skipping auth_event fetch for outlier") logger.info("Skipping auth_event fetch for outlier")
return return context
# FIXME: Assumes we have and stored all the state for all the # FIXME: Assumes we have and stored all the state for all the
# prev_events # prev_events
@ -2183,7 +2179,7 @@ class FederationHandler(BaseHandler):
) )
if not different_auth: if not different_auth:
return return context
logger.info( logger.info(
"auth_events refers to events which are not in our calculated auth " "auth_events refers to events which are not in our calculated auth "
@ -2230,10 +2226,12 @@ class FederationHandler(BaseHandler):
auth_events.update(new_state) auth_events.update(new_state)
yield self._update_context_for_auth_events( context = yield self._update_context_for_auth_events(
event, context, auth_events, event_key event, context, auth_events, event_key
) )
return context
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events, event_key): def _update_context_for_auth_events(self, event, context, auth_events, event_key):
"""Update the state_ids in an event context after auth event resolution, """Update the state_ids in an event context after auth event resolution,
@ -2242,14 +2240,16 @@ class FederationHandler(BaseHandler):
Args: Args:
event (Event): The event we're handling the context for event (Event): The event we're handling the context for
context (synapse.events.snapshot.EventContext): event context context (synapse.events.snapshot.EventContext): initial event context
to be updated
auth_events (dict[(str, str)->str]): Events to update in the event auth_events (dict[(str, str)->str]): Events to update in the event
context. context.
event_key ((str, str)): (type, state_key) for the current event. event_key ((str, str)): (type, state_key) for the current event.
this will not be included in the current_state in the context. this will not be included in the current_state in the context.
Returns:
Deferred[EventContext]: new event context
""" """
state_updates = { state_updates = {
k: a.event_id for k, a in iteritems(auth_events) if k != event_key k: a.event_id for k, a in iteritems(auth_events) if k != event_key
@ -2274,7 +2274,7 @@ class FederationHandler(BaseHandler):
current_state_ids=current_state_ids, current_state_ids=current_state_ids,
) )
yield context.update_state( return EventContext.with_state(
state_group=state_group, state_group=state_group,
current_state_ids=current_state_ids, current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids, prev_state_ids=prev_state_ids,

View File

@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
) )
self.handler = self.homeserver.get_handlers().federation_handler self.handler = self.homeserver.get_handlers().federation_handler
self.handler.do_auth = lambda *a, **b: succeed(True) self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
context
)
self.client = self.homeserver.get_federation_client() self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus pdus