mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Factor out an _AsyncEventContextImpl (#6298)
The intention here is to make it clearer which fields we can expect to be populated when: notably, that the _event_type etc aren't used for the synchronous impl of EventContext.
This commit is contained in:
parent
fa7e52caf1
commit
c6516adbe0
1
changelog.d/6298.misc
Normal file
1
changelog.d/6298.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Refactor EventContext for clarity.
|
@ -37,9 +37,6 @@ class EventContext:
|
|||||||
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
|
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
|
||||||
(type, state_key) -> event_id. ``None`` for an outlier.
|
(type, state_key) -> event_id. ``None`` for an outlier.
|
||||||
|
|
||||||
prev_state_events (?): XXX: is this ever set to anything other than
|
|
||||||
the empty list?
|
|
||||||
|
|
||||||
app_service: FIXME
|
app_service: FIXME
|
||||||
|
|
||||||
_current_state_ids (dict[(str, str), str]|None):
|
_current_state_ids (dict[(str, str), str]|None):
|
||||||
@ -51,36 +48,16 @@ class EventContext:
|
|||||||
The current state map excluding the current event. None if outlier
|
The current state map excluding the current event. None if outlier
|
||||||
or we haven't fetched the state from DB yet.
|
or we haven't fetched the state from DB yet.
|
||||||
(type, state_key) -> event_id
|
(type, state_key) -> event_id
|
||||||
|
|
||||||
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
|
||||||
been calculated. None if we haven't started calculating yet
|
|
||||||
|
|
||||||
_event_type (str): The type of the event the context is associated with.
|
|
||||||
Only set when state has not been fetched yet.
|
|
||||||
|
|
||||||
_event_state_key (str|None): The state_key of the event the context is
|
|
||||||
associated with. Only set when state has not been fetched yet.
|
|
||||||
|
|
||||||
_prev_state_id (str|None): If the event associated with the context is
|
|
||||||
a state event, then `_prev_state_id` is the event_id of the state
|
|
||||||
that was replaced.
|
|
||||||
Only set when state has not been fetched yet.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
state_group = attr.ib(default=None)
|
state_group = attr.ib(default=None)
|
||||||
rejected = attr.ib(default=False)
|
rejected = attr.ib(default=False)
|
||||||
prev_group = attr.ib(default=None)
|
prev_group = attr.ib(default=None)
|
||||||
delta_ids = attr.ib(default=None)
|
delta_ids = attr.ib(default=None)
|
||||||
prev_state_events = attr.ib(default=attr.Factory(list))
|
|
||||||
app_service = attr.ib(default=None)
|
app_service = attr.ib(default=None)
|
||||||
|
|
||||||
_current_state_ids = attr.ib(default=None)
|
|
||||||
_prev_state_ids = attr.ib(default=None)
|
_prev_state_ids = attr.ib(default=None)
|
||||||
_prev_state_id = attr.ib(default=None)
|
_current_state_ids = attr.ib(default=None)
|
||||||
|
|
||||||
_event_type = attr.ib(default=None)
|
|
||||||
_event_state_key = attr.ib(default=None)
|
|
||||||
_fetching_state_deferred = attr.ib(default=None)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def with_state(
|
def with_state(
|
||||||
@ -90,7 +67,6 @@ class EventContext:
|
|||||||
current_state_ids=current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
prev_state_ids=prev_state_ids,
|
prev_state_ids=prev_state_ids,
|
||||||
state_group=state_group,
|
state_group=state_group,
|
||||||
fetching_state_deferred=defer.succeed(None),
|
|
||||||
prev_group=prev_group,
|
prev_group=prev_group,
|
||||||
delta_ids=delta_ids,
|
delta_ids=delta_ids,
|
||||||
)
|
)
|
||||||
@ -125,7 +101,6 @@ class EventContext:
|
|||||||
"rejected": self.rejected,
|
"rejected": self.rejected,
|
||||||
"prev_group": self.prev_group,
|
"prev_group": self.prev_group,
|
||||||
"delta_ids": _encode_state_dict(self.delta_ids),
|
"delta_ids": _encode_state_dict(self.delta_ids),
|
||||||
"prev_state_events": self.prev_state_events,
|
|
||||||
"app_service_id": self.app_service.id if self.app_service else None,
|
"app_service_id": self.app_service.id if self.app_service else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,7 +116,7 @@ class EventContext:
|
|||||||
Returns:
|
Returns:
|
||||||
EventContext
|
EventContext
|
||||||
"""
|
"""
|
||||||
context = EventContext(
|
context = _AsyncEventContextImpl(
|
||||||
# We use the state_group and prev_state_id stuff to pull the
|
# We use the state_group and prev_state_id stuff to pull the
|
||||||
# current_state_ids out of the DB and construct prev_state_ids.
|
# current_state_ids out of the DB and construct prev_state_ids.
|
||||||
prev_state_id=input["prev_state_id"],
|
prev_state_id=input["prev_state_id"],
|
||||||
@ -151,7 +126,6 @@ class EventContext:
|
|||||||
prev_group=input["prev_group"],
|
prev_group=input["prev_group"],
|
||||||
delta_ids=_decode_state_dict(input["delta_ids"]),
|
delta_ids=_decode_state_dict(input["delta_ids"]),
|
||||||
rejected=input["rejected"],
|
rejected=input["rejected"],
|
||||||
prev_state_events=input["prev_state_events"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
app_service_id = input["app_service_id"]
|
app_service_id = input["app_service_id"]
|
||||||
@ -170,14 +144,7 @@ class EventContext:
|
|||||||
Maps a (type, state_key) to the event ID of the state event matching
|
Maps a (type, state_key) to the event ID of the state event matching
|
||||||
this tuple.
|
this tuple.
|
||||||
"""
|
"""
|
||||||
|
yield self._ensure_fetched(store)
|
||||||
if not self._fetching_state_deferred:
|
|
||||||
self._fetching_state_deferred = run_in_background(
|
|
||||||
self._fill_out_state, store
|
|
||||||
)
|
|
||||||
|
|
||||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
|
||||||
|
|
||||||
return self._current_state_ids
|
return self._current_state_ids
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -190,14 +157,7 @@ class EventContext:
|
|||||||
Maps a (type, state_key) to the event ID of the state event matching
|
Maps a (type, state_key) to the event ID of the state event matching
|
||||||
this tuple.
|
this tuple.
|
||||||
"""
|
"""
|
||||||
|
yield self._ensure_fetched(store)
|
||||||
if not self._fetching_state_deferred:
|
|
||||||
self._fetching_state_deferred = run_in_background(
|
|
||||||
self._fill_out_state, store
|
|
||||||
)
|
|
||||||
|
|
||||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
|
||||||
|
|
||||||
return self._prev_state_ids
|
return self._prev_state_ids
|
||||||
|
|
||||||
def get_cached_current_state_ids(self):
|
def get_cached_current_state_ids(self):
|
||||||
@ -211,6 +171,44 @@ class EventContext:
|
|||||||
|
|
||||||
return self._current_state_ids
|
return self._current_state_ids
|
||||||
|
|
||||||
|
def _ensure_fetched(self, store):
|
||||||
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class _AsyncEventContextImpl(EventContext):
|
||||||
|
"""
|
||||||
|
An implementation of EventContext which fetches _current_state_ids and
|
||||||
|
_prev_state_ids from the database on demand.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
|
||||||
|
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
||||||
|
been calculated. None if we haven't started calculating yet
|
||||||
|
|
||||||
|
_event_type (str): The type of the event the context is associated with.
|
||||||
|
|
||||||
|
_event_state_key (str): The state_key of the event the context is
|
||||||
|
associated with.
|
||||||
|
|
||||||
|
_prev_state_id (str|None): If the event associated with the context is
|
||||||
|
a state event, then `_prev_state_id` is the event_id of the state
|
||||||
|
that was replaced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_prev_state_id = attr.ib(default=None)
|
||||||
|
_event_type = attr.ib(default=None)
|
||||||
|
_event_state_key = attr.ib(default=None)
|
||||||
|
_fetching_state_deferred = attr.ib(default=None)
|
||||||
|
|
||||||
|
def _ensure_fetched(self, store):
|
||||||
|
if not self._fetching_state_deferred:
|
||||||
|
self._fetching_state_deferred = run_in_background(
|
||||||
|
self._fill_out_state, store
|
||||||
|
)
|
||||||
|
|
||||||
|
return make_deferred_yieldable(self._fetching_state_deferred)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _fill_out_state(self, store):
|
def _fill_out_state(self, store):
|
||||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||||
@ -228,27 +226,6 @@ class EventContext:
|
|||||||
else:
|
else:
|
||||||
self._prev_state_ids = self._current_state_ids
|
self._prev_state_ids = self._current_state_ids
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def update_state(
|
|
||||||
self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
|
|
||||||
):
|
|
||||||
"""Replace the state in the context
|
|
||||||
"""
|
|
||||||
|
|
||||||
# We need to make sure we wait for any ongoing fetching of state
|
|
||||||
# to complete so that the updated state doesn't get clobbered
|
|
||||||
if self._fetching_state_deferred:
|
|
||||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
|
||||||
|
|
||||||
self.state_group = state_group
|
|
||||||
self._prev_state_ids = prev_state_ids
|
|
||||||
self.prev_group = prev_group
|
|
||||||
self._current_state_ids = current_state_ids
|
|
||||||
self.delta_ids = delta_ids
|
|
||||||
|
|
||||||
# We need to ensure that that we've marked as having fetched the state
|
|
||||||
self._fetching_state_deferred = defer.succeed(None)
|
|
||||||
|
|
||||||
|
|
||||||
def _encode_state_dict(state_dict):
|
def _encode_state_dict(state_dict):
|
||||||
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|
||||||
|
@ -45,6 +45,7 @@ from synapse.api.errors import (
|
|||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||||
from synapse.crypto.event_signing import compute_event_signature
|
from synapse.crypto.event_signing import compute_event_signature
|
||||||
from synapse.event_auth import auth_types_for_event
|
from synapse.event_auth import auth_types_for_event
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
from synapse.logging.context import (
|
from synapse.logging.context import (
|
||||||
make_deferred_yieldable,
|
make_deferred_yieldable,
|
||||||
@ -1871,14 +1872,7 @@ class FederationHandler(BaseHandler):
|
|||||||
if c and c.type == EventTypes.Create:
|
if c and c.type == EventTypes.Create:
|
||||||
auth_events[(c.type, c.state_key)] = c
|
auth_events[(c.type, c.state_key)] = c
|
||||||
|
|
||||||
try:
|
context = yield self.do_auth(origin, event, context, auth_events=auth_events)
|
||||||
yield self.do_auth(origin, event, context, auth_events=auth_events)
|
|
||||||
except AuthError as e:
|
|
||||||
logger.warning(
|
|
||||||
"[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
|
|
||||||
)
|
|
||||||
|
|
||||||
context.rejected = RejectedReason.AUTH_ERROR
|
|
||||||
|
|
||||||
if not context.rejected:
|
if not context.rejected:
|
||||||
yield self._check_for_soft_fail(event, state, backfilled)
|
yield self._check_for_soft_fail(event, state, backfilled)
|
||||||
@ -2047,12 +2041,12 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
Also NB that this function adds entries to it.
|
Also NB that this function adds entries to it.
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[None]
|
defer.Deferred[EventContext]: updated context object
|
||||||
"""
|
"""
|
||||||
room_version = yield self.store.get_room_version(event.room_id)
|
room_version = yield self.store.get_room_version(event.room_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self._update_auth_events_and_context_for_auth(
|
context = yield self._update_auth_events_and_context_for_auth(
|
||||||
origin, event, context, auth_events
|
origin, event, context, auth_events
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -2070,7 +2064,9 @@ class FederationHandler(BaseHandler):
|
|||||||
event_auth.check(room_version, event, auth_events=auth_events)
|
event_auth.check(room_version, event, auth_events=auth_events)
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
logger.warning("Failed auth resolution for %r because %s", event, e)
|
logger.warning("Failed auth resolution for %r because %s", event, e)
|
||||||
raise e
|
context.rejected = RejectedReason.AUTH_ERROR
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _update_auth_events_and_context_for_auth(
|
def _update_auth_events_and_context_for_auth(
|
||||||
@ -2094,7 +2090,7 @@ class FederationHandler(BaseHandler):
|
|||||||
auth_events (dict[(str, str)->synapse.events.EventBase]):
|
auth_events (dict[(str, str)->synapse.events.EventBase]):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[None]
|
defer.Deferred[EventContext]: updated context
|
||||||
"""
|
"""
|
||||||
event_auth_events = set(event.auth_event_ids())
|
event_auth_events = set(event.auth_event_ids())
|
||||||
|
|
||||||
@ -2133,7 +2129,7 @@ class FederationHandler(BaseHandler):
|
|||||||
# The other side isn't around or doesn't implement the
|
# The other side isn't around or doesn't implement the
|
||||||
# endpoint, so lets just bail out.
|
# endpoint, so lets just bail out.
|
||||||
logger.info("Failed to get event auth from remote: %s", e)
|
logger.info("Failed to get event auth from remote: %s", e)
|
||||||
return
|
return context
|
||||||
|
|
||||||
seen_remotes = yield self.store.have_seen_events(
|
seen_remotes = yield self.store.have_seen_events(
|
||||||
[e.event_id for e in remote_auth_chain]
|
[e.event_id for e in remote_auth_chain]
|
||||||
@ -2174,7 +2170,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
if event.internal_metadata.is_outlier():
|
if event.internal_metadata.is_outlier():
|
||||||
logger.info("Skipping auth_event fetch for outlier")
|
logger.info("Skipping auth_event fetch for outlier")
|
||||||
return
|
return context
|
||||||
|
|
||||||
# FIXME: Assumes we have and stored all the state for all the
|
# FIXME: Assumes we have and stored all the state for all the
|
||||||
# prev_events
|
# prev_events
|
||||||
@ -2183,7 +2179,7 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not different_auth:
|
if not different_auth:
|
||||||
return
|
return context
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"auth_events refers to events which are not in our calculated auth "
|
"auth_events refers to events which are not in our calculated auth "
|
||||||
@ -2230,10 +2226,12 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
auth_events.update(new_state)
|
auth_events.update(new_state)
|
||||||
|
|
||||||
yield self._update_context_for_auth_events(
|
context = yield self._update_context_for_auth_events(
|
||||||
event, context, auth_events, event_key
|
event, context, auth_events, event_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
|
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
|
||||||
"""Update the state_ids in an event context after auth event resolution,
|
"""Update the state_ids in an event context after auth event resolution,
|
||||||
@ -2242,14 +2240,16 @@ class FederationHandler(BaseHandler):
|
|||||||
Args:
|
Args:
|
||||||
event (Event): The event we're handling the context for
|
event (Event): The event we're handling the context for
|
||||||
|
|
||||||
context (synapse.events.snapshot.EventContext): event context
|
context (synapse.events.snapshot.EventContext): initial event context
|
||||||
to be updated
|
|
||||||
|
|
||||||
auth_events (dict[(str, str)->str]): Events to update in the event
|
auth_events (dict[(str, str)->str]): Events to update in the event
|
||||||
context.
|
context.
|
||||||
|
|
||||||
event_key ((str, str)): (type, state_key) for the current event.
|
event_key ((str, str)): (type, state_key) for the current event.
|
||||||
this will not be included in the current_state in the context.
|
this will not be included in the current_state in the context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[EventContext]: new event context
|
||||||
"""
|
"""
|
||||||
state_updates = {
|
state_updates = {
|
||||||
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
|
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
|
||||||
@ -2274,7 +2274,7 @@ class FederationHandler(BaseHandler):
|
|||||||
current_state_ids=current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield context.update_state(
|
return EventContext.with_state(
|
||||||
state_group=state_group,
|
state_group=state_group,
|
||||||
current_state_ids=current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
prev_state_ids=prev_state_ids,
|
prev_state_ids=prev_state_ids,
|
||||||
|
@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.handler = self.homeserver.get_handlers().federation_handler
|
self.handler = self.homeserver.get_handlers().federation_handler
|
||||||
self.handler.do_auth = lambda *a, **b: succeed(True)
|
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
|
||||||
|
context
|
||||||
|
)
|
||||||
self.client = self.homeserver.get_federation_client()
|
self.client = self.homeserver.get_federation_client()
|
||||||
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
|
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
|
||||||
pdus
|
pdus
|
||||||
|
Loading…
Reference in New Issue
Block a user