mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add include_event_in_state
to _get_state_for_room (#6521)
Make it return the state *after* the requested event, rather than the one before it. This is a bit easier and requires fewer calls to get_events_from_store_or_dest.
This commit is contained in:
parent
894d2addac
commit
2045356517
1
changelog.d/6521.misc
Normal file
1
changelog.d/6521.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Refactor some code in the event authentication path for clarity.
|
@ -378,22 +378,10 @@ class FederationHandler(BaseHandler):
|
|||||||
(
|
(
|
||||||
remote_state,
|
remote_state,
|
||||||
got_auth_chain,
|
got_auth_chain,
|
||||||
) = await self._get_state_for_room(origin, room_id, p)
|
) = await self._get_state_for_room(
|
||||||
|
origin, room_id, p, include_event_in_state=True
|
||||||
# we want the state *after* p; _get_state_for_room returns the
|
|
||||||
# state *before* p.
|
|
||||||
remote_event = await self.federation_client.get_pdu(
|
|
||||||
[origin], p, room_version, outlier=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if remote_event is None:
|
|
||||||
raise Exception(
|
|
||||||
"Unable to get missing prev_event %s" % (p,)
|
|
||||||
)
|
|
||||||
|
|
||||||
if remote_event.is_state():
|
|
||||||
remote_state.append(remote_event)
|
|
||||||
|
|
||||||
# XXX hrm I'm not convinced that duplicate events will compare
|
# XXX hrm I'm not convinced that duplicate events will compare
|
||||||
# for equality, so I'm not sure this does what the author
|
# for equality, so I'm not sure this does what the author
|
||||||
# hoped.
|
# hoped.
|
||||||
@ -579,20 +567,25 @@ class FederationHandler(BaseHandler):
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@log_function
|
|
||||||
async def _get_state_for_room(
|
async def _get_state_for_room(
|
||||||
self, destination: str, room_id: str, event_id: str
|
self,
|
||||||
|
destination: str,
|
||||||
|
room_id: str,
|
||||||
|
event_id: str,
|
||||||
|
include_event_in_state: bool = False,
|
||||||
) -> Tuple[List[EventBase], List[EventBase]]:
|
) -> Tuple[List[EventBase], List[EventBase]]:
|
||||||
"""Requests all of the room state at a given event from a remote homeserver.
|
"""Requests all of the room state at a given event from a remote homeserver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination:: The remote homeserver to query for the state.
|
destination: The remote homeserver to query for the state.
|
||||||
room_id: The id of the room we're interested in.
|
room_id: The id of the room we're interested in.
|
||||||
event_id: The id of the event we want the state at.
|
event_id: The id of the event we want the state at.
|
||||||
|
include_event_in_state: if true, the event itself will be included in the
|
||||||
|
returned state event list.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of events in the state, and a list of events in the auth chain
|
A list of events in the state, possibly including the event itself, and
|
||||||
for the given event.
|
a list of events in the auth chain for the given event.
|
||||||
"""
|
"""
|
||||||
(
|
(
|
||||||
state_event_ids,
|
state_event_ids,
|
||||||
@ -602,6 +595,10 @@ class FederationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
desired_events = set(state_event_ids + auth_event_ids)
|
desired_events = set(state_event_ids + auth_event_ids)
|
||||||
|
|
||||||
|
if include_event_in_state:
|
||||||
|
desired_events.add(event_id)
|
||||||
|
|
||||||
event_map = await self._get_events_from_store_or_dest(
|
event_map = await self._get_events_from_store_or_dest(
|
||||||
destination, room_id, desired_events
|
destination, room_id, desired_events
|
||||||
)
|
)
|
||||||
@ -614,12 +611,21 @@ class FederationHandler(BaseHandler):
|
|||||||
failed_to_fetch,
|
failed_to_fetch,
|
||||||
)
|
)
|
||||||
|
|
||||||
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
|
remote_state = [
|
||||||
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
|
event_map[e_id] for e_id in state_event_ids if e_id in event_map
|
||||||
|
]
|
||||||
|
|
||||||
|
if include_event_in_state:
|
||||||
|
remote_event = event_map.get(event_id)
|
||||||
|
if not remote_event:
|
||||||
|
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
||||||
|
if remote_event.is_state():
|
||||||
|
remote_state.append(remote_event)
|
||||||
|
|
||||||
|
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
|
||||||
auth_chain.sort(key=lambda e: e.depth)
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
return pdus, auth_chain
|
return remote_state, auth_chain
|
||||||
|
|
||||||
async def _get_events_from_store_or_dest(
|
async def _get_events_from_store_or_dest(
|
||||||
self, destination: str, room_id: str, event_ids: Iterable[str]
|
self, destination: str, room_id: str, event_ids: Iterable[str]
|
||||||
|
Loading…
Reference in New Issue
Block a user