Merge pull request #6517 from matrix-org/rav/event_auth/13

Port some of FederationHandler to async/await
This commit is contained in:
Richard van der Hoff 2019-12-11 16:36:06 +00:00 committed by GitHub
commit 894d2addac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 103 deletions

1
changelog.d/6517.misc Normal file
View File

@ -0,0 +1 @@
Port some of FederationHandler to async/await.

View File

@ -19,7 +19,7 @@
import itertools import itertools
import logging import logging
from typing import Dict, Iterable, Optional, Sequence, Tuple from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six import six
from six import iteritems, itervalues from six import iteritems, itervalues
@ -165,8 +165,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
@defer.inlineCallbacks async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
""" Process a PDU received via a federation /send/ transaction, or """ Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events via backfill of missing prev_events
@ -176,8 +175,6 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event. we pulled it as the result of a missing prev_event.
Returns (Deferred): completes with None
""" """
room_id = pdu.room_id room_id = pdu.room_id
@ -186,7 +183,7 @@ class FederationHandler(BaseHandler):
logger.info("handling received PDU: %s", pdu) logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers # We reprocess pdus when we have seen them only as outliers
existing = yield self.store.get_event( existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True event_id, allow_none=True, allow_rejected=True
) )
@ -230,7 +227,7 @@ class FederationHandler(BaseHandler):
# #
# Note that if we were never in the room then we would have already # Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version. # dropped the event, since we wouldn't know the room version.
is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room: if not is_in_room:
logger.info( logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room", "[%s %s] Ignoring PDU from %s as we're not in the room",
@ -246,12 +243,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary. # Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier(): if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth. # We only backfill backwards to the min depth.
min_depth = yield self.get_min_depth_for_context(pdu.room_id) min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids()) prevs = set(pdu.prev_event_ids())
seen = yield self.store.have_seen_events(prevs) seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth: if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this # This is so that we don't notify the user about this
@ -271,7 +268,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs), len(missing_prevs),
shortstr(missing_prevs), shortstr(missing_prevs),
) )
with (yield self._room_pdu_linearizer.queue(pdu.room_id)): with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info( logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events", "[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id, room_id,
@ -280,7 +277,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
yield self._get_missing_events_for_pdu( await self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth origin, pdu, prevs, min_depth
) )
except Exception as e: except Exception as e:
@ -291,7 +288,7 @@ class FederationHandler(BaseHandler):
# Update the set of things we've seen after trying to # Update the set of things we've seen after trying to
# fetch the missing stuff # fetch the missing stuff
seen = yield self.store.have_seen_events(prevs) seen = await self.store.have_seen_events(prevs)
if not prevs - seen: if not prevs - seen:
logger.info( logger.info(
@ -355,7 +352,7 @@ class FederationHandler(BaseHandler):
event_map = {event_id: pdu} event_map = {event_id: pdu}
try: try:
# Get the state of the events we know about # Get the state of the events we know about
ours = yield self.state_store.get_state_groups_ids(room_id, seen) ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list( state_maps = list(
@ -372,7 +369,7 @@ class FederationHandler(BaseHandler):
"Requesting state at missing prev_event %s", event_id, "Requesting state at missing prev_event %s", event_id,
) )
room_version = yield self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p): with nested_logging_context(p):
# note that if any of the missing prevs share missing state or # note that if any of the missing prevs share missing state or
@ -381,11 +378,11 @@ class FederationHandler(BaseHandler):
( (
remote_state, remote_state,
got_auth_chain, got_auth_chain,
) = yield self._get_state_for_room(origin, room_id, p) ) = await self._get_state_for_room(origin, room_id, p)
# we want the state *after* p; _get_state_for_room returns the # we want the state *after* p; _get_state_for_room returns the
# state *before* p. # state *before* p.
remote_event = yield self.federation_client.get_pdu( remote_event = await self.federation_client.get_pdu(
[origin], p, room_version, outlier=True [origin], p, room_version, outlier=True
) )
@ -410,7 +407,7 @@ class FederationHandler(BaseHandler):
for x in remote_state: for x in remote_state:
event_map[x.event_id] = x event_map[x.event_id] = x
state_map = yield resolve_events_with_store( state_map = await resolve_events_with_store(
room_version, room_version,
state_maps, state_maps,
event_map, event_map,
@ -422,7 +419,7 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in # First though we need to fetch all the events that are in
# state_map, so we can build up the state below. # state_map, so we can build up the state below.
evs = yield self.store.get_events( evs = await self.store.get_events(
list(state_map.values()), list(state_map.values()),
get_prev_content=False, get_prev_content=False,
redact_behaviour=EventRedactBehaviour.AS_IS, redact_behaviour=EventRedactBehaviour.AS_IS,
@ -446,12 +443,11 @@ class FederationHandler(BaseHandler):
affected=event_id, affected=event_id,
) )
yield self._process_received_pdu( await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain origin, pdu, state=state, auth_chain=auth_chain
) )
@defer.inlineCallbacks async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
""" """
Args: Args:
origin (str): Origin of the pdu. Will be called to get the missing events origin (str): Origin of the pdu. Will be called to get the missing events
@ -463,12 +459,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id room_id = pdu.room_id
event_id = pdu.event_id event_id = pdu.event_id
seen = yield self.store.have_seen_events(prevs) seen = await self.store.have_seen_events(prevs)
if not prevs - seen: if not prevs - seen:
return return
latest = yield self.store.get_latest_event_ids_in_room(room_id) latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest # We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us # list to ensure the remote server doesn't give them to us
@ -532,7 +528,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens. # All that said: Let's try increasing the timout to 60s and see what happens.
try: try:
missing_events = yield self.federation_client.get_missing_events( missing_events = await self.federation_client.get_missing_events(
origin, origin,
room_id, room_id,
earliest_events_ids=list(latest), earliest_events_ids=list(latest),
@ -571,7 +567,7 @@ class FederationHandler(BaseHandler):
) )
with nested_logging_context(ev.event_id): with nested_logging_context(ev.event_id):
try: try:
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e: except FederationError as e:
if e.code == 403: if e.code == 403:
logger.warning( logger.warning(
@ -583,30 +579,30 @@ class FederationHandler(BaseHandler):
else: else:
raise raise
@defer.inlineCallbacks
@log_function @log_function
def _get_state_for_room(self, destination, room_id, event_id): async def _get_state_for_room(
self, destination: str, room_id: str, event_id: str
) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver. """Requests all of the room state at a given event from a remote homeserver.
Args: Args:
destination (str): The remote homeserver to query for the state. destination:: The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in. room_id: The id of the room we're interested in.
event_id (str): The id of the event we want the state at. event_id: The id of the event we want the state at.
Returns: Returns:
Deferred[Tuple[List[EventBase], List[EventBase]]]:
A list of events in the state, and a list of events in the auth chain A list of events in the state, and a list of events in the auth chain
for the given event. for the given event.
""" """
( (
state_event_ids, state_event_ids,
auth_event_ids, auth_event_ids,
) = yield self.federation_client.get_room_state_ids( ) = await self.federation_client.get_room_state_ids(
destination, room_id, event_id=event_id destination, room_id, event_id=event_id
) )
desired_events = set(state_event_ids + auth_event_ids) desired_events = set(state_event_ids + auth_event_ids)
event_map = yield self._get_events_from_store_or_dest( event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events destination, room_id, desired_events
) )
@ -625,20 +621,20 @@ class FederationHandler(BaseHandler):
return pdus, auth_chain return pdus, auth_chain
@defer.inlineCallbacks async def _get_events_from_store_or_dest(
def _get_events_from_store_or_dest(self, destination, room_id, event_ids): self, destination: str, room_id: str, event_ids: Iterable[str]
) -> Dict[str, EventBase]:
"""Fetch events from a remote destination, checking if we already have them. """Fetch events from a remote destination, checking if we already have them.
Args: Args:
destination (str) destination
room_id (str) room_id
event_ids (Iterable[str]) event_ids
Returns: Returns:
Deferred[dict[str, EventBase]]: A deferred resolving to a map map from event_id to event
from event_id to event
""" """
fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
missing_events = set(event_ids) - fetched_events.keys() missing_events = set(event_ids) - fetched_events.keys()
@ -651,7 +647,7 @@ class FederationHandler(BaseHandler):
event_ids, event_ids,
) )
room_version = yield self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
# XXX 20 requests at once? really? # XXX 20 requests at once? really?
for batch in batch_iter(missing_events, 20): for batch in batch_iter(missing_events, 20):
@ -665,7 +661,7 @@ class FederationHandler(BaseHandler):
for e_id in batch for e_id in batch
] ]
res = yield make_deferred_yieldable( res = await make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True) defer.DeferredList(deferreds, consumeErrors=True)
) )
for success, result in res: for success, result in res:
@ -674,8 +670,7 @@ class FederationHandler(BaseHandler):
return fetched_events return fetched_events
@defer.inlineCallbacks async def _process_received_pdu(self, origin, event, state, auth_chain):
def _process_received_pdu(self, origin, event, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it """ Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler. through the StateHandler.
""" """
@ -690,7 +685,7 @@ class FederationHandler(BaseHandler):
if auth_chain: if auth_chain:
event_ids |= {e.event_id for e in auth_chain} event_ids |= {e.event_id for e in auth_chain}
seen_ids = yield self.store.have_seen_events(event_ids) seen_ids = await self.store.have_seen_events(event_ids)
if state and auth_chain is not None: if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication # If we have any state or auth_chain given to us by the replication
@ -717,18 +712,18 @@ class FederationHandler(BaseHandler):
event_id, event_id,
[e.event.event_id for e in event_infos], [e.event.event_id for e in event_infos],
) )
yield self._handle_new_events(origin, event_infos) await self._handle_new_events(origin, event_infos)
try: try:
context = yield self._handle_new_event(origin, event, state=state) context = await self._handle_new_event(origin, event, state=state)
except AuthError as e: except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
room = yield self.store.get_room(room_id) room = await self.store.get_room(room_id)
if not room: if not room:
try: try:
yield self.store.store_room( await self.store.store_room(
room_id=room_id, room_creator_user_id="", is_public=False room_id=room_id, room_creator_user_id="", is_public=False
) )
except StoreError: except StoreError:
@ -741,11 +736,11 @@ class FederationHandler(BaseHandler):
# changing their profile info. # changing their profile info.
newly_joined = True newly_joined = True
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = await context.get_prev_state_ids(self.store)
prev_state_id = prev_state_ids.get((event.type, event.state_key)) prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id: if prev_state_id:
prev_state = yield self.store.get_event( prev_state = await self.store.get_event(
prev_state_id, allow_none=True prev_state_id, allow_none=True
) )
if prev_state and prev_state.membership == Membership.JOIN: if prev_state and prev_state.membership == Membership.JOIN:
@ -753,11 +748,10 @@ class FederationHandler(BaseHandler):
if newly_joined: if newly_joined:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield self.user_joined_room(user, room_id) await self.user_joined_room(user, room_id)
@log_function @log_function
@defer.inlineCallbacks async def backfill(self, dest, room_id, limit, extremities):
def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side This will attempt to get more events from the remote. If the other side
@ -774,9 +768,9 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
room_version = yield self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
events = yield self.federation_client.backfill( events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities dest, room_id, limit=limit, extremities=extremities
) )
@ -791,7 +785,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev) # self._sanity_check_event(ev)
# Don't bother processing events we already have. # Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline( seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events) set(e.event_id for e in events)
) )
@ -814,7 +808,7 @@ class FederationHandler(BaseHandler):
state_events = {} state_events = {}
events_to_state = {} events_to_state = {}
for e_id in edges: for e_id in edges:
state, auth = yield self._get_state_for_room( state, auth = await self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id destination=dest, room_id=room_id, event_id=e_id
) )
auth_events.update({a.event_id: a for a in auth}) auth_events.update({a.event_id: a for a in auth})
@ -839,7 +833,7 @@ class FederationHandler(BaseHandler):
# We repeatedly do this until we stop finding new auth events. # We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch: while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth) logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events) auth_events.update(ret_events)
required_auth.update( required_auth.update(
@ -853,7 +847,7 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch, missing_auth - failed_to_fetch,
) )
results = yield make_deferred_yieldable( results = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background( run_in_background(
@ -880,7 +874,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events) failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_seen_events( seen_events = await self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys()) set(auth_events.keys()) | set(state_events.keys())
) )
@ -942,7 +936,7 @@ class FederationHandler(BaseHandler):
) )
) )
yield self._handle_new_events(dest, ev_infos, backfilled=True) await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one # Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
@ -958,16 +952,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the # We store these one at a time since each event depends on the
# previous to work out the state. # previous to work out the state.
# TODO: We can probably do something more clever here. # TODO: We can probably do something more clever here.
yield self._handle_new_event(dest, event, backfilled=True) await self._handle_new_event(dest, event, backfilled=True)
return events return events
@defer.inlineCallbacks async def maybe_backfill(self, room_id, current_depth):
def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating, """Checks the database to see if we should backfill before paginating,
and if so do. and if so do.
""" """
extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities: if not extremities:
logger.debug("Not backfilling as no extremeties found.") logger.debug("Not backfilling as no extremeties found.")
@ -999,9 +992,9 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event # state *before* the event, ignoring the special casing certain event
# types have. # types have.
forward_events = yield self.store.get_successor_events(list(extremities)) forward_events = await self.store.get_successor_events(list(extremities))
extremities_events = yield self.store.get_events( extremities_events = await self.store.get_events(
forward_events, forward_events,
redact_behaviour=EventRedactBehaviour.AS_IS, redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False, get_prev_content=False,
@ -1009,7 +1002,7 @@ class FederationHandler(BaseHandler):
# We set `check_history_visibility_only` as we might otherwise get false # We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased. # positives from users having been erased.
filtered_extremities = yield filter_events_for_server( filtered_extremities = await filter_events_for_server(
self.storage, self.storage,
self.server_name, self.server_name,
list(extremities_events.values()), list(extremities_events.values()),
@ -1039,7 +1032,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room # First we try hosts that are already in the room
# TODO: HEURISTIC ALERT. # TODO: HEURISTIC ALERT.
curr_state = yield self.state_handler.get_current_state(room_id) curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state): def get_domains_from_state(state):
"""Get joined domains from state """Get joined domains from state
@ -1078,12 +1071,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name domain for domain, depth in curr_domains if domain != self.server_name
] ]
@defer.inlineCallbacks async def try_backfill(domains):
def try_backfill(domains):
# TODO: Should we try multiple of these at a time? # TODO: Should we try multiple of these at a time?
for dom in domains: for dom in domains:
try: try:
yield self.backfill( await self.backfill(
dom, room_id, limit=100, extremities=extremities dom, room_id, limit=100, extremities=extremities
) )
# If this succeeded then we probably already have the # If this succeeded then we probably already have the
@ -1114,7 +1106,7 @@ class FederationHandler(BaseHandler):
return False return False
success = yield try_backfill(likely_domains) success = await try_backfill(likely_domains)
if success: if success:
return True return True
@ -1128,7 +1120,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill") logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
states = yield make_deferred_yieldable( states = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
) )
@ -1138,7 +1130,7 @@ class FederationHandler(BaseHandler):
# event_ids. # event_ids.
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events( state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)], [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False, get_prev_content=False,
) )
@ -1154,7 +1146,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill( success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains] [dom for dom, _ in likely_domains if dom not in tried_domains]
) )
if success: if success:
@ -1331,8 +1323,7 @@ class FederationHandler(BaseHandler):
return True return True
@defer.inlineCallbacks async def _handle_queued_pdus(self, room_queue):
def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining. """Process PDUs which got queued up while we were busy send_joining.
Args: Args:
@ -1348,7 +1339,7 @@ class FederationHandler(BaseHandler):
p.room_id, p.room_id,
) )
with nested_logging_context(p.event_id): with nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@ -2907,7 +2898,7 @@ class FederationHandler(BaseHandler):
room_id=room_id, user_id=user.to_string(), change="joined" room_id=room_id, user_id=user.to_string(), change="joined"
) )
else: else:
return user_joined_room(self.distributor, user, room_id) return defer.succeed(user_joined_room(self.distributor, user, room_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id): def get_room_complexity(self, remote_room_hosts, room_id):

View File

@ -280,8 +280,7 @@ class PaginationHandler(object):
await self.storage.purge_events.purge_room(room_id) await self.storage.purge_events.purge_room(room_id)
@defer.inlineCallbacks async def get_messages(
def get_messages(
self, self,
requester, requester,
room_id=None, room_id=None,
@ -307,7 +306,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
else: else:
pagin_config.from_token = ( pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token_for_pagination() await self.hs.get_event_sources().get_current_token_for_pagination()
) )
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@ -319,11 +318,11 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
with (yield self.pagination_lock.read(room_id)): with (await self.pagination_lock.read(room_id)):
( (
membership, membership,
member_event_id, member_event_id,
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b": if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This # if we're going backwards, we might need to backfill. This
@ -331,7 +330,7 @@ class PaginationHandler(object):
if room_token.topological: if room_token.topological:
max_topo = room_token.topological max_topo = room_token.topological
else: else:
max_topo = yield self.store.get_max_topological_token( max_topo = await self.store.get_max_topological_token(
room_id, room_token.stream room_id, room_token.stream
) )
@ -339,18 +338,18 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before # If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the # they left the room, to save the effort of loading from the
# database. # database.
leave_token = yield self.store.get_topological_token_for_event( leave_token = await self.store.get_topological_token_for_event(
member_event_id member_event_id
) )
leave_token = RoomStreamToken.parse(leave_token) leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo: if leave_token.topological < max_topo:
source_config.from_key = str(leave_token) source_config.from_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill( await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo room_id, max_topo
) )
events, next_key = yield self.store.paginate_room_events( events, next_key = await self.store.paginate_room_events(
room_id=room_id, room_id=room_id,
from_key=source_config.from_key, from_key=source_config.from_key,
to_key=source_config.to_key, to_key=source_config.to_key,
@ -365,7 +364,7 @@ class PaginationHandler(object):
if event_filter: if event_filter:
events = event_filter.filter(events) events = event_filter.filter(events)
events = yield filter_events_for_client( events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None) self.storage, user_id, events, is_peeking=(member_event_id is None)
) )
@ -385,19 +384,19 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events (EventTypes.Member, event.sender) for event in events
) )
state_ids = yield self.state_store.get_state_ids_for_event( state_ids = await self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter events[0].event_id, state_filter=state_filter
) )
if state_ids: if state_ids:
state = yield self.store.get_events(list(state_ids.values())) state = await self.store.get_events(list(state_ids.values()))
state = state.values() state = state.values()
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
chunk = { chunk = {
"chunk": ( "chunk": (
yield self._event_serializer.serialize_events( await self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event events, time_now, as_client_event=as_client_event
) )
), ),
@ -406,7 +405,7 @@ class PaginationHandler(object):
} }
if state: if state:
chunk["state"] = yield self._event_serializer.serialize_events( chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event state, time_now, as_client_event=as_client_event
) )

View File

@ -1,6 +1,6 @@
from mock import Mock from mock import Mock
from twisted.internet.defer import maybeDeferred, succeed from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
@ -70,9 +70,11 @@ class MessageAcceptTests(unittest.TestCase):
) )
# Send the join, it should return None (which is not an error) # Send the join, it should return None (which is not an error)
d = self.handler.on_receive_pdu( d = ensureDeferred(
self.handler.on_receive_pdu(
"test.serv", join_event, sent_to_us_directly=True "test.serv", join_event, sent_to_us_directly=True
) )
)
self.reactor.advance(1) self.reactor.advance(1)
self.assertEqual(self.successResultOf(d), None) self.assertEqual(self.successResultOf(d), None)
@ -119,9 +121,11 @@ class MessageAcceptTests(unittest.TestCase):
) )
with LoggingContext(request="lying_event"): with LoggingContext(request="lying_event"):
d = self.handler.on_receive_pdu( d = ensureDeferred(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True "test.serv", lying_event, sent_to_us_directly=True
) )
)
# Step the reactor, so the database fetches come back # Step the reactor, so the database fetches come back
self.reactor.advance(1) self.reactor.advance(1)