convert to async: FederationHandler.on_receive_pdu

and associated functions:
 * on_receive_pdu
 * handle_queued_pdus
 * get_missing_events_for_pdu
This commit is contained in:
Richard van der Hoff 2019-12-10 17:01:37 +00:00
parent 7712e751b8
commit e77237b935
2 changed files with 31 additions and 32 deletions

View File

@ -165,8 +165,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
@defer.inlineCallbacks async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
""" Process a PDU received via a federation /send/ transaction, or """ Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events via backfill of missing prev_events
@ -176,8 +175,6 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event. we pulled it as the result of a missing prev_event.
Returns (Deferred): completes with None
""" """
room_id = pdu.room_id room_id = pdu.room_id
@ -186,7 +183,7 @@ class FederationHandler(BaseHandler):
logger.info("handling received PDU: %s", pdu) logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers # We reprocess pdus when we have seen them only as outliers
existing = yield self.store.get_event( existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True event_id, allow_none=True, allow_rejected=True
) )
@ -230,7 +227,7 @@ class FederationHandler(BaseHandler):
# #
# Note that if we were never in the room then we would have already # Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version. # dropped the event, since we wouldn't know the room version.
is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room: if not is_in_room:
logger.info( logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room", "[%s %s] Ignoring PDU from %s as we're not in the room",
@ -246,12 +243,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary. # Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier(): if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth. # We only backfill backwards to the min depth.
min_depth = yield self.get_min_depth_for_context(pdu.room_id) min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids()) prevs = set(pdu.prev_event_ids())
seen = yield self.store.have_seen_events(prevs) seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth: if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this # This is so that we don't notify the user about this
@ -271,7 +268,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs), len(missing_prevs),
shortstr(missing_prevs), shortstr(missing_prevs),
) )
with (yield self._room_pdu_linearizer.queue(pdu.room_id)): with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info( logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events", "[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id, room_id,
@ -280,7 +277,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
yield self._get_missing_events_for_pdu( await self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth origin, pdu, prevs, min_depth
) )
except Exception as e: except Exception as e:
@ -291,7 +288,7 @@ class FederationHandler(BaseHandler):
# Update the set of things we've seen after trying to # Update the set of things we've seen after trying to
# fetch the missing stuff # fetch the missing stuff
seen = yield self.store.have_seen_events(prevs) seen = await self.store.have_seen_events(prevs)
if not prevs - seen: if not prevs - seen:
logger.info( logger.info(
@ -355,7 +352,7 @@ class FederationHandler(BaseHandler):
event_map = {event_id: pdu} event_map = {event_id: pdu}
try: try:
# Get the state of the events we know about # Get the state of the events we know about
ours = yield self.state_store.get_state_groups_ids(room_id, seen) ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list( state_maps = list(
@ -372,7 +369,7 @@ class FederationHandler(BaseHandler):
"Requesting state at missing prev_event %s", event_id, "Requesting state at missing prev_event %s", event_id,
) )
room_version = yield self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p): with nested_logging_context(p):
# note that if any of the missing prevs share missing state or # note that if any of the missing prevs share missing state or
@ -381,11 +378,11 @@ class FederationHandler(BaseHandler):
( (
remote_state, remote_state,
got_auth_chain, got_auth_chain,
) = yield self._get_state_for_room(origin, room_id, p) ) = await self._get_state_for_room(origin, room_id, p)
# we want the state *after* p; _get_state_for_room returns the # we want the state *after* p; _get_state_for_room returns the
# state *before* p. # state *before* p.
remote_event = yield self.federation_client.get_pdu( remote_event = await self.federation_client.get_pdu(
[origin], p, room_version, outlier=True [origin], p, room_version, outlier=True
) )
@ -410,7 +407,7 @@ class FederationHandler(BaseHandler):
for x in remote_state: for x in remote_state:
event_map[x.event_id] = x event_map[x.event_id] = x
state_map = yield resolve_events_with_store( state_map = await resolve_events_with_store(
room_version, room_version,
state_maps, state_maps,
event_map, event_map,
@ -422,7 +419,7 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in # First though we need to fetch all the events that are in
# state_map, so we can build up the state below. # state_map, so we can build up the state below.
evs = yield self.store.get_events( evs = await self.store.get_events(
list(state_map.values()), list(state_map.values()),
get_prev_content=False, get_prev_content=False,
redact_behaviour=EventRedactBehaviour.AS_IS, redact_behaviour=EventRedactBehaviour.AS_IS,
@ -446,12 +443,11 @@ class FederationHandler(BaseHandler):
affected=event_id, affected=event_id,
) )
yield self._process_received_pdu( await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain origin, pdu, state=state, auth_chain=auth_chain
) )
@defer.inlineCallbacks async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
""" """
Args: Args:
origin (str): Origin of the pdu. Will be called to get the missing events origin (str): Origin of the pdu. Will be called to get the missing events
@ -463,12 +459,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id room_id = pdu.room_id
event_id = pdu.event_id event_id = pdu.event_id
seen = yield self.store.have_seen_events(prevs) seen = await self.store.have_seen_events(prevs)
if not prevs - seen: if not prevs - seen:
return return
latest = yield self.store.get_latest_event_ids_in_room(room_id) latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest # We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us # list to ensure the remote server doesn't give them to us
@ -532,7 +528,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens. # All that said: Let's try increasing the timout to 60s and see what happens.
try: try:
missing_events = yield self.federation_client.get_missing_events( missing_events = await self.federation_client.get_missing_events(
origin, origin,
room_id, room_id,
earliest_events_ids=list(latest), earliest_events_ids=list(latest),
@ -571,7 +567,7 @@ class FederationHandler(BaseHandler):
) )
with nested_logging_context(ev.event_id): with nested_logging_context(ev.event_id):
try: try:
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e: except FederationError as e:
if e.code == 403: if e.code == 403:
logger.warning( logger.warning(
@ -1328,8 +1324,7 @@ class FederationHandler(BaseHandler):
return True return True
@defer.inlineCallbacks async def _handle_queued_pdus(self, room_queue):
def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining. """Process PDUs which got queued up while we were busy send_joining.
Args: Args:
@ -1345,7 +1340,7 @@ class FederationHandler(BaseHandler):
p.room_id, p.room_id,
) )
with nested_logging_context(p.event_id): with nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e "Error handling queued PDU %s from %s: %s", p.event_id, origin, e

View File

@ -1,6 +1,6 @@
from mock import Mock from mock import Mock
from twisted.internet.defer import maybeDeferred, succeed from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase):
) )
# Send the join, it should return None (which is not an error) # Send the join, it should return None (which is not an error)
d = self.handler.on_receive_pdu( d = ensureDeferred(
"test.serv", join_event, sent_to_us_directly=True self.handler.on_receive_pdu(
"test.serv", join_event, sent_to_us_directly=True
)
) )
self.reactor.advance(1) self.reactor.advance(1)
self.assertEqual(self.successResultOf(d), None) self.assertEqual(self.successResultOf(d), None)
@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase):
) )
with LoggingContext(request="lying_event"): with LoggingContext(request="lying_event"):
d = self.handler.on_receive_pdu( d = ensureDeferred(
"test.serv", lying_event, sent_to_us_directly=True self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
)
) )
# Step the reactor, so the database fetches come back # Step the reactor, so the database fetches come back