convert to async: FederationHandler.on_receive_pdu

and associated functions:
 * on_receive_pdu
 * handle_queued_pdus
 * get_missing_events_for_pdu
This commit is contained in:
Richard van der Hoff 2019-12-10 17:01:37 +00:00
parent 7712e751b8
commit e77237b935
2 changed files with 31 additions and 32 deletions

View File

@ -165,8 +165,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
@defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@ -176,8 +175,6 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
Returns (Deferred): completes with None
"""
room_id = pdu.room_id
@ -186,7 +183,7 @@ class FederationHandler(BaseHandler):
logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers
existing = yield self.store.get_event(
existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
@ -230,7 +227,7 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
@ -246,12 +243,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.get_min_depth_for_context(pdu.room_id)
min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
seen = yield self.store.have_seen_events(prevs)
seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@ -271,7 +268,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
shortstr(missing_prevs),
)
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
@ -280,7 +277,7 @@ class FederationHandler(BaseHandler):
)
try:
yield self._get_missing_events_for_pdu(
await self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth
)
except Exception as e:
@ -291,7 +288,7 @@ class FederationHandler(BaseHandler):
# Update the set of things we've seen after trying to
# fetch the missing stuff
seen = yield self.store.have_seen_events(prevs)
seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@ -355,7 +352,7 @@ class FederationHandler(BaseHandler):
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
ours = yield self.state_store.get_state_groups_ids(room_id, seen)
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@ -372,7 +369,7 @@ class FederationHandler(BaseHandler):
"Requesting state at missing prev_event %s", event_id,
)
room_version = yield self.store.get_room_version(room_id)
room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
@ -381,11 +378,11 @@ class FederationHandler(BaseHandler):
(
remote_state,
got_auth_chain,
) = yield self._get_state_for_room(origin, room_id, p)
) = await self._get_state_for_room(origin, room_id, p)
# we want the state *after* p; _get_state_for_room returns the
# state *before* p.
remote_event = yield self.federation_client.get_pdu(
remote_event = await self.federation_client.get_pdu(
[origin], p, room_version, outlier=True
)
@ -410,7 +407,7 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
state_map = yield resolve_events_with_store(
state_map = await resolve_events_with_store(
room_version,
state_maps,
event_map,
@ -422,7 +419,7 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = yield self.store.get_events(
evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.AS_IS,
@ -446,12 +443,11 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
yield self._process_received_pdu(
await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain
)
@defer.inlineCallbacks
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@ -463,12 +459,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
seen = yield self.store.have_seen_events(prevs)
seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
return
latest = yield self.store.get_latest_event_ids_in_room(room_id)
latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@ -532,7 +528,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens.
try:
missing_events = yield self.federation_client.get_missing_events(
missing_events = await self.federation_client.get_missing_events(
origin,
room_id,
earliest_events_ids=list(latest),
@ -571,7 +567,7 @@ class FederationHandler(BaseHandler):
)
with nested_logging_context(ev.event_id):
try:
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
@ -1328,8 +1324,7 @@ class FederationHandler(BaseHandler):
return True
@defer.inlineCallbacks
def _handle_queued_pdus(self, room_queue):
async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
@ -1345,7 +1340,7 @@ class FederationHandler(BaseHandler):
p.room_id,
)
with nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e

View File

@ -1,6 +1,6 @@
from mock import Mock
from twisted.internet.defer import maybeDeferred, succeed
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
from synapse.events import FrozenEvent
from synapse.logging.context import LoggingContext
@ -70,9 +70,11 @@ class MessageAcceptTests(unittest.TestCase):
)
# Send the join, it should return None (which is not an error)
d = self.handler.on_receive_pdu(
d = ensureDeferred(
self.handler.on_receive_pdu(
"test.serv", join_event, sent_to_us_directly=True
)
)
self.reactor.advance(1)
self.assertEqual(self.successResultOf(d), None)
@ -119,9 +121,11 @@ class MessageAcceptTests(unittest.TestCase):
)
with LoggingContext(request="lying_event"):
d = self.handler.on_receive_pdu(
d = ensureDeferred(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
)
)
# Step the reactor, so the database fetches come back
self.reactor.advance(1)