Add include_event_in_state to _get_state_for_room (#6521)

Make it return the state *after* the requested event, rather than the one
before it. This is a bit easier and requires fewer calls to
get_events_from_store_or_dest.
This commit is contained in:
Richard van der Hoff 2019-12-11 16:37:51 +00:00 committed by GitHub
parent 894d2addac
commit 2045356517
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 22 deletions

1
changelog.d/6521.misc Normal file
View File

@ -0,0 +1 @@
Refactor some code in the event authentication path for clarity.

View File

@ -378,22 +378,10 @@ class FederationHandler(BaseHandler):
( (
remote_state, remote_state,
got_auth_chain, got_auth_chain,
) = await self._get_state_for_room(origin, room_id, p) ) = await self._get_state_for_room(
origin, room_id, p, include_event_in_state=True
# we want the state *after* p; _get_state_for_room returns the
# state *before* p.
remote_event = await self.federation_client.get_pdu(
[origin], p, room_version, outlier=True
) )
if remote_event is None:
raise Exception(
"Unable to get missing prev_event %s" % (p,)
)
if remote_event.is_state():
remote_state.append(remote_event)
# XXX hrm I'm not convinced that duplicate events will compare # XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author # for equality, so I'm not sure this does what the author
# hoped. # hoped.
@ -579,20 +567,25 @@ class FederationHandler(BaseHandler):
else: else:
raise raise
@log_function
async def _get_state_for_room( async def _get_state_for_room(
self, destination: str, room_id: str, event_id: str self,
destination: str,
room_id: str,
event_id: str,
include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]: ) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver. """Requests all of the room state at a given event from a remote homeserver.
Args: Args:
destination:: The remote homeserver to query for the state. destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in. room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at. event_id: The id of the event we want the state at.
include_event_in_state: if true, the event itself will be included in the
returned state event list.
Returns: Returns:
A list of events in the state, and a list of events in the auth chain A list of events in the state, possibly including the event itself, and
for the given event. a list of events in the auth chain for the given event.
""" """
( (
state_event_ids, state_event_ids,
@ -602,6 +595,10 @@ class FederationHandler(BaseHandler):
) )
desired_events = set(state_event_ids + auth_event_ids) desired_events = set(state_event_ids + auth_event_ids)
if include_event_in_state:
desired_events.add(event_id)
event_map = await self._get_events_from_store_or_dest( event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events destination, room_id, desired_events
) )
@ -614,12 +611,21 @@ class FederationHandler(BaseHandler):
failed_to_fetch, failed_to_fetch,
) )
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] remote_state = [
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
if include_event_in_state:
remote_event = event_map.get(event_id)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
if remote_event.is_state():
remote_state.append(remote_event)
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
return pdus, auth_chain return remote_state, auth_chain
async def _get_events_from_store_or_dest( async def _get_events_from_store_or_dest(
self, destination: str, room_id: str, event_ids: Iterable[str] self, destination: str, room_id: str, event_ids: Iterable[str]