Convert some of the federation handler methods to async/await. (#7338)

This commit is contained in:
Patrick Cloke 2020-04-24 14:36:38 -04:00 committed by GitHub
parent 69a1ac00b2
commit 33bceb7f70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 25 deletions

1
changelog.d/7338.misc Normal file
View File

@ -0,0 +1 @@
Convert some federation handler code to async/await.

View File

@ -343,7 +343,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen) ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(ours.values()) # type: list[StateMap[str]] state_maps = list(ours.values()) # type: List[StateMap[str]]
# we don't need this any more, let's delete it. # we don't need this any more, let's delete it.
del ours del ours
@ -1694,16 +1694,15 @@ class FederationHandler(BaseHandler):
return None return None
@defer.inlineCallbacks async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
def get_state_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event. """Returns the state at the event. i.e. not including said event.
""" """
event = yield self.store.get_event( event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id event_id, allow_none=False, check_room_id=room_id
) )
state_groups = yield self.state_store.get_state_groups(room_id, [event_id]) state_groups = await self.state_store.get_state_groups(room_id, [event_id])
if state_groups: if state_groups:
_, state = list(iteritems(state_groups)).pop() _, state = list(iteritems(state_groups)).pop()
@ -1714,7 +1713,7 @@ class FederationHandler(BaseHandler):
if "replaces_state" in event.unsigned: if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"] prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id: if prev_id != event.event_id:
prev_event = yield self.store.get_event(prev_id) prev_event = await self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event results[(event.type, event.state_key)] = prev_event
else: else:
del results[(event.type, event.state_key)] del results[(event.type, event.state_key)]
@ -1724,15 +1723,14 @@ class FederationHandler(BaseHandler):
else: else:
return [] return []
@defer.inlineCallbacks async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event. """Returns the state at the event. i.e. not including said event.
""" """
event = yield self.store.get_event( event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id event_id, allow_none=False, check_room_id=room_id
) )
state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id]) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups: if state_groups:
_, state = list(state_groups.items()).pop() _, state = list(state_groups.items()).pop()
@ -1751,49 +1749,50 @@ class FederationHandler(BaseHandler):
else: else:
return [] return []
@defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, pdu_list, limit): async def on_backfill_request(
in_room = yield self.auth.check_host_in_room(room_id, origin) self, origin: str, room_id: str, pdu_list: List[str], limit: int
) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
# Synapse asks for 100 events per backfill request. Do not allow more. # Synapse asks for 100 events per backfill request. Do not allow more.
limit = min(limit, 100) limit = min(limit, 100)
events = yield self.store.get_backfill_events(room_id, pdu_list, limit) events = await self.store.get_backfill_events(room_id, pdu_list, limit)
events = yield filter_events_for_server(self.storage, origin, events) events = await filter_events_for_server(self.storage, origin, events)
return events return events
@defer.inlineCallbacks
@log_function @log_function
def get_persisted_pdu(self, origin, event_id): async def get_persisted_pdu(
self, origin: str, event_id: str
) -> Optional[EventBase]:
"""Get an event from the database for the given server. """Get an event from the database for the given server.
Args: Args:
origin [str]: hostname of server which is requesting the event; we origin: hostname of server which is requesting the event; we
will check that the server is allowed to see it. will check that the server is allowed to see it.
event_id [str]: id of the event being requested event_id: id of the event being requested
Returns: Returns:
Deferred[EventBase|None]: None if we know nothing about the event; None if we know nothing about the event; otherwise the (possibly-redacted) event.
otherwise the (possibly-redacted) event.
Raises: Raises:
AuthError if the server is not currently in the room AuthError if the server is not currently in the room
""" """
event = yield self.store.get_event( event = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True event_id, allow_none=True, allow_rejected=True
) )
if event: if event:
in_room = yield self.auth.check_host_in_room(event.room_id, origin) in_room = await self.auth.check_host_in_room(event.room_id, origin)
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = yield filter_events_for_server(self.storage, origin, [event]) events = await filter_events_for_server(self.storage, origin, [event])
event = events[0] event = events[0]
return event return event
else: else:
@ -2397,7 +2396,7 @@ class FederationHandler(BaseHandler):
""" """
# exclude the state key of the new event from the current_state in the context. # exclude the state key of the new event from the current_state in the context.
if event.is_state(): if event.is_state():
event_key = (event.type, event.state_key) event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
else: else:
event_key = None event_key = None
state_updates = { state_updates = {