From 8144bc26a7432463b7e70f9c03198d4724952522 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 12:21:34 -0400 Subject: [PATCH 01/26] Convert push to async/await. (#7948) --- changelog.d/7948.misc | 1 + synapse/push/action_generator.py | 7 +- synapse/push/bulk_push_rule_evaluator.py | 62 +++++++--------- synapse/push/httppusher.py | 58 +++++++-------- synapse/push/presentable_names.py | 15 ++-- synapse/push/push_tools.py | 22 +++--- synapse/push/pusherpool.py | 70 ++++++++----------- .../data_stores/main/event_push_actions.py | 4 +- .../replication/slave/storage/test_events.py | 6 +- tests/storage/test_event_push_actions.py | 6 +- 10 files changed, 106 insertions(+), 145 deletions(-) create mode 100644 changelog.d/7948.misc diff --git a/changelog.d/7948.misc b/changelog.d/7948.misc new file mode 100644 index 000000000..7c2e2b18b --- /dev/null +++ b/changelog.d/7948.misc @@ -0,0 +1 @@ +Convert push to async/await. diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 1ffd5e2df..0d2314265 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.util.metrics import Measure from .bulk_push_rule_evaluator import BulkPushRuleEvaluator @@ -37,7 +35,6 @@ class ActionGenerator(object): # event stream, so we just run the rules for a client with no profile # tag (ie. we just need all the users). - @defer.inlineCallbacks - def handle_push_actions_for_event(self, event, context): + async def handle_push_actions_for_event(self, event, context): with Measure(self.clock, "action_for_event_by_user"): - yield self.bulk_evaluator.action_for_event_by_user(event, context) + await self.bulk_evaluator.action_for_event_by_user(event, context) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 472ddf9f7..04b9d8ac8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,8 +19,6 @@ from collections import namedtuple from prometheus_client import Counter -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.event_auth import get_user_power_level from synapse.state import POWER_KEY @@ -70,8 +68,7 @@ class BulkPushRuleEvaluator(object): resizable=False, ) - @defer.inlineCallbacks - def _get_rules_for_event(self, event, context): + async def _get_rules_for_event(self, event, context): """This gets the rules for all users in the room at the time of the event, as well as the push rules for the invitee if the event is an invite. @@ -79,19 +76,19 @@ class BulkPushRuleEvaluator(object): dict of user_id -> push_rules """ room_id = event.room_id - rules_for_room = yield self._get_rules_for_room(room_id) + rules_for_room = await self._get_rules_for_room(room_id) - rules_by_user = yield rules_for_room.get_rules(event, context) + rules_by_user = await rules_for_room.get_rules(event, context) # if this event is an invite event, we may need to run rules for the user # who's been invited, otherwise they won't get told they've been invited if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): - has_pusher = yield self.store.user_has_pusher(invited) + has_pusher = await self.store.user_has_pusher(invited) if has_pusher: rules_by_user = dict(rules_by_user) - rules_by_user[invited] = yield self.store.get_push_rules_for_user( + rules_by_user[invited] = await self.store.get_push_rules_for_user( invited ) @@ -114,20 +111,19 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics, ) - @defer.inlineCallbacks - def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids() + async def _get_power_levels_and_sender_level(self, event, context): + prev_state_ids = await context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and # not having a power level event is an extreme edge case - pl_event = yield self.store.get_event(pl_event_id) + pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: pl_event} else: - auth_events_ids = yield self.auth.compute_auth_events( + auth_events_ids = await self.auth.compute_auth_events( event, prev_state_ids, for_verification=False ) - auth_events = yield self.store.get_events(auth_events_ids) + auth_events = await self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events.values()} sender_level = get_user_power_level(event.sender, auth_events) @@ -136,23 +132,19 @@ class BulkPushRuleEvaluator(object): return pl_event.content if pl_event else {}, sender_level - @defer.inlineCallbacks - def action_for_event_by_user(self, event, context): + async def action_for_event_by_user(self, event, context) -> None: """Given an event and context, evaluate the push rules and insert the results into the event_push_actions_staging table. - - Returns: - Deferred """ - rules_by_user = yield self._get_rules_for_event(event, context) + rules_by_user = await self._get_rules_for_event(event, context) actions_by_user = {} - room_members = yield self.store.get_joined_users_from_context(event, context) + room_members = await self.store.get_joined_users_from_context(event, context) ( power_levels, sender_power_level, - ) = yield self._get_power_levels_and_sender_level(event, context) + ) = await self._get_power_levels_and_sender_level(event, context) evaluator = PushRuleEvaluatorForEvent( event, len(room_members), sender_power_level, power_levels @@ -165,7 +157,7 @@ class BulkPushRuleEvaluator(object): continue if not event.is_state(): - is_ignored = yield self.store.is_ignored_by(event.sender, uid) + is_ignored = await self.store.is_ignored_by(event.sender, uid) if is_ignored: continue @@ -197,7 +189,7 @@ class BulkPushRuleEvaluator(object): # Mark in the DB staging area the push actions for users who should be # notified for this event. (This will then get handled when we persist # the event) - yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) + await self.store.add_push_actions_to_staging(event.event_id, actions_by_user) def _condition_checker(evaluator, conditions, uid, display_name, cache): @@ -274,8 +266,7 @@ class RulesForRoom(object): # to self around in the callback. self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) - @defer.inlineCallbacks - def get_rules(self, event, context): + async def get_rules(self, event, context): """Given an event context return the rules for all users who are currently in the room. """ @@ -286,7 +277,7 @@ class RulesForRoom(object): self.room_push_rule_cache_metrics.inc_hits() return self.rules_by_user - with (yield self.linearizer.queue(())): + with (await self.linearizer.queue(())): if state_group and self.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() @@ -304,9 +295,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield defer.ensureDeferred( - context.get_current_state_ids() - ) + current_state_ids = await context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) @@ -353,7 +342,7 @@ class RulesForRoom(object): # If we have some memebr events we haven't seen, look them up # and fetch push rules for them if appropriate. logger.debug("Found new member events %r", missing_member_event_ids) - yield self._update_rules_with_member_event_ids( + await self._update_rules_with_member_event_ids( ret_rules_by_user, missing_member_event_ids, state_group, event ) else: @@ -371,8 +360,7 @@ class RulesForRoom(object): ) return ret_rules_by_user - @defer.inlineCallbacks - def _update_rules_with_member_event_ids( + async def _update_rules_with_member_event_ids( self, ret_rules_by_user, member_event_ids, state_group, event ): """Update the partially filled rules_by_user dict by fetching rules for @@ -388,7 +376,7 @@ class RulesForRoom(object): """ sequence = self.sequence - rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) + rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} @@ -410,7 +398,7 @@ class RulesForRoom(object): logger.debug("Joined: %r", interested_in_user_ids) - if_users_with_pushers = yield self.store.get_if_users_have_pushers( + if_users_with_pushers = await self.store.get_if_users_have_pushers( interested_in_user_ids, on_invalidate=self.invalidate_all_cb ) @@ -420,7 +408,7 @@ class RulesForRoom(object): logger.debug("With pushers: %r", user_ids) - users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( + users_with_receipts = await self.store.get_users_with_read_receipts_in_room( self.room_id, on_invalidate=self.invalidate_all_cb ) @@ -431,7 +419,7 @@ class RulesForRoom(object): if uid in interested_in_user_ids: user_ids.add(uid) - rules_by_user = yield self.store.bulk_get_push_rules( + rules_by_user = await self.store.bulk_get_push_rules( user_ids, on_invalidate=self.invalidate_all_cb ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 2fac07593..4c469efb2 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -17,7 +17,6 @@ import logging from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.api.constants import EventTypes @@ -128,12 +127,11 @@ class HttpPusher(object): # but currently that's the only type of receipt anyway... run_as_background_process("http_pusher.on_new_receipts", self._update_badge) - @defer.inlineCallbacks - def _update_badge(self): + async def _update_badge(self): # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # to be largely redundant. perhaps we can remove it. - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - yield self._send_badge(badge) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + await self._send_badge(badge) def on_timer(self): self._start_processing() @@ -152,8 +150,7 @@ class HttpPusher(object): run_as_background_process("httppush.process", self._process) - @defer.inlineCallbacks - def _process(self): + async def _process(self): # we should never get here if we are already processing assert not self._is_processing @@ -164,7 +161,7 @@ class HttpPusher(object): while True: starting_max_ordering = self.max_stream_ordering try: - yield self._unsafe_process() + await self._unsafe_process() except Exception: logger.exception("Exception processing notifs") if self.max_stream_ordering == starting_max_ordering: @@ -172,8 +169,7 @@ class HttpPusher(object): finally: self._is_processing = False - @defer.inlineCallbacks - def _unsafe_process(self): + async def _unsafe_process(self): """ Looks for unset notifications and dispatch them, in order Never call this directly: use _process which will only allow this to @@ -181,7 +177,7 @@ class HttpPusher(object): """ fn = self.store.get_unread_push_actions_for_user_in_range_for_http - unprocessed = yield fn( + unprocessed = await fn( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) @@ -203,13 +199,13 @@ class HttpPusher(object): "app_display_name": self.app_display_name, }, ): - processed = yield self._process_one(push_action) + processed = await self._process_one(push_action) if processed: http_push_processed_counter.inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( self.app_id, self.pushkey, self.user_id, @@ -224,14 +220,14 @@ class HttpPusher(object): if self.failing_since: self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: http_push_failed_counter.inc() if not self.failing_since: self.failing_since = self.clock.time_msec() - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) @@ -250,7 +246,7 @@ class HttpPusher(object): ) self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action["stream_ordering"] - pusher_still_exists = yield self.store.update_pusher_last_stream_ordering( + pusher_still_exists = await self.store.update_pusher_last_stream_ordering( self.app_id, self.pushkey, self.user_id, @@ -263,7 +259,7 @@ class HttpPusher(object): return self.failing_since = None - yield self.store.update_pusher_failing_since( + await self.store.update_pusher_failing_since( self.app_id, self.pushkey, self.user_id, self.failing_since ) else: @@ -276,18 +272,17 @@ class HttpPusher(object): ) break - @defer.inlineCallbacks - def _process_one(self, push_action): + async def _process_one(self, push_action): if "notify" not in push_action["actions"]: return True tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) - badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) + badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) - event = yield self.store.get_event(push_action["event_id"], allow_none=True) + event = await self.store.get_event(push_action["event_id"], allow_none=True) if event is None: return True # It's been redacted - rejected = yield self.dispatch_push(event, tweaks, badge) + rejected = await self.dispatch_push(event, tweaks, badge) if rejected is False: return False @@ -301,11 +296,10 @@ class HttpPusher(object): ) else: logger.info("Pushkey %s was rejected: removing", pk) - yield self.hs.remove_pusher(self.app_id, pk, self.user_id) + await self.hs.remove_pusher(self.app_id, pk, self.user_id) return True - @defer.inlineCallbacks - def _build_notification_dict(self, event, tweaks, badge): + async def _build_notification_dict(self, event, tweaks, badge): priority = "low" if ( event.type == EventTypes.Encrypted @@ -335,7 +329,7 @@ class HttpPusher(object): } return d - ctx = yield push_tools.get_context_for_event( + ctx = await push_tools.get_context_for_event( self.storage, self.state_handler, event, self.user_id ) @@ -377,13 +371,12 @@ class HttpPusher(object): return d - @defer.inlineCallbacks - def dispatch_push(self, event, tweaks, badge): - notification_dict = yield self._build_notification_dict(event, tweaks, badge) + async def dispatch_push(self, event, tweaks, badge): + notification_dict = await self._build_notification_dict(event, tweaks, badge) if not notification_dict: return [] try: - resp = yield self.http_client.post_json_get_json( + resp = await self.http_client.post_json_get_json( self.url, notification_dict ) except Exception as e: @@ -400,8 +393,7 @@ class HttpPusher(object): rejected = resp["rejected"] return rejected - @defer.inlineCallbacks - def _send_badge(self, badge): + async def _send_badge(self, badge): """ Args: badge (int): number of unread messages @@ -424,7 +416,7 @@ class HttpPusher(object): } } try: - yield self.http_client.post_json_get_json(self.url, d) + await self.http_client.post_json_get_json(self.url, d) http_badges_processed_counter.inc() except Exception as e: logger.warning( diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py index 0644a13cf..d8f4a453c 100644 --- a/synapse/push/presentable_names.py +++ b/synapse/push/presentable_names.py @@ -16,8 +16,6 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes logger = logging.getLogger(__name__) @@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$") ALL_ALONE = "Empty Room" -@defer.inlineCallbacks -def calculate_room_name( +async def calculate_room_name( store, room_state_ids, user_id, @@ -53,7 +50,7 @@ def calculate_room_name( """ # does it have a name? if (EventTypes.Name, "") in room_state_ids: - m_room_name = yield store.get_event( + m_room_name = await store.get_event( room_state_ids[(EventTypes.Name, "")], allow_none=True ) if m_room_name and m_room_name.content and m_room_name.content["name"]: @@ -61,7 +58,7 @@ def calculate_room_name( # does it have a canonical alias? if (EventTypes.CanonicalAlias, "") in room_state_ids: - canon_alias = yield store.get_event( + canon_alias = await store.get_event( room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True ) if ( @@ -81,7 +78,7 @@ def calculate_room_name( my_member_event = None if (EventTypes.Member, user_id) in room_state_ids: - my_member_event = yield store.get_event( + my_member_event = await store.get_event( room_state_ids[(EventTypes.Member, user_id)], allow_none=True ) @@ -90,7 +87,7 @@ def calculate_room_name( and my_member_event.content["membership"] == "invite" ): if (EventTypes.Member, my_member_event.sender) in room_state_ids: - inviter_member_event = yield store.get_event( + inviter_member_event = await store.get_event( room_state_ids[(EventTypes.Member, my_member_event.sender)], allow_none=True, ) @@ -107,7 +104,7 @@ def calculate_room_name( # we're going to have to generate a name based on who's in the room, # so find out who is in the room that isn't the user. if EventTypes.Member in room_state_bytype_ids: - member_events = yield store.get_events( + member_events = await store.get_events( list(room_state_bytype_ids[EventTypes.Member].values()) ) all_members = [ diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 5dae4648c..d0145666b 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage -@defer.inlineCallbacks -def get_badge_count(store, user_id): - invites = yield store.get_invited_rooms_for_local_user(user_id) - joins = yield store.get_rooms_for_user(user_id) +async def get_badge_count(store, user_id): + invites = await store.get_invited_rooms_for_local_user(user_id) + joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") badge = len(invites) @@ -32,7 +29,7 @@ def get_badge_count(store, user_id): if room_id in my_receipts_by_room: last_unread_event_id = my_receipts_by_room[room_id] - notifs = yield ( + notifs = await ( store.get_unread_event_push_actions_by_room_for_user( room_id, user_id, last_unread_event_id ) @@ -43,23 +40,22 @@ def get_badge_count(store, user_id): return badge -@defer.inlineCallbacks -def get_context_for_event(storage: Storage, state_handler, ev, user_id): +async def get_context_for_event(storage: Storage, state_handler, ev, user_id): ctx = {} - room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) + room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id) # we no longer bother setting room_alias, and make room_name the # human-readable name instead, be that m.room.name, an alias or # a list of people in the room - name = yield calculate_room_name( + name = await calculate_room_name( storage.main, room_state_ids, user_id, fallback_to_single_member=False ) if name: ctx["name"] = name sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] - sender_state_event = yield storage.main.get_event(sender_state_event_id) + sender_state_event = await storage.main.get_event(sender_state_event_id) ctx["sender_display_name"] = name_from_member_event(sender_state_event) return ctx diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 2456f12f4..3c3262a88 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Dict, Union from prometheus_client import Gauge -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import PusherConfigException from synapse.push.emailpusher import EmailPusher @@ -52,7 +50,7 @@ class PusherPool: Note that it is expected that each pusher will have its own 'processing' loop which will send out the notifications in the background, rather than blocking until the notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and - Pusher.on_new_receipts are not expected to return deferreds. + Pusher.on_new_receipts are not expected to return awaitables. """ def __init__(self, hs: "HomeServer"): @@ -77,8 +75,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - @defer.inlineCallbacks - def add_pusher( + async def add_pusher( self, user_id, access_token, @@ -94,7 +91,7 @@ class PusherPool: """Creates a new pusher and adds it to the pool Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ time_now_msec = self.clock.time_msec() @@ -124,9 +121,9 @@ class PusherPool: # create the pusher setting last_stream_ordering to the current maximum # stream ordering in event_push_actions, so it will process # pushes from this point onwards. - last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() + last_stream_ordering = await self.store.get_latest_push_action_stream_ordering() - yield self.store.add_pusher( + await self.store.add_pusher( user_id=user_id, access_token=access_token, kind=kind, @@ -140,15 +137,14 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, ) - pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) return pusher - @defer.inlineCallbacks - def remove_pushers_by_app_id_and_pushkey_not_user( + async def remove_pushers_by_app_id_and_pushkey_not_user( self, app_id, pushkey, not_user_id ): - to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) for p in to_remove: if p["user_name"] != not_user_id: logger.info( @@ -157,10 +153,9 @@ class PusherPool: pushkey, p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def remove_pushers_by_access_token(self, user_id, access_tokens): + async def remove_pushers_by_access_token(self, user_id, access_tokens): """Remove the pushers for a given user corresponding to a set of access_tokens. @@ -173,7 +168,7 @@ class PusherPool: return tokens = set(access_tokens) - for p in (yield self.store.get_pushers_by_user_id(user_id)): + for p in await self.store.get_pushers_by_user_id(user_id): if p["access_token"] in tokens: logger.info( "Removing pusher for app id %s, pushkey %s, user %s", @@ -181,16 +176,15 @@ class PusherPool: p["pushkey"], p["user_name"], ) - yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) + await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - @defer.inlineCallbacks - def on_new_notifications(self, min_stream_id, max_stream_id): + async def on_new_notifications(self, min_stream_id, max_stream_id): if not self.pushers: # nothing to do here. return try: - users_affected = yield self.store.get_push_action_users_in_range( + users_affected = await self.store.get_push_action_users_in_range( min_stream_id, max_stream_id ) @@ -202,8 +196,7 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_notifications") - @defer.inlineCallbacks - def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): + async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): if not self.pushers: # nothing to do here. return @@ -211,7 +204,7 @@ class PusherPool: try: # Need to subtract 1 from the minimum because the lower bound here # is not inclusive - users_affected = yield self.store.get_users_sent_receipts_between( + users_affected = await self.store.get_users_sent_receipts_between( min_stream_id - 1, max_stream_id ) @@ -223,12 +216,11 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - @defer.inlineCallbacks - def start_pusher_by_id(self, app_id, pushkey, user_id): + async def start_pusher_by_id(self, app_id, pushkey, user_id): """Look up the details for the given pusher, and start it Returns: - Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any + EmailPusher|HttpPusher|None: The pusher started, if any """ if not self._should_start_pushers: return @@ -236,7 +228,7 @@ class PusherPool: if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return - resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_dict = None for r in resultlist: @@ -245,34 +237,29 @@ class PusherPool: pusher = None if pusher_dict: - pusher = yield self._start_pusher(pusher_dict) + pusher = await self._start_pusher(pusher_dict) return pusher - @defer.inlineCallbacks - def _start_pushers(self): + async def _start_pushers(self) -> None: """Start all the pushers - - Returns: - Deferred """ - pushers = yield self.store.get_all_pushers() + pushers = await self.store.get_all_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. - yield concurrently_execute(self._start_pusher, pushers, 10) + await concurrently_execute(self._start_pusher, pushers, 10) logger.info("Started pushers") - @defer.inlineCallbacks - def _start_pusher(self, pusherdict): + async def _start_pusher(self, pusherdict): """Start the given pusher Args: pusherdict (dict): dict with the values pulled from the db table Returns: - Deferred[EmailPusher|HttpPusher] + EmailPusher|HttpPusher """ if not self._pusher_shard_config.should_handle( self._instance_name, pusherdict["user_name"] @@ -315,7 +302,7 @@ class PusherPool: user_id = pusherdict["user_name"] last_stream_ordering = pusherdict["last_stream_ordering"] if last_stream_ordering: - have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( + have_notifs = await self.store.get_if_maybe_push_in_range_for_user( user_id, last_stream_ordering ) else: @@ -327,8 +314,7 @@ class PusherPool: return p - @defer.inlineCallbacks - def remove_pusher(self, app_id, pushkey, user_id): + async def remove_pusher(self, app_id, pushkey, user_id): appid_pushkey = "%s:%s" % (app_id, pushkey) byuser = self.pushers.get(user_id, {}) @@ -340,6 +326,6 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() - yield self.store.delete_pusher_by_app_id_pushkey_user_id( + await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 504babaa7..18297cf3b 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -411,7 +411,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): _get_if_maybe_push_in_range_for_user_txn, ) - def add_push_actions_to_staging(self, event_id, user_id_actions): + async def add_push_actions_to_staging(self, event_id, user_id_actions): """Add the push actions for the event to the push action staging area. Args: @@ -457,7 +457,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): ), ) - return self.db.runInteraction( + return await self.db.runInteraction( "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 1a88c7fb8..0b5204654 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): state_handler = self.hs.get_state_handler() context = self.get_success(state_handler.compute_event_context(event)) - self.master_store.add_push_actions_to_staging( - event.event_id, {user_id: actions for user_id, actions in push_actions} + self.get_success( + self.master_store.add_push_actions_to_staging( + event.event_id, {user_id: actions for user_id, actions in push_actions} + ) ) return event, context diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index b45bc9c11..43dbeb42c 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -72,8 +72,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): event.internal_metadata.stream_ordering = stream event.depth = stream - yield self.store.add_push_actions_to_staging( - event.event_id, {user_id: action} + yield defer.ensureDeferred( + self.store.add_push_actions_to_staging( + event.event_id, {user_id: action} + ) ) yield self.store.db.runInteraction( "", From 5f65e6268146a5ae7b8dafdfe2290b791e8b4c92 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 12:32:08 -0400 Subject: [PATCH 02/26] Convert groups and visibility code to async / await. (#7951) --- changelog.d/7951.misc | 1 + synapse/groups/attestations.py | 25 +++++++++++-------------- synapse/visibility.py | 30 +++++++++++++----------------- tests/test_visibility.py | 12 ++++++------ 4 files changed, 31 insertions(+), 37 deletions(-) create mode 100644 changelog.d/7951.misc diff --git a/changelog.d/7951.misc b/changelog.d/7951.misc new file mode 100644 index 000000000..cbba4fa82 --- /dev/null +++ b/changelog.d/7951.misc @@ -0,0 +1 @@ +Convert groups and visibility code to async / await. diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index dab13c243..e674bf44a 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -41,8 +41,6 @@ from typing import Tuple from signedjson.sign import sign_json -from twisted.internet import defer - from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import get_domain_from_id @@ -72,8 +70,9 @@ class GroupAttestationSigning(object): self.server_name = hs.hostname self.signing_key = hs.signing_key - @defer.inlineCallbacks - def verify_attestation(self, attestation, group_id, user_id, server_name=None): + async def verify_attestation( + self, attestation, group_id, user_id, server_name=None + ): """Verifies that the given attestation matches the given parameters. An optional server_name can be supplied to explicitly set which server's @@ -102,7 +101,7 @@ class GroupAttestationSigning(object): if valid_until_ms < now: raise SynapseError(400, "Attestation expired") - yield self.keyring.verify_json_for_server( + await self.keyring.verify_json_for_server( server_name, attestation, now, "Group attestation" ) @@ -142,8 +141,7 @@ class GroupAttestionRenewer(object): self._start_renew_attestations, 30 * 60 * 1000 ) - @defer.inlineCallbacks - def on_renew_attestation(self, group_id, user_id, content): + async def on_renew_attestation(self, group_id, user_id, content): """When a remote updates an attestation """ attestation = content["attestation"] @@ -151,11 +149,11 @@ class GroupAttestionRenewer(object): if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): raise SynapseError(400, "Neither user not group are on this server") - yield self.attestations.verify_attestation( + await self.attestations.verify_attestation( attestation, user_id=user_id, group_id=group_id ) - yield self.store.update_remote_attestion(group_id, user_id, attestation) + await self.store.update_remote_attestion(group_id, user_id, attestation) return {} @@ -172,8 +170,7 @@ class GroupAttestionRenewer(object): now + UPDATE_ATTESTATION_TIME_MS ) - @defer.inlineCallbacks - def _renew_attestation(group_user: Tuple[str, str]): + async def _renew_attestation(group_user: Tuple[str, str]): group_id, user_id = group_user try: if not self.is_mine_id(group_id): @@ -186,16 +183,16 @@ class GroupAttestionRenewer(object): user_id, group_id, ) - yield self.store.remove_attestation_renewal(group_id, user_id) + await self.store.remove_attestation_renewal(group_id, user_id) return attestation = self.attestations.create_attestation(group_id, user_id) - yield self.transport_client.renew_group_attestation( + await self.transport_client.renew_group_attestation( destination, group_id, user_id, content={"attestation": attestation} ) - yield self.store.update_attestation_renewal( + await self.store.update_attestation_renewal( group_id, user_id, attestation ) except (RequestSendFailed, HttpResponseException) as e: diff --git a/synapse/visibility.py b/synapse/visibility.py index 0f042c569..e3da7744d 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -16,8 +16,6 @@ import logging import operator -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events.utils import prune_event from synapse.storage import Storage @@ -39,8 +37,7 @@ MEMBERSHIP_PRIORITY = ( ) -@defer.inlineCallbacks -def filter_events_for_client( +async def filter_events_for_client( storage: Storage, user_id, events, @@ -67,19 +64,19 @@ def filter_events_for_client( also be called to check whether a user can see the state at a given point. Returns: - Deferred[list[synapse.events.EventBase]] + list[synapse.events.EventBase] """ # Filter out events that have been soft failed so that we don't relay them # to clients. events = [e for e in events if not e.internal_metadata.is_soft_failed()] types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id)) - event_id_to_state = yield storage.state.get_state_for_events( + event_id_to_state = await storage.state.get_state_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types(types), ) - ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user( + ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user( "m.ignored_user_list", user_id ) @@ -90,7 +87,7 @@ def filter_events_for_client( else [] ) - erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased((e.sender for e in events)) if filter_send_to_client: room_ids = {e.room_id for e in events} @@ -99,7 +96,7 @@ def filter_events_for_client( for room_id in room_ids: retention_policies[ room_id - ] = yield storage.main.get_retention_policy_for_room(room_id) + ] = await storage.main.get_retention_policy_for_room(room_id) def allowed(event): """ @@ -254,8 +251,7 @@ def filter_events_for_client( return list(filtered_events) -@defer.inlineCallbacks -def filter_events_for_server( +async def filter_events_for_server( storage: Storage, server_name, events, @@ -277,7 +273,7 @@ def filter_events_for_server( backfill or not. Returns - Deferred[list[FrozenEvent]] + list[FrozenEvent] """ def is_sender_erased(event, erased_senders): @@ -321,7 +317,7 @@ def filter_events_for_server( # Lets check to see if all the events have a history visibility # of "shared" or "world_readable". If that's the case then we don't # need to check membership (as we know the server is in the room). - event_to_state_ids = yield storage.state.get_state_ids_for_events( + event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""),) @@ -339,14 +335,14 @@ def filter_events_for_server( if not visibility_ids: all_open = True else: - event_map = yield storage.main.get_events(visibility_ids) + event_map = await storage.main.get_events(visibility_ids) all_open = all( e.content.get("history_visibility") in (None, "shared", "world_readable") for e in event_map.values() ) if not check_history_visibility_only: - erased_senders = yield storage.main.are_users_erased((e.sender for e in events)) + erased_senders = await storage.main.are_users_erased((e.sender for e in events)) else: # We don't want to check whether users are erased, which is equivalent # to no users having been erased. @@ -375,7 +371,7 @@ def filter_events_for_server( # first, for each event we're wanting to return, get the event_ids # of the history vis and membership state at those events. - event_to_state_ids = yield storage.state.get_state_ids_for_events( + event_to_state_ids = await storage.state.get_state_ids_for_events( frozenset(e.event_id for e in events), state_filter=StateFilter.from_types( types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None)) @@ -405,7 +401,7 @@ def filter_events_for_server( return False return state_key[idx + 1 :] == server_name - event_map = yield storage.main.get_events( + event_map = await storage.main.get_events( [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])] ) diff --git a/tests/test_visibility.py b/tests/test_visibility.py index b371efc0d..a7a36174e 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): evt = yield self.inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) # the result should be 5 redacted events, and 5 unredacted events. @@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): yield self.hs.get_datastore().mark_user_erased("@erased:local_hs") # ... and the filtering happens. - filtered = yield filter_events_for_server( - self.storage, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(self.storage, "test_server", events_to_filter) ) for i in range(0, len(events_to_filter)): @@ -265,8 +265,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): storage.main = test_store storage.state = test_store - filtered = yield filter_events_for_server( - test_store, "test_server", events_to_filter + filtered = yield defer.ensureDeferred( + filter_events_for_server(test_store, "test_server", events_to_filter) ) logger.info("Filtering took %f seconds", time.time() - start) From 8553f4649857c7862e30917adc925642ad684a10 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 13:40:22 -0400 Subject: [PATCH 03/26] Convert a synapse.events to async/await. (#7949) --- changelog.d/7948.misc | 2 +- changelog.d/7949.misc | 1 + changelog.d/7951.misc | 2 +- synapse/api/auth.py | 2 +- synapse/events/builder.py | 19 ++++----- synapse/events/snapshot.py | 46 +++++++++++---------- synapse/events/third_party_rules.py | 55 ++++++++++++++------------ synapse/events/utils.py | 15 ++++--- synapse/handlers/federation.py | 2 +- synapse/replication/http/federation.py | 4 +- synapse/replication/http/send_event.py | 2 +- tests/storage/test_redaction.py | 4 +- tests/test_state.py | 14 +++---- 13 files changed, 86 insertions(+), 82 deletions(-) create mode 100644 changelog.d/7949.misc diff --git a/changelog.d/7948.misc b/changelog.d/7948.misc index 7c2e2b18b..dfe4c0317 100644 --- a/changelog.d/7948.misc +++ b/changelog.d/7948.misc @@ -1 +1 @@ -Convert push to async/await. +Convert various parts of the codebase to async/await. diff --git a/changelog.d/7949.misc b/changelog.d/7949.misc new file mode 100644 index 000000000..dfe4c0317 --- /dev/null +++ b/changelog.d/7949.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/changelog.d/7951.misc b/changelog.d/7951.misc index cbba4fa82..dfe4c0317 100644 --- a/changelog.d/7951.misc +++ b/changelog.d/7951.misc @@ -1 +1 @@ -Convert groups and visibility code to async / await. +Convert various parts of the codebase to async/await. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index b53e8451e..2178e623d 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -82,7 +82,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, room_version: str, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 0bb216419..69b53ca2b 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -17,8 +17,6 @@ from typing import Optional import attr from nacl.signing import SigningKey -from twisted.internet import defer - from synapse.api.constants import MAX_DEPTH from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.room_versions import ( @@ -95,31 +93,30 @@ class EventBuilder(object): def is_state(self): return self._state_key is not None - @defer.inlineCallbacks - def build(self, prev_event_ids): + async def build(self, prev_event_ids): """Transform into a fully signed and hashed event Args: prev_event_ids (list[str]): The event IDs to use as the prev events Returns: - Deferred[FrozenEvent] + FrozenEvent """ - state_ids = yield defer.ensureDeferred( - self._state.get_current_state_ids(self.room_id, prev_event_ids) + state_ids = await self._state.get_current_state_ids( + self.room_id, prev_event_ids ) - auth_ids = yield self._auth.compute_auth_events(self, state_ids) + auth_ids = await self._auth.compute_auth_events(self, state_ids) format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: - auth_events = yield self._store.add_event_hashes(auth_ids) - prev_events = yield self._store.add_event_hashes(prev_event_ids) + auth_events = await self._store.add_event_hashes(auth_ids) + prev_events = await self._store.add_event_hashes(prev_event_ids) else: auth_events = auth_ids prev_events = prev_event_ids - old_depth = yield self._store.get_max_depth_of(prev_event_ids) + old_depth = await self._store.get_max_depth_of(prev_event_ids) depth = old_depth + 1 # we cap depth of generated events, to ensure that they are not diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index f94cdcbab..cca93e3a4 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,17 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import attr from frozendict import frozendict -from twisted.internet import defer - from synapse.appservice import ApplicationService +from synapse.events import EventBase from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import StateMap +if TYPE_CHECKING: + from synapse.storage.data_stores.main import DataStore + @attr.s(slots=True) class EventContext: @@ -129,8 +131,7 @@ class EventContext: delta_ids=delta_ids, ) - @defer.inlineCallbacks - def serialize(self, event, store): + async def serialize(self, event: EventBase, store: "DataStore") -> dict: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -146,7 +147,7 @@ class EventContext: # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids() + prev_state_ids = await self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -214,8 +215,7 @@ class EventContext: return self._state_group - @defer.inlineCallbacks - def get_current_state_ids(self): + async def get_current_state_ids(self) -> Optional[StateMap[str]]: """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -224,32 +224,31 @@ class EventContext: ``rejected`` is set. Returns: - Deferred[dict[(str, str), str]|None]: Returns None if state_group - is None, which happens when the associated event is an outlier. + Returns None if state_group is None, which happens when the associated + event is an outlier. - Maps a (type, state_key) to the event ID of the state event matching - this tuple. + Maps a (type, state_key) to the event ID of the state event matching + this tuple. """ if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched() + await self._ensure_fetched() return self._current_state_ids - @defer.inlineCallbacks - def get_prev_state_ids(self): + async def get_prev_state_ids(self): """ Gets the room state map, excluding this event. For a non-state event, this will be the same as get_current_state_ids(). Returns: - Deferred[dict[(str, str), str]|None]: Returns None if state_group + dict[(str, str), str]|None: Returns None if state_group is None, which happens when the associated event is an outlier. Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched() + await self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): @@ -269,8 +268,8 @@ class EventContext: return self._current_state_ids - def _ensure_fetched(self): - return defer.succeed(None) + async def _ensure_fetched(self): + return None @attr.s(slots=True) @@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext): _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self): + async def _ensure_fetched(self): if not self._fetching_state_deferred: self._fetching_state_deferred = run_in_background(self._fill_out_state) - return make_deferred_yieldable(self._fetching_state_deferred) + return await make_deferred_yieldable(self._fetching_state_deferred) - @defer.inlineCallbacks - def _fill_out_state(self): + async def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self._current_state_ids = await self._storage.state.get_state_ids_for_group( self.state_group ) if self._event_state_key is not None: diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 459132d38..2956a6423 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.types import Requester class ThirdPartyEventRules(object): @@ -39,76 +41,79 @@ class ThirdPartyEventRules(object): config=config, http_client=hs.get_simple_http_client() ) - @defer.inlineCallbacks - def check_event_allowed(self, event, context): + async def check_event_allowed( + self, event: EventBase, context: EventContext + ) -> bool: """Check if a provided event should be allowed in the given context. Args: - event (synapse.events.EventBase): The event to be checked. - context (synapse.events.snapshot.EventContext): The context of the event. + event: The event to be checked. + context: The context of the event. Returns: - defer.Deferred[bool]: True if the event should be allowed, False if not. + True if the event should be allowed, False if not. """ if self.third_party_rules is None: return True - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = await context.get_prev_state_ids() # Retrieve the state events from the database. state_events = {} for key, event_id in prev_state_ids.items(): - state_events[key] = yield self.store.get_event(event_id, allow_none=True) + state_events[key] = await self.store.get_event(event_id, allow_none=True) - ret = yield self.third_party_rules.check_event_allowed(event, state_events) + ret = await self.third_party_rules.check_event_allowed(event, state_events) return ret - @defer.inlineCallbacks - def on_create_room(self, requester, config, is_requester_admin): + async def on_create_room( + self, requester: Requester, config: dict, is_requester_admin: bool + ) -> bool: """Intercept requests to create room to allow, deny or update the request config. Args: - requester (Requester) - config (dict): The creation config from the client. - is_requester_admin (bool): If the requester is an admin + requester + config: The creation config from the client. + is_requester_admin: If the requester is an admin Returns: - defer.Deferred[bool]: Whether room creation is allowed or denied. + Whether room creation is allowed or denied. """ if self.third_party_rules is None: return True - ret = yield self.third_party_rules.on_create_room( + ret = await self.third_party_rules.on_create_room( requester, config, is_requester_admin ) return ret - @defer.inlineCallbacks - def check_threepid_can_be_invited(self, medium, address, room_id): + async def check_threepid_can_be_invited( + self, medium: str, address: str, room_id: str + ) -> bool: """Check if a provided 3PID can be invited in the given room. Args: - medium (str): The 3PID's medium. - address (str): The 3PID's address. - room_id (str): The room we want to invite the threepid to. + medium: The 3PID's medium. + address: The 3PID's address. + room_id: The room we want to invite the threepid to. Returns: - defer.Deferred[bool], True if the 3PID can be invited, False if not. + True if the 3PID can be invited, False if not. """ if self.third_party_rules is None: return True - state_ids = yield self.store.get_filtered_current_state_ids(room_id) - room_state_events = yield self.store.get_events(state_ids.values()) + state_ids = await self.store.get_filtered_current_state_ids(room_id) + room_state_events = await self.store.get_events(state_ids.values()) state_events = {} for key, event_id in state_ids.items(): state_events[key] = room_state_events[event_id] - ret = yield self.third_party_rules.check_threepid_can_be_invited( + ret = await self.third_party_rules.check_threepid_can_be_invited( medium, address, state_events ) return ret diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 11f0d34ec..2d42e268c 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -18,8 +18,6 @@ from typing import Any, Mapping, Union from frozendict import frozendict -from twisted.internet import defer - from synapse.api.constants import EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersion @@ -337,8 +335,9 @@ class EventClientSerializer(object): hs.config.experimental_msc1849_support_enabled ) - @defer.inlineCallbacks - def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs): + async def serialize_event( + self, event, time_now, bundle_aggregations=True, **kwargs + ): """Serializes a single event. Args: @@ -348,7 +347,7 @@ class EventClientSerializer(object): **kwargs: Arguments to pass to `serialize_event` Returns: - Deferred[dict]: The serialized event + dict: The serialized event """ # To handle the case of presence events and the like if not isinstance(event, EventBase): @@ -363,8 +362,8 @@ class EventClientSerializer(object): if not event.internal_metadata.is_redacted() and ( self.experimental_msc1849_support_enabled and bundle_aggregations ): - annotations = yield self.store.get_aggregation_groups_for_event(event_id) - references = yield self.store.get_relations_for_event( + annotations = await self.store.get_aggregation_groups_for_event(event_id) + references = await self.store.get_relations_for_event( event_id, RelationTypes.REFERENCE, direction="f" ) @@ -378,7 +377,7 @@ class EventClientSerializer(object): edit = None if event.type == EventTypes.Message: - edit = yield self.store.get_applicable_edit(event_id) + edit = await self.store.get_applicable_edit(event_id) if edit: # If there is an edit replace the content, preserving existing diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index f5f683bfd..0d7d1adce 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2470,7 +2470,7 @@ class FederationHandler(BaseHandler): } current_state_ids = await context.get_current_state_ids() - current_state_ids = dict(current_state_ids) + current_state_ids = dict(current_state_ids) # type: ignore current_state_ids.update(state_updates) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index c287c4e26..ca065e819 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -78,7 +78,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): """ event_payloads = [] for event, context in event_and_contexts: - serialized_context = yield context.serialize(event, store) + serialized_context = yield defer.ensureDeferred( + context.serialize(event, store) + ) event_payloads.append( { diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index c981723c1..b30e4d503 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -77,7 +77,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): extra_users (list(UserID)): Any extra users to notify about event """ - serialized_context = yield context.serialize(event, store) + serialized_context = yield defer.ensureDeferred(context.serialize(event, store)) payload = { "event": event.get_pdu_json(), diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index db3667dc4..0f0e1cd09 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def build(self, prev_event_ids): - built_event = yield self._base_builder.build(prev_event_ids) + built_event = yield defer.ensureDeferred( + self._base_builder.build(prev_event_ids) + ) built_event._event_id = self._event_id built_event._dict["event_id"] = self._event_id diff --git a/tests/test_state.py b/tests/test_state.py index 4858e8fc5..b5c3667d2 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -213,7 +213,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertEqual(2, len(prev_state_ids)) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -259,7 +259,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -318,7 +318,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_e = context_store["E"] - prev_state_ids = yield ctx_e.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids()) self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) @@ -393,7 +393,7 @@ class StateTestCase(unittest.TestCase): ctx_b = context_store["B"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids()) self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) @@ -425,7 +425,7 @@ class StateTestCase(unittest.TestCase): self.state.compute_event_context(event, old_state=old_state) ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) @@ -450,7 +450,7 @@ class StateTestCase(unittest.TestCase): self.state.compute_event_context(event, old_state=old_state) ) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) @@ -519,7 +519,7 @@ class StateTestCase(unittest.TestCase): context = yield defer.ensureDeferred(self.state.compute_event_context(event)) - prev_state_ids = yield context.get_prev_state_ids() + prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) From 68626ff8e98443d6dc470970274a853a93fceefa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 14:40:11 -0400 Subject: [PATCH 04/26] Convert the remaining media repo code to async / await. (#7947) --- changelog.d/7947.misc | 1 + synapse/rest/media/v1/_base.py | 8 +- synapse/rest/media/v1/media_repository.py | 105 ++++++++++-------- synapse/rest/media/v1/media_storage.py | 52 +++++---- synapse/rest/media/v1/preview_url_resource.py | 10 +- synapse/rest/media/v1/storage_provider.py | 62 +++++------ 6 files changed, 131 insertions(+), 107 deletions(-) create mode 100644 changelog.d/7947.misc diff --git a/changelog.d/7947.misc b/changelog.d/7947.misc new file mode 100644 index 000000000..dfe4c0317 --- /dev/null +++ b/changelog.d/7947.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 9a847130c..20ddb9550 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -17,7 +17,9 @@ import logging import os import urllib +from typing import Awaitable +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error @@ -240,14 +242,14 @@ class Responder(object): held can be cleaned up. """ - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: """Stream response into consumer Args: - consumer (IConsumer) + consumer: The consumer to stream into. Returns: - Deferred: Resolves once the response has finished being written + Resolves once the response has finished being written """ pass diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 45628c07b..6fb4039e9 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -18,10 +18,11 @@ import errno import logging import os import shutil -from typing import Dict, Tuple +from typing import IO, Dict, Optional, Tuple import twisted.internet.error import twisted.web.http +from twisted.web.http import Request from twisted.web.resource import Resource from synapse.api.errors import ( @@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string from ._base import ( FileInfo, + Responder, get_filename_from_headers, respond_404, respond_with_responder, @@ -135,19 +137,24 @@ class MediaRepository(object): self.recently_accessed_locals.add(media_id) async def create_content( - self, media_type, upload_name, content, content_length, auth_user - ): + self, + media_type: str, + upload_name: str, + content: IO, + content_length: int, + auth_user: str, + ) -> str: """Store uploaded content for a local user and return the mxc URL Args: - media_type(str): The content type of the file - upload_name(str): The name of the file + media_type: The content type of the file + upload_name: The name of the file content: A file like object that is the content to store - content_length(int): The length of the content - auth_user(str): The user_id of the uploader + content_length: The length of the content + auth_user: The user_id of the uploader Returns: - Deferred[str]: The mxc url of the stored content + The mxc url of the stored content """ media_id = random_string(24) @@ -170,19 +177,20 @@ class MediaRepository(object): return "mxc://%s/%s" % (self.server_name, media_id) - async def get_local_media(self, request, media_id, name): + async def get_local_media( + self, request: Request, media_id: str, name: Optional[str] + ) -> None: """Responds to reqests for local media, if exists, or returns 404. Args: - request(twisted.web.http.Request) - media_id (str): The media ID of the content. (This is the same as + request: The incoming request. + media_id: The media ID of the content. (This is the same as the file_id for local content.) - name (str|None): Optional name that, if specified, will be used as + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ media_info = await self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: @@ -203,20 +211,20 @@ class MediaRepository(object): request, responder, media_type, media_length, upload_name ) - async def get_remote_media(self, request, server_name, media_id, name): + async def get_remote_media( + self, request: Request, server_name: str, media_id: str, name: Optional[str] + ) -> None: """Respond to requests for remote media. Args: - request(twisted.web.http.Request) - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). - name (str|None): Optional name that, if specified, will be used as + request: The incoming request. + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ if ( self.federation_domain_whitelist is not None @@ -245,17 +253,16 @@ class MediaRepository(object): else: respond_404(request) - async def get_remote_media_info(self, server_name, media_id): + async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: """Gets the media info associated with the remote file, downloading if necessary. Args: - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). Returns: - Deferred[dict]: The media_info of the file + The media info of the file """ if ( self.federation_domain_whitelist is not None @@ -278,7 +285,9 @@ class MediaRepository(object): return media_info - async def _get_remote_media_impl(self, server_name, media_id): + async def _get_remote_media_impl( + self, server_name: str, media_id: str + ) -> Tuple[Optional[Responder], dict]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -288,7 +297,7 @@ class MediaRepository(object): remote server). Returns: - Deferred[(Responder, media_info)] + A tuple of responder and the media info of the file. """ media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -319,19 +328,21 @@ class MediaRepository(object): responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file(self, server_name, media_id, file_id): + async def _download_remote_file( + self, server_name: str, media_id: str, file_id: str + ) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. Args: - server_name (str): Originating server - media_id (str): The media ID of the content (as defined by the + server_name: Originating server + media_id: The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. - file_id (str): Local file ID + file_id: Local file ID Returns: - Deferred[MediaInfo] + The media info of the file. """ file_info = FileInfo(server_name=server_name, file_id=file_id) @@ -549,25 +560,31 @@ class MediaRepository(object): return output_path async def _generate_thumbnails( - self, server_name, media_id, file_id, media_type, url_cache=False - ): + self, + server_name: Optional[str], + media_id: str, + file_id: str, + media_type: str, + url_cache: bool = False, + ) -> Optional[dict]: """Generate and store thumbnails for an image. Args: - server_name (str|None): The server name if remote media, else None if local - media_id (str): The media ID of the content. (This is the same as + server_name: The server name if remote media, else None if local + media_id: The media ID of the content. (This is the same as the file_id for local content) - file_id (str): Local file ID - media_type (str): The content type of the file - url_cache (bool): If we are thumbnailing images downloaded for the URL cache, + file_id: Local file ID + media_type: The content type of the file + url_cache: If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: - Deferred[dict]: Dict with "width" and "height" keys of original image + Dict with "width" and "height" keys of original image or None if the + media cannot be thumbnailed. """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: - return + return None input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) @@ -584,7 +601,7 @@ class MediaRepository(object): m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = await defer_to_thread( diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 66bc1c336..858b6d300 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import contextlib import inspect import logging import os import shutil -from typing import Optional +from typing import IO, TYPE_CHECKING, Any, Optional, Sequence from twisted.protocols.basic import FileSender @@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.server import HomeServer + + from .storage_provider import StorageProvider logger = logging.getLogger(__name__) @@ -34,20 +39,25 @@ class MediaStorage(object): """Responsible for storing/fetching files from local sources. Args: - hs (synapse.server.Homeserver) - local_media_directory (str): Base path where we store media on disk - filepaths (MediaFilePaths) - storage_providers ([StorageProvider]): List of StorageProvider that are - used to fetch and store files. + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. """ - def __init__(self, hs, local_media_directory, filepaths, storage_providers): + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProvider"], + ): self.hs = hs self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers - async def store_file(self, source, file_info: FileInfo) -> str: + async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other configured storage providers @@ -69,7 +79,7 @@ class MediaStorage(object): return fname @contextlib.contextmanager - def store_into_file(self, file_info): + def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as described by file_info. @@ -85,7 +95,7 @@ class MediaStorage(object): error. Args: - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Example: @@ -143,9 +153,9 @@ class MediaStorage(object): return FileResponder(open(local_path, "rb")) for provider in self.storage_providers: - res = provider.fetch(path, file_info) - # Fetch is supposed to return an Awaitable, but guard against - # improper implementations. + res = provider.fetch(path, file_info) # type: Any + # Fetch is supposed to return an Awaitable[Responder], but guard + # against improper implementations. if inspect.isawaitable(res): res = await res if res: @@ -174,9 +184,9 @@ class MediaStorage(object): os.makedirs(dirname) for provider in self.storage_providers: - res = provider.fetch(path, file_info) - # Fetch is supposed to return an Awaitable, but guard against - # improper implementations. + res = provider.fetch(path, file_info) # type: Any + # Fetch is supposed to return an Awaitable[Responder], but guard + # against improper implementations. if inspect.isawaitable(res): res = await res if res: @@ -190,17 +200,11 @@ class MediaStorage(object): raise Exception("file could not be found") - def _file_info_to_path(self, file_info): + def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. The path is suitable for storing files under a directory, e.g. used to store files on local FS under the base media repository directory. - - Args: - file_info (FileInfo) - - Returns: - str """ if file_info.url_cache: if file_info.thumbnail: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 13d1a6d2e..e12f65a20 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource): og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) respond_with_json_bytes(request, 200, og, send_cors=True) - async def _do_preview(self, url, user, ts): + async def _do_preview(self, url: str, user: str, ts: int) -> bytes: """Check the db, and download the URL and build a preview Args: - url (str): - user (str): - ts (int): + url: The URL to preview. + user: The user requesting the preview. + ts: The timestamp requested for the preview. Returns: - Deferred[bytes]: json-encoded og data + json-encoded og data """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 858680be2..a33f56e80 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -16,62 +16,62 @@ import logging import os import shutil - -from twisted.internet import defer +from typing import Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from ._base import FileInfo, Responder from .media_storage import FileResponder logger = logging.getLogger(__name__) -class StorageProvider(object): +class StorageProvider: """A storage provider is a service that can store uploaded media and retrieve them. """ - def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo): """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) - - Returns: - Deferred + path: Relative path of file in local cache + file_info: The metadata of the file. """ - pass - def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) + path: Relative path of file in local cache + file_info: The metadata of the file. Returns: - Deferred(Responder): Returns a Responder if the provider has the file, - otherwise returns None. + Returns a Responder if the provider has the file, otherwise returns None. """ - pass class StorageProviderWrapper(StorageProvider): """Wraps a storage provider and provides various config options Args: - backend (StorageProvider) - store_local (bool): Whether to store new local files or not. - store_synchronous (bool): Whether to wait for file to be successfully + backend: The storage provider to wrap. + store_local: Whether to store new local files or not. + store_synchronous: Whether to wait for file to be successfully uploaded, or todo the upload in the background. - store_remote (bool): Whether remote media should be uploaded + store_remote: Whether remote media should be uploaded """ - def __init__(self, backend, store_local, store_synchronous, store_remote): + def __init__( + self, + backend: StorageProvider, + store_local: bool, + store_synchronous: bool, + store_remote: bool, + ): self.backend = backend self.store_local = store_local self.store_synchronous = store_synchronous @@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider): def __str__(self): return "StorageProviderWrapper[%s]" % (self.backend,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): if not file_info.server_name and not self.store_local: - return defer.succeed(None) + return None if file_info.server_name and not self.store_remote: - return defer.succeed(None) + return None if self.store_synchronous: - return self.backend.store_file(path, file_info) + return await self.backend.store_file(path, file_info) else: # TODO: Handle errors. def store(): @@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider): logger.exception("Error storing file") run_in_background(store) - return defer.succeed(None) + return None - def fetch(self, path, file_info): - return self.backend.fetch(path, file_info) + async def fetch(self, path, file_info): + return await self.backend.fetch(path, file_info) class FileStorageProviderBackend(StorageProvider): @@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return defer_to_thread( + return await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - def fetch(self, path, file_info): + async def fetch(self, path, file_info): """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) From c4ce0da6fee5c03de9364f08922d4cb8fd3e1705 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 27 Jul 2020 17:26:50 -0700 Subject: [PATCH 05/26] Add script for finding files with unix line terminators (#7965) This PRs adds a script to check for unix-line terminators in the repo. It will be used to address https://github.com/matrix-org/synapse/issues/7943 by adding the check to CI. I've changed the original script slightly as proposed in https://github.com/matrix-org/pipelines/pull/81#discussion_r460580664 --- changelog.d/7965.misc | 1 + scripts-dev/check_line_terminators.sh | 31 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 changelog.d/7965.misc create mode 100755 scripts-dev/check_line_terminators.sh diff --git a/changelog.d/7965.misc b/changelog.d/7965.misc new file mode 100644 index 000000000..ee9f1a711 --- /dev/null +++ b/changelog.d/7965.misc @@ -0,0 +1 @@ +Add a script to detect source code files using non-unix line terminators. \ No newline at end of file diff --git a/scripts-dev/check_line_terminators.sh b/scripts-dev/check_line_terminators.sh new file mode 100755 index 000000000..0f430e839 --- /dev/null +++ b/scripts-dev/check_line_terminators.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This script checks that line terminators in all repository files (excluding +# those in the .git directory) feature unix line terminators. +# +# Usage: +# +# ./check_line_terminators.sh +# +# The script will emit exit code 1 if any files that do not use unix line +# terminators are found, 0 otherwise. + +# cd to the root of the repository +cd `dirname $0`/.. + +# Find and print files with non-unix line terminators +find . -path './.git/*' -prune -o -type f -print0 | xargs -0 grep -I -l $'\r$' && ( echo 'found files with CRLF line endings'; exit 1 ) From aaf9ce72a0195ade561803e762dfe440969c90c7 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 28 Jul 2020 10:03:18 +0100 Subject: [PATCH 06/26] Fix typo in metrics docs (#7966) --- docs/metrics-howto.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/metrics-howto.md b/docs/metrics-howto.md index cf69938a2..b386ec91c 100644 --- a/docs/metrics-howto.md +++ b/docs/metrics-howto.md @@ -27,7 +27,7 @@ different thread to Synapse. This can make it more resilient to heavy load meaning metrics cannot be retrieved, and can be exposed to just internal networks easier. The served metrics are available - over HTTP only, and will be available at `/`. + over HTTP only, and will be available at `/_synapse/metrics`. Add a new listener to homeserver.yaml: From 3857de2194e3b2057c4af71e095eb6759508f25f Mon Sep 17 00:00:00 2001 From: lugino-emeritus Date: Tue, 28 Jul 2020 14:41:44 +0200 Subject: [PATCH 07/26] Option to allow server admins to join complex rooms (#7902) Fixes #7901. Signed-off-by: Niklas Tittjung --- changelog.d/7902.feature | 1 + docs/sample_config.yaml | 4 + synapse/config/server.py | 7 ++ synapse/handlers/room_member.py | 8 +- tests/federation/test_complexity.py | 109 ++++++++++++++++++++++++++++ 5 files changed, 127 insertions(+), 2 deletions(-) create mode 100644 changelog.d/7902.feature diff --git a/changelog.d/7902.feature b/changelog.d/7902.feature new file mode 100644 index 000000000..4feae8cc2 --- /dev/null +++ b/changelog.d/7902.feature @@ -0,0 +1 @@ +Add option to allow server admins to join rooms which fail complexity checks. Contributed by @lugino-emeritus. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 3227294e0..09a729987 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -314,6 +314,10 @@ limit_remote_rooms: # #complexity_error: "This room is too complex." + # allow server admins to join complex rooms. Default is false. + # + #admins_can_join: true + # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. # diff --git a/synapse/config/server.py b/synapse/config/server.py index 3747a01ca..848587d23 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -439,6 +439,9 @@ class ServerConfig(Config): validator=attr.validators.instance_of(str), default=ROOM_COMPLEXITY_TOO_GREAT, ) + admins_can_join = attr.ib( + validator=attr.validators.instance_of(bool), default=False + ) self.limit_remote_rooms = LimitRemoteRoomsConfig( **(config.get("limit_remote_rooms") or {}) @@ -893,6 +896,10 @@ class ServerConfig(Config): # #complexity_error: "This room is too complex." + # allow server admins to join complex rooms. Default is false. + # + #admins_can_join: true + # Whether to require a user to be in the room to add an alias to it. # Defaults to 'true'. # diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index a1a8fa1d3..5a40e8c14 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -952,7 +952,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - if self.hs.config.limit_remote_rooms.enabled: + check_complexity = self.hs.config.limit_remote_rooms.enabled + if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: + check_complexity = not await self.hs.auth.is_server_admin(user) + + if check_complexity: # Fetch the room complexity too_complex = await self._is_remote_room_too_complex( room_id, remote_room_hosts @@ -975,7 +979,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): # Check the room we just joined wasn't too large, if we didn't fetch the # complexity of it before. - if self.hs.config.limit_remote_rooms.enabled: + if check_complexity: if too_complex is False: # We checked, and we're under the limit. return event_id, stream_id diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 0c9987be5..5cd0510f0 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -99,6 +99,37 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): self.assertEqual(f.value.code, 400, f.value) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_join_too_large_admin(self): + # Check whether an admin can join if option "admins_can_join" is undefined, + # this option defaults to false, so the join should fail. + + u1 = self.register_user("u1", "pass", admin=True) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + def test_join_too_large_once_joined(self): u1 = self.register_user("u1", "pass") @@ -141,3 +172,81 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): f = self.get_failure(d, SynapseError) self.assertEqual(f.value.code, 400) self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + +class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): + # Test the behavior of joining rooms which exceed the complexity if option + # limit_remote_rooms.admins_can_join is True. + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def default_config(self): + config = super().default_config() + config["limit_remote_rooms"] = { + "enabled": True, + "complexity": 0.05, + "admins_can_join": True, + } + return config + + def test_join_too_large_no_admin(self): + # A user which is not an admin should not be able to join a remote room + # which is too complex. + + u1 = self.register_user("u1", "pass") + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request failed with a SynapseError saying the resource limit was + # exceeded. + f = self.get_failure(d, SynapseError) + self.assertEqual(f.value.code, 400, f.value) + self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) + + def test_join_too_large_admin(self): + # An admin should be able to join rooms where a complexity check fails. + + u1 = self.register_user("u1", "pass", admin=True) + + handler = self.hs.get_room_member_handler() + fed_transport = self.hs.get_federation_transport_client() + + # Mock out some things, because we don't want to test the whole join + fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + handler.federation_handler.do_invite_join = Mock( + return_value=defer.succeed(("", 1)) + ) + + d = handler._remote_join( + None, + ["other.example.com"], + "roomid", + UserID.from_string(u1), + {"membership": "join"}, + ) + + self.pump() + + # The request success since the user is an admin + self.get_success(d) From 8078dec3be29f849f730fdd91a83fd2b1f89726e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 28 Jul 2020 13:52:25 +0100 Subject: [PATCH 08/26] Fix exit code for `check_line_terminators.sh` (#7970) If there are *no* files with CRLF line endings, then the xargs exits with a non-zero exit code (as expected), but then, since that is the last thing to happen in the script, the script as a whole exits non-zero, making the whole thing fail. using `if/then/fi` instead of `&& (...)` means that the script exits with a zero exit code. --- changelog.d/7970.misc | 1 + scripts-dev/check_line_terminators.sh | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7970.misc diff --git a/changelog.d/7970.misc b/changelog.d/7970.misc new file mode 100644 index 000000000..ee9f1a711 --- /dev/null +++ b/changelog.d/7970.misc @@ -0,0 +1 @@ +Add a script to detect source code files using non-unix line terminators. \ No newline at end of file diff --git a/scripts-dev/check_line_terminators.sh b/scripts-dev/check_line_terminators.sh index 0f430e839..c98395623 100755 --- a/scripts-dev/check_line_terminators.sh +++ b/scripts-dev/check_line_terminators.sh @@ -28,4 +28,7 @@ cd `dirname $0`/.. # Find and print files with non-unix line terminators -find . -path './.git/*' -prune -o -type f -print0 | xargs -0 grep -I -l $'\r$' && ( echo 'found files with CRLF line endings'; exit 1 ) +if find . -path './.git/*' -prune -o -type f -print0 | xargs -0 grep -I -l $'\r$'; then + echo -e '\e[31mERROR: found files with CRLF line endings. See above.\e[39m' + exit 1 +fi From 2c1e1b153d7ca429b84c2cd0a2d657a066de8bc7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 28 Jul 2020 10:28:59 -0400 Subject: [PATCH 09/26] Use the JSON module from the std library instead of simplejson. (#7936) --- changelog.d/7936.misc | 1 + synapse/__init__.py | 12 ++++++++++++ synapse/python_dependencies.py | 2 +- 3 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7936.misc diff --git a/changelog.d/7936.misc b/changelog.d/7936.misc new file mode 100644 index 000000000..4304bbdd2 --- /dev/null +++ b/changelog.d/7936.misc @@ -0,0 +1 @@ +Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0. diff --git a/synapse/__init__.py b/synapse/__init__.py index 83ce2ae6f..72c93f6c4 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -17,6 +17,7 @@ """ This is a reference implementation of a Matrix homeserver. """ +import json import os import sys @@ -25,6 +26,9 @@ if sys.version_info < (3, 5): print("Synapse requires Python 3.5 or above.") sys.exit(1) +# Twisted and canonicaljson will fail to import when this file is executed to +# get the __version__ during a fresh install. That's OK and subsequent calls to +# actually start Synapse will import these libraries fine. try: from twisted.internet import protocol from twisted.internet.protocol import Factory @@ -36,6 +40,14 @@ try: except ImportError: pass +# Use the standard library json implementation instead of simplejson. +try: + from canonicaljson import set_json_library + + set_json_library(json) +except ImportError: + pass + __version__ = "1.18.0rc2" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8cfcdb057..abea2be4e 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -43,7 +43,7 @@ REQUIREMENTS = [ "jsonschema>=2.5.1", "frozendict>=1", "unpaddedbase64>=1.1.0", - "canonicaljson>=1.1.3", + "canonicaljson>=1.2.0", # we use the type definitions added in signedjson 1.1. "signedjson>=1.1.0", "pynacl>=1.2.1", From 8a25332d946158f2028d050cf9f7d444c028ef2d Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 28 Jul 2020 10:52:13 -0700 Subject: [PATCH 10/26] Move some log lines from default logger to sql/transaction loggers (#7952) Idea from matrix-org/synapse-dinsic#49 --- changelog.d/7952.misc | 1 + synapse/storage/database.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7952.misc diff --git a/changelog.d/7952.misc b/changelog.d/7952.misc new file mode 100644 index 000000000..93c25cb38 --- /dev/null +++ b/changelog.d/7952.misc @@ -0,0 +1 @@ +Move some database-related log lines from the default logger to the database/transaction loggers. \ No newline at end of file diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 3be20c866..ce8757a40 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -49,11 +49,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E from synapse.storage.types import Connection, Cursor from synapse.types import Collection -logger = logging.getLogger(__name__) - # python 3 does not have a maximum int value MAX_TXN_ID = 2 ** 63 - 1 +logger = logging.getLogger(__name__) + sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") perf_logger = logging.getLogger("synapse.storage.TIME") @@ -233,7 +233,7 @@ class LoggingTransaction: try: return func(sql, *args) except Exception as e: - logger.debug("[SQL FAIL] {%s} %s", self.name, e) + sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e) raise finally: secs = time.time() - start @@ -419,7 +419,7 @@ class Database(object): except self.engine.module.OperationalError as e: # This can happen if the database disappears mid # transaction. - logger.warning( + transaction_logger.warning( "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, ) if i < N: @@ -427,18 +427,20 @@ class Database(object): try: conn.rollback() except self.engine.module.Error as e1: - logger.warning("[TXN EROLL] {%s} %s", name, e1) + transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1) continue raise except self.engine.module.DatabaseError as e: if self.engine.is_deadlock(e): - logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N) + transaction_logger.warning( + "[TXN DEADLOCK] {%s} %d/%d", name, i, N + ) if i < N: i += 1 try: conn.rollback() except self.engine.module.Error as e1: - logger.warning( + transaction_logger.warning( "[TXN EROLL] {%s} %s", name, e1, ) continue @@ -478,7 +480,7 @@ class Database(object): # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 cursor.close() except Exception as e: - logger.debug("[TXN FAIL] {%s} %s", name, e) + transaction_logger.debug("[TXN FAIL] {%s} %s", name, e) raise finally: end = monotonic_time() From e866e3b8966efc470038b48061a89aac513eb6e0 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 28 Jul 2020 21:08:23 +0200 Subject: [PATCH 11/26] Add an option to disable purge in delete room admin API (#7964) Add option ```purge``` to ```POST /_synapse/admin/v1/rooms//delete``` Fixes: #3761 Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/7964.feature | 1 + docs/admin_api/rooms.md | 13 +++++--- synapse/rest/admin/rooms.py | 11 ++++++- tests/rest/admin/test_room.py | 57 +++++++++++++++++++++++++++++++++-- 4 files changed, 75 insertions(+), 7 deletions(-) create mode 100644 changelog.d/7964.feature diff --git a/changelog.d/7964.feature b/changelog.d/7964.feature new file mode 100644 index 000000000..ffe861650 --- /dev/null +++ b/changelog.d/7964.feature @@ -0,0 +1 @@ +Add an option to purge room or not with delete room admin endpoint (`POST /_synapse/admin/v1/rooms//delete`). Contributed by @dklimpel. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index 15b83e982..0f267d2b7 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -369,7 +369,9 @@ to the new room will have power level `-10` by default, and thus be unable to sp If `block` is `True` it prevents new joins to the old room. This API will remove all trace of the old room from your database after removing -all local users. +all local users. If `purge` is `true` (the default), all traces of the old room will +be removed from your database after removing all local users. If you do not want +this to happen, set `purge` to `false`. Depending on the amount of history being purged a call to the API may take several minutes or longer. @@ -388,7 +390,8 @@ with a body of: "new_room_user_id": "@someuser:example.com", "room_name": "Content Violation Notification", "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.", - "block": true + "block": true, + "purge": true } ``` @@ -430,8 +433,10 @@ The following JSON body parameters are available: `new_room_user_id` in the new room. Ideally this will clearly convey why the original room was shut down. Defaults to `Sharing illegal content on this server is not permitted and rooms in violation will be blocked.` -* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to - join the room. Defaults to `false`. +* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing + future attempts to join the room. Defaults to `false`. +* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database. + Defaults to `true`. The JSON body must not be empty. The body must be at least `{}`. diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index b8c95d045..a8364d979 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -103,6 +103,14 @@ class DeleteRoomRestServlet(RestServlet): Codes.BAD_JSON, ) + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + ret = await self.room_shutdown_handler.shutdown_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), @@ -113,7 +121,8 @@ class DeleteRoomRestServlet(RestServlet): ) # Purge room - await self.pagination_handler.purge_room(room_id) + if purge: + await self.pagination_handler.purge_room(room_id) return (200, ret) diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index ba8552c29..cec1cf928 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -283,6 +283,23 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + def test_purge_is_not_bool(self): + """ + If parameter `purge` is not boolean, return an error + """ + body = json.dumps({"purge": "NotBool"}) + + request, channel = self.make_request( + "POST", + self.url, + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"]) + def test_purge_room_and_block(self): """Test to purge a room and block it. Members will not be moved to a new room and will not receive a message. @@ -297,7 +314,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": True}) + body = json.dumps({"block": True, "purge": True}) request, channel = self.make_request( "POST", @@ -331,7 +348,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): # Assert one user in room self._is_member(room_id=self.room_id, user_id=self.other_user) - body = json.dumps({"block": False}) + body = json.dumps({"block": False, "purge": True}) request, channel = self.make_request( "POST", @@ -351,6 +368,42 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): self._is_blocked(self.room_id, expect=False) self._has_no_members(self.room_id) + def test_block_room_and_not_purge(self): + """Test to block a room without purging it. + Members will not be moved to a new room and will not receive a message. + The room will not be purged. + """ + # Test that room is not purged + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + + # Test that room is not blocked + self._is_blocked(self.room_id, expect=False) + + # Assert one user in room + self._is_member(room_id=self.room_id, user_id=self.other_user) + + body = json.dumps({"block": False, "purge": False}) + + request, channel = self.make_request( + "POST", + self.url.encode("ascii"), + content=body.encode(encoding="utf_8"), + access_token=self.admin_user_tok, + ) + self.render(request) + + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(None, channel.json_body["new_room_id"]) + self.assertEqual(self.other_user, channel.json_body["kicked_users"][0]) + self.assertIn("failed_to_kick_users", channel.json_body) + self.assertIn("local_aliases", channel.json_body) + + with self.assertRaises(AssertionError): + self._is_purged(self.room_id) + self._is_blocked(self.room_id, expect=False) + self._has_no_members(self.room_id) + def test_shutdown_room_consent(self): """Test that we can shutdown rooms with local users who have not yet accepted the privacy policy. This used to fail when we tried to From 3345c166a45cb4a8f87c583ee0476c2bca5c41bd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 28 Jul 2020 16:09:53 -0400 Subject: [PATCH 12/26] Convert storage layer to async/await. (#7963) --- changelog.d/7963.misc | 1 + synapse/storage/persist_events.py | 40 +++--- synapse/storage/purge_events.py | 38 +++--- synapse/storage/state.py | 207 ++++++++++++++++-------------- tests/storage/test_purge.py | 8 +- tests/storage/test_room.py | 6 +- tests/storage/test_state.py | 64 +++++---- tests/test_visibility.py | 14 +- tests/utils.py | 16 +-- tox.ini | 1 + 10 files changed, 210 insertions(+), 185 deletions(-) create mode 100644 changelog.d/7963.misc diff --git a/changelog.d/7963.misc b/changelog.d/7963.misc new file mode 100644 index 000000000..dfe4c0317 --- /dev/null +++ b/changelog.d/7963.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 78fbdcdee..4a164834d 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import FrozenEvent +from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process @@ -192,12 +192,11 @@ class EventsPersistenceStorage(object): self._event_persist_queue = _EventPeristenceQueue() self._state_resolution_handler = hs.get_state_resolution_handler() - @defer.inlineCallbacks - def persist_events( + async def persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, - ): + ) -> int: """ Write events to the database Args: @@ -207,7 +206,7 @@ class EventsPersistenceStorage(object): which might update the current state etc. Returns: - Deferred[int]: the stream ordering of the latest persisted event + the stream ordering of the latest persisted event """ partitioned = {} for event, ctx in events_and_contexts: @@ -223,22 +222,19 @@ class EventsPersistenceStorage(object): for room_id in partitioned: self._maybe_start_persisting(room_id) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) - max_persisted_id = yield self.main_store.get_current_events_token() + return self.main_store.get_current_events_token() - return max_persisted_id - - @defer.inlineCallbacks - def persist_event( - self, event: FrozenEvent, context: EventContext, backfilled: bool = False - ): + async def persist_event( + self, event: EventBase, context: EventContext, backfilled: bool = False + ) -> Tuple[int, int]: """ Returns: - Deferred[Tuple[int, int]]: the stream ordering of ``event``, - and the stream ordering of the latest persisted event + The stream ordering of `event`, and the stream ordering of the + latest persisted event """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled @@ -246,9 +242,9 @@ class EventsPersistenceStorage(object): self._maybe_start_persisting(event.room_id) - yield make_deferred_yieldable(deferred) + await make_deferred_yieldable(deferred) - max_persisted_id = yield self.main_store.get_current_events_token() + max_persisted_id = self.main_store.get_current_events_token() return (event.internal_metadata.stream_ordering, max_persisted_id) def _maybe_start_persisting(self, room_id: str): @@ -262,7 +258,7 @@ class EventsPersistenceStorage(object): async def _persist_events( self, - events_and_contexts: List[Tuple[FrozenEvent, EventContext]], + events_and_contexts: List[Tuple[EventBase, EventContext]], backfilled: bool = False, ): """Calculates the change to current state and forward extremities, and @@ -439,7 +435,7 @@ class EventsPersistenceStorage(object): async def _calculate_new_extremities( self, room_id: str, - event_contexts: List[Tuple[FrozenEvent, EventContext]], + event_contexts: List[Tuple[EventBase, EventContext]], latest_event_ids: List[str], ): """Calculates the new forward extremities for a room given events to @@ -497,7 +493,7 @@ class EventsPersistenceStorage(object): async def _get_new_state_after_events( self, room_id: str, - events_context: List[Tuple[FrozenEvent, EventContext]], + events_context: List[Tuple[EventBase, EventContext]], old_latest_event_ids: Iterable[str], new_latest_event_ids: Iterable[str], ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: @@ -683,7 +679,7 @@ class EventsPersistenceStorage(object): async def _is_server_still_joined( self, room_id: str, - ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], + ev_ctx_rm: List[Tuple[EventBase, EventContext]], delta: DeltaState, current_state: Optional[StateMap[str]], potentially_left_users: Set[str], diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index fdc0abf5c..79d9f06e2 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -15,8 +15,7 @@ import itertools import logging - -from twisted.internet import defer +from typing import Set logger = logging.getLogger(__name__) @@ -28,49 +27,48 @@ class PurgeEventsStorage(object): def __init__(self, hs, stores): self.stores = stores - @defer.inlineCallbacks - def purge_room(self, room_id: str): + async def purge_room(self, room_id: str): """Deletes all record of a room """ - state_groups_to_delete = yield self.stores.main.purge_room(room_id) - yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) + state_groups_to_delete = await self.stores.main.purge_room(room_id) + await self.stores.state.purge_room_state(room_id, state_groups_to_delete) - @defer.inlineCallbacks - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> None: """Deletes room history before a certain point Args: - room_id (str): + room_id: The room ID - token (str): A topological token to delete events before + token: A topological token to delete events before - delete_local_events (bool): + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). """ - state_groups = yield self.stores.main.purge_history( + state_groups = await self.stores.main.purge_history( room_id, token, delete_local_events ) logger.info("[purge] finding state groups that can be deleted") - sg_to_delete = yield self._find_unreferenced_groups(state_groups) + sg_to_delete = await self._find_unreferenced_groups(state_groups) - yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) + await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) - @defer.inlineCallbacks - def _find_unreferenced_groups(self, state_groups): + async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: """Used when purging history to figure out which state groups can be deleted. Args: - state_groups (set[int]): Set of state groups referenced by events + state_groups: Set of state groups referenced by events that are going to be deleted. Returns: - Deferred[set[int]] The set of state groups that can be deleted. + The set of state groups that can be deleted. """ # Graph of state group -> previous group graph = {} @@ -93,7 +91,7 @@ class PurgeEventsStorage(object): current_search = set(itertools.islice(next_to_search, 100)) next_to_search -= current_search - referenced = yield self.stores.main.get_referenced_state_groups( + referenced = await self.stores.main.get_referenced_state_groups( current_search ) referenced_groups |= referenced @@ -102,7 +100,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.state.get_previous_state_groups(current_search) + edges = await self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/state.py b/synapse/storage/state.py index dc568476f..49ee9c9a7 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,13 +14,12 @@ # limitations under the License. import logging -from typing import Iterable, List, TypeVar +from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr -from twisted.internet import defer - from synapse.api.constants import EventTypes +from synapse.events import EventBase from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -34,16 +33,16 @@ class StateFilter(object): """A filter used when querying for state. Attributes: - types (dict[str, set[str]|None]): Map from type to set of state keys (or - None). This specifies which state_keys for the given type to fetch - from the DB. If None then all events with that type are fetched. If - the set is empty then no events with that type are fetched. - include_others (bool): Whether to fetch events with types that do not + types: Map from type to set of state keys (or None). This specifies + which state_keys for the given type to fetch from the DB. If None + then all events with that type are fetched. If the set is empty + then no events with that type are fetched. + include_others: Whether to fetch events with types that do not appear in `types`. """ - types = attr.ib() - include_others = attr.ib(default=False) + types = attr.ib(type=Dict[str, Optional[Set[str]]]) + include_others = attr.ib(default=False, type=bool) def __attrs_post_init__(self): # If `include_others` is set we canonicalise the filter by removing @@ -52,36 +51,35 @@ class StateFilter(object): self.types = {k: v for k, v in self.types.items() if v is not None} @staticmethod - def all(): + def all() -> "StateFilter": """Creates a filter that fetches everything. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=True) @staticmethod - def none(): + def none() -> "StateFilter": """Creates a filter that fetches nothing. Returns: - StateFilter + The new state filter. """ return StateFilter(types={}, include_others=False) @staticmethod - def from_types(types): + def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter": """Creates a filter that only fetches the given types Args: - types (Iterable[tuple[str, str|None]]): A list of type and state - keys to fetch. A state_key of None fetches everything for - that type + types: A list of type and state keys to fetch. A state_key of None + fetches everything for that type Returns: - StateFilter + The new state filter. """ - type_dict = {} + type_dict = {} # type: Dict[str, Optional[Set[str]]] for typ, s in types: if typ in type_dict: if type_dict[typ] is None: @@ -91,24 +89,24 @@ class StateFilter(object): type_dict[typ] = None continue - type_dict.setdefault(typ, set()).add(s) + type_dict.setdefault(typ, set()).add(s) # type: ignore return StateFilter(types=type_dict) @staticmethod - def from_lazy_load_member_list(members): + def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter": """Creates a filter that returns all non-member events, plus the member events for the given users Args: - members (iterable[str]): Set of user IDs + members: Set of user IDs Returns: - StateFilter + The new state filter """ return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) - def return_expanded(self): + def return_expanded(self) -> "StateFilter": """Creates a new StateFilter where type wild cards have been removed (except for memberships). The returned filter is a superset of the current one, i.e. anything that passes the current filter will pass @@ -130,7 +128,7 @@ class StateFilter(object): return all non-member events Returns: - StateFilter + The new state filter. """ if self.is_full(): @@ -167,7 +165,7 @@ class StateFilter(object): include_others=True, ) - def make_sql_filter_clause(self): + def make_sql_filter_clause(self) -> Tuple[str, List[str]]: """Converts the filter to an SQL clause. For example: @@ -179,13 +177,12 @@ class StateFilter(object): Returns: - tuple[str, list]: The SQL string (may be empty) and arguments. An - empty SQL string is returned when the filter matches everything - (i.e. is "full"). + The SQL string (may be empty) and arguments. An empty SQL string is + returned when the filter matches everything (i.e. is "full"). """ where_clause = "" - where_args = [] + where_args = [] # type: List[str] if self.is_full(): return where_clause, where_args @@ -221,7 +218,7 @@ class StateFilter(object): return where_clause, where_args - def max_entries_returned(self): + def max_entries_returned(self) -> Optional[int]: """Returns the maximum number of entries this filter will return if known, otherwise returns None. @@ -260,33 +257,33 @@ class StateFilter(object): return filtered_state - def is_full(self): + def is_full(self) -> bool: """Whether this filter fetches everything or not Returns: - bool + True if the filter fetches everything. """ return self.include_others and not self.types - def has_wildcards(self): + def has_wildcards(self) -> bool: """Whether the filter includes wildcards or is attempting to fetch specific state. Returns: - bool + True if the filter includes wildcards. """ return self.include_others or any( state_keys is None for state_keys in self.types.values() ) - def concrete_types(self): + def concrete_types(self) -> List[Tuple[str, str]]: """Returns a list of concrete type/state_keys (i.e. not None) that will be fetched. This will be a complete list if `has_wildcards` returns False, but otherwise will be a subset (or even empty). Returns: - list[tuple[str,str]] + A list of type/state_keys tuples. """ return [ (t, s) @@ -295,7 +292,7 @@ class StateFilter(object): for s in state_keys ] - def get_member_split(self): + def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]: """Return the filter split into two: one which assumes it's exclusively matching against member state, and one which assumes it's matching against non member state. @@ -307,7 +304,7 @@ class StateFilter(object): state caches). Returns: - tuple[StateFilter, StateFilter]: The member and non member filters + The member and non member filters """ if EventTypes.Member in self.types: @@ -340,6 +337,9 @@ class StateGroupStorage(object): """Given a state group try to return a previous group and a delta between the old and the new. + Args: + state_group: The state group used to retrieve state deltas. + Returns: Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: (prev_group, delta_ids) @@ -347,55 +347,59 @@ class StateGroupStorage(object): return self.stores.state.get_state_group_delta(state_group) - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): + async def get_state_groups_ids( + self, _room_id: str, event_ids: Iterable[str] + ) -> Dict[int, StateMap[str]]: """Get the event IDs of all the state for the state groups for the given events Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events + _room_id: id of the room for these events + event_ids: ids of the events Returns: - Deferred[dict[int, StateMap[str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: return {} - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups(groups) + group_to_state = await self.stores.state._get_state_for_groups(groups) return group_to_state - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): + async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]: """Get the event IDs of all the state in the given state group Args: - state_group (int) + state_group: A state group for which we want to get the state IDs. Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + Resolves to a map of (type, state_key) -> event_id """ - group_to_state = yield self._get_state_for_groups((state_group,)) + group_to_state = await self._get_state_for_groups((state_group,)) return group_to_state[state_group] - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): + async def get_state_groups( + self, room_id: str, event_ids: Iterable[str] + ) -> Dict[int, List[EventBase]]: """ Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. + dict of state_group_id -> list of state events. """ if not event_ids: return {} - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ ev_id for group_ids in group_to_ids.values() @@ -423,31 +427,34 @@ class StateGroupStorage(object): groups: list of state group IDs to query state_filter: The state filter used to fetch state from the database. + Returns: Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """Given a list of event_ids and type tuples, return a list of state dicts for each event. + Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] + A dict of (event_id) -> (type, state_key) -> [state_events] """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) - state_event_map = yield self.stores.main.get_events( + state_event_map = await self.stores.main.get_events( [ev_id for sd in group_to_state.values() for ev_id in sd.values()], get_prev_content=False, ) @@ -463,24 +470,24 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): + async def get_state_ids_for_events( + self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() + ): """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from event_id -> (type, state_key) -> event_id + A dict from event_id -> (type, state_key) -> event_id """ - event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) + event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) groups = set(event_to_groups.values()) - group_to_state = yield self.stores.state._get_state_for_groups( + group_to_state = await self.stores.state._get_state_for_groups( groups, state_filter ) @@ -491,36 +498,36 @@ class StateGroupStorage(object): return {event: event_to_state[event] for event in event_ids} - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: - A deferred dict from (type, state_key) -> state_event + A dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_for_events([event_id], state_filter) + state_map = await self.get_state_for_events([event_id], state_filter) return state_map[event_id] - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): + async def get_state_ids_for_event( + self, event_id: str, state_filter: StateFilter = StateFilter.all() + ): """ Get the state dict corresponding to a particular event Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. Returns: A deferred dict from (type, state_key) -> state_event """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) + state_map = await self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] def _get_state_for_groups( @@ -530,9 +537,8 @@ class StateGroupStorage(object): filtering by type/state_key Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. from the database. Returns: Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. @@ -540,18 +546,23 @@ class StateGroupStorage(object): return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[dict], + current_state_ids: dict, ): """Store a new set of state, returning a newly assigned state group. Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and `current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) + current_state_ids: The state to store. Map of (type, state_key) to event_id. Returns: diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index b9fafaa1a..a6012c973 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + from synapse.rest.client.v1 import room from tests.unittest import HomeserverTestCase @@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase): event = self.successResultOf(event) # Purge everything before this topological token - purge = storage.purge_events.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred( + storage.purge_events.purge_history(self.room_id, event, True) + ) self.pump() self.assertEqual(self.successResultOf(purge), None) @@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase): ) # Purge everything before this topological token - purge = storage.purge_history(self.room_id, event, True) + purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True)) self.pump() f = self.failureResultOf(purge) self.assertIn("greater than forward", f.value.args[0]) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 1d77b4a2d..a5f250d47 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -97,8 +97,10 @@ class RoomEventsStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def inject_room_event(self, **kwargs): - yield self.storage.persistence.persist_event( - self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + yield defer.ensureDeferred( + self.storage.persistence.persist_event( + self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index a0e133cd4..6a48b9d3b 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups_ids( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_map = list(state_group_map.values())[0] @@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} ) - state_group_map = yield self.storage.state.get_state_groups( - self.room, [e2.event_id] + state_group_map = yield defer.ensureDeferred( + self.storage.state.get_state_groups(self.room, [e2.event_id]) ) self.assertEqual(len(state_group_map), 1) state_list = list(state_group_map.values())[0] @@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we get the full state as of the final event - state = yield self.storage.state.get_state_for_event(e5.event_id) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event(e5.event_id) + ) self.assertIsNotNone(e4) @@ -164,22 +168,28 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check we can filter to the m.room.name event (with a '' state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can filter to the m.room.name event (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) + ) ) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) # check we can grab the m.room.member events (with a wildcard None state key) - state = yield self.storage.state.get_state_for_event( - e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) self.assertStateMapEqual( @@ -188,12 +198,14 @@ class StateStoreTestCase(tests.unittest.TestCase): # check we can grab a specific room member without filtering out the # other event types - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: {self.u_alice.to_string()}}, - include_others=True, - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: {self.u_alice.to_string()}}, + include_others=True, + ), + ) ) self.assertStateMapEqual( @@ -206,11 +218,13 @@ class StateStoreTestCase(tests.unittest.TestCase): ) # check that we can grab everything except members - state = yield self.storage.state.get_state_for_event( - e5.event_id, - state_filter=StateFilter( - types={EventTypes.Member: set()}, include_others=True - ), + state = yield defer.ensureDeferred( + self.storage.state.get_state_for_event( + e5.event_id, + state_filter=StateFilter( + types={EventTypes.Member: set()}, include_others=True + ), + ) ) self.assertStateMapEqual( @@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase): ####################################################### room_id = self.room.to_string() - group_ids = yield self.storage.state.get_state_groups_ids( - room_id, [e5.event_id] + group_ids = yield defer.ensureDeferred( + self.storage.state.get_state_groups_ids(room_id, [e5.event_id]) ) group = list(group_ids.keys())[0] diff --git a/tests/test_visibility.py b/tests/test_visibility.py index a7a36174e..531a9b911 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.store = self.hs.get_datastore() self.storage = self.hs.get_storage() - yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") + yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")) @defer.inlineCallbacks def test_filtering(self): @@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): event, context = yield defer.ensureDeferred( self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks @@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase): self.event_creation_handler.create_new_client_event(builder) ) - yield self.storage.persistence.persist_event(event, context) + yield defer.ensureDeferred( + self.storage.persistence.persist_event(event, context) + ) return event @defer.inlineCallbacks diff --git a/tests/utils.py b/tests/utils.py index ac643679a..b33b6860d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -638,14 +638,8 @@ class DeferredMockCallable(object): ) -@defer.inlineCallbacks -def create_room(hs, room_id, creator_id): +async def create_room(hs, room_id: str, creator_id: str): """Creates and persist a creation event for the given room - - Args: - hs - room_id (str) - creator_id (str) """ persistence_store = hs.get_storage().persistence @@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id): event_builder_factory = hs.get_event_builder_factory() event_creation_handler = hs.get_event_creation_handler() - yield store.store_room( + await store.store_room( room_id=room_id, room_creator_user_id=creator_id, is_public=False, @@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id): }, ) - event, context = yield defer.ensureDeferred( - event_creation_handler.create_new_client_event(builder) - ) + event, context = await event_creation_handler.create_new_client_event(builder) - yield persistence_store.persist_event(event, context) + await persistence_store.persist_event(event, context) diff --git a/tox.ini b/tox.ini index 595ab3ba6..a394f6ead 100644 --- a/tox.ini +++ b/tox.ini @@ -206,6 +206,7 @@ commands = mypy \ synapse/storage/data_stores/main/ui_auth.py \ synapse/storage/database.py \ synapse/storage/engines \ + synapse/storage/state.py \ synapse/storage/util \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ From 2184f61faeb5ce88c05d28913e3f881813c0c5dd Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Wed, 29 Jul 2020 09:35:44 -0500 Subject: [PATCH 13/26] Various improvements to the docs (#7899) --- INSTALL.md | 109 ++++++++++++++++++++++++++++---- README.rst | 43 ++----------- changelog.d/7899.doc | 1 + debian/changelog | 10 +++ debian/matrix-synapse.default | 2 +- debian/synctl.ronn | 27 ++++---- docs/.sample_config_header.yaml | 11 ++++ docs/postgres.md | 3 + docs/sample_config.yaml | 29 ++++----- synapse/config/registration.py | 18 ------ 10 files changed, 153 insertions(+), 100 deletions(-) create mode 100644 changelog.d/7899.doc diff --git a/INSTALL.md b/INSTALL.md index b507de744..22f7b7c02 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -1,10 +1,12 @@ - [Choosing your server name](#choosing-your-server-name) +- [Picking a database engine](#picking-a-database-engine) - [Installing Synapse](#installing-synapse) - [Installing from source](#installing-from-source) - [Platform-Specific Instructions](#platform-specific-instructions) - [Prebuilt packages](#prebuilt-packages) - [Setting up Synapse](#setting-up-synapse) - [TLS certificates](#tls-certificates) + - [Client Well-Known URI](#client-well-known-uri) - [Email](#email) - [Registering a user](#registering-a-user) - [Setting up a TURN server](#setting-up-a-turn-server) @@ -27,6 +29,25 @@ that your email address is probably `user@example.com` rather than `user@email.example.com`) - but doing so may require more advanced setup: see [Setting up Federation](docs/federate.md). +# Picking a database engine + +Synapse offers two database engines: + * [PostgreSQL](https://www.postgresql.org) + * [SQLite](https://sqlite.org/) + +Almost all installations should opt to use PostgreSQL. Advantages include: + +* significant performance improvements due to the superior threading and + caching model, smarter query optimiser +* allowing the DB to be run on separate hardware + +For information on how to install and use PostgreSQL, please see +[docs/postgres.md](docs/postgres.md) + +By default Synapse uses SQLite and in doing so trades performance for convenience. +SQLite is only recommended in Synapse for testing purposes or for servers with +light workloads. + # Installing Synapse ## Installing from source @@ -234,9 +255,9 @@ for a number of platforms. There is an offical synapse image available at https://hub.docker.com/r/matrixdotorg/synapse which can be used with -the docker-compose file available at [contrib/docker](contrib/docker). Further information on -this including configuration options is available in the README on -hub.docker.com. +the docker-compose file available at [contrib/docker](contrib/docker). Further +information on this including configuration options is available in the README +on hub.docker.com. Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a Dockerfile to automate a synapse server in a single Docker image, at @@ -244,7 +265,8 @@ https://hub.docker.com/r/avhost/docker-matrix/tags/ Slavi Pantaleev has created an Ansible playbook, which installs the offical Docker image of Matrix Synapse -along with many other Matrix-related services (Postgres database, riot-web, coturn, mxisd, SSL support, etc.). +along with many other Matrix-related services (Postgres database, Element, coturn, +ma1sd, SSL support, etc.). For more details, see https://github.com/spantaleev/matrix-docker-ansible-deploy @@ -277,22 +299,27 @@ The fingerprint of the repository signing key (as shown by `gpg /usr/share/keyrings/matrix-org-archive-keyring.gpg`) is `AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`. -#### Downstream Debian/Ubuntu packages +#### Downstream Debian packages -For `buster` and `sid`, Synapse is available in the Debian repositories and -it should be possible to install it with simply: +We do not recommend using the packages from the default Debian `buster` +repository at this time, as they are old and suffer from known security +vulnerabilities. You can install the latest version of Synapse from +[our repository](#matrixorg-packages) or from `buster-backports`. Please +see the [Debian documentation](https://backports.debian.org/Instructions/) +for information on how to use backports. + +If you are using Debian `sid` or testing, Synapse is available in the default +repositories and it should be possible to install it simply with: ``` sudo apt install matrix-synapse ``` -There is also a version of `matrix-synapse` in `stretch-backports`. Please see -the [Debian documentation on -backports](https://backports.debian.org/Instructions/) for information on how -to use them. +#### Downstream Ubuntu packages -We do not recommend using the packages in downstream Ubuntu at this time, as -they are old and suffer from known security vulnerabilities. +We do not recommend using the packages in the default Ubuntu repository +at this time, as they are old and suffer from known security vulnerabilities. +The latest version of Synapse can be installed from [our repository](#matrixorg-packages). ### Fedora @@ -419,6 +446,60 @@ so, you will need to edit `homeserver.yaml`, as follows: For a more detailed guide to configuring your server for federation, see [federate.md](docs/federate.md). +## Client Well-Known URI + +Setting up the client Well-Known URI is optional but if you set it up, it will +allow users to enter their full username (e.g. `@user:`) into clients +which support well-known lookup to automatically configure the homeserver and +identity server URLs. This is useful so that users don't have to memorize or think +about the actual homeserver URL you are using. + +The URL `https:///.well-known/matrix/client` should return JSON in +the following format. + +``` +{ + "m.homeserver": { + "base_url": "https://" + } +} +``` + +It can optionally contain identity server information as well. + +``` +{ + "m.homeserver": { + "base_url": "https://" + }, + "m.identity_server": { + "base_url": "https://" + } +} +``` + +To work in browser based clients, the file must be served with the appropriate +Cross-Origin Resource Sharing (CORS) headers. A recommended value would be +`Access-Control-Allow-Origin: *` which would allow all browser based clients to +view it. + +In nginx this would be something like: +``` +location /.well-known/matrix/client { + return 200 '{"m.homeserver": {"base_url": "https://"}}'; + add_header Content-Type application/json; + add_header Access-Control-Allow-Origin *; +} +``` + +You should also ensure the `public_baseurl` option in `homeserver.yaml` is set +correctly. `public_baseurl` should be set to the URL that clients will use to +connect to your server. This is the same URL you put for the `m.homeserver` +`base_url` above. + +``` +public_baseurl: "https://" +``` ## Email @@ -437,7 +518,7 @@ email will be disabled. ## Registering a user -The easiest way to create a new user is to do so from a client like [Riot](https://riot.im). +The easiest way to create a new user is to do so from a client like [Element](https://element.io/). Alternatively you can do so from the command line if you have installed via pip. diff --git a/README.rst b/README.rst index f7116b348..4a189c8bc 100644 --- a/README.rst +++ b/README.rst @@ -45,7 +45,7 @@ which handle: - Eventually-consistent cryptographically secure synchronisation of room state across a global open network of federated servers and services - Sending and receiving extensible messages in a room with (optional) - end-to-end encryption[1] + end-to-end encryption - Inviting, joining, leaving, kicking, banning room members - Managing user accounts (registration, login, logout) - Using 3rd Party IDs (3PIDs) such as email addresses, phone numbers, @@ -82,9 +82,6 @@ at the `Matrix spec `_, and experiment with the Thanks for using Matrix! -[1] End-to-end encryption is currently in beta: `blog post `_. - - Support ======= @@ -115,12 +112,11 @@ Unless you are running a test instance of Synapse on your local machine, in general, you will need to enable TLS support before you can successfully connect from a client: see ``_. -An easy way to get started is to login or register via Riot at -https://riot.im/app/#/login or https://riot.im/app/#/register respectively. +An easy way to get started is to login or register via Element at +https://app.element.io/#/login or https://app.element.io/#/register respectively. You will need to change the server you are logging into from ``matrix.org`` and instead specify a Homeserver URL of ``https://:8448`` (or just ``https://`` if you are using a reverse proxy). -(Leave the identity server as the default - see `Identity servers`_.) If you prefer to use another client, refer to our `client breakdown `_. @@ -137,7 +133,7 @@ it, specify ``enable_registration: true`` in ``homeserver.yaml``. (It is then recommended to also set up CAPTCHA - see ``_.) Once ``enable_registration`` is set to ``true``, it is possible to register a -user via `riot.im `_ or other Matrix clients. +user via a Matrix client. Your new user name will be formed partly from the ``server_name``, and partly from a localpart you specify when you create the account. Your name will take @@ -183,30 +179,6 @@ versions of synapse. .. _UPGRADE.rst: UPGRADE.rst - -Using PostgreSQL -================ - -Synapse offers two database engines: - * `PostgreSQL `_ - * `SQLite `_ - -Almost all installations should opt to use PostgreSQL. Advantages include: - -* significant performance improvements due to the superior threading and - caching model, smarter query optimiser -* allowing the DB to be run on separate hardware -* allowing basic active/backup high-availability with a "hot spare" synapse - pointing at the same DB master, as well as enabling DB replication in - synapse itself. - -For information on how to install and use PostgreSQL, please see -`docs/postgres.md `_. - -By default Synapse uses SQLite and in doing so trades performance for convenience. -SQLite is only recommended in Synapse for testing purposes or for servers with -light workloads. - .. _reverse-proxy: Using a reverse proxy with Synapse @@ -255,10 +227,9 @@ email address. Password reset ============== -If a user has registered an email address to their account using an identity -server, they can request a password-reset token via clients such as Riot. - -A manual password reset can be done via direct database access as follows. +Users can reset their password through their client. Alternatively, a server admin +can reset a users password using the `admin API `_ +or by directly editing the database as shown below. First calculate the hash of the new password:: diff --git a/changelog.d/7899.doc b/changelog.d/7899.doc new file mode 100644 index 000000000..847c2cb62 --- /dev/null +++ b/changelog.d/7899.doc @@ -0,0 +1 @@ +Document how to set up a Client Well-Known file and fix several pieces of outdated documentation. diff --git a/debian/changelog b/debian/changelog index 3825603ae..99165b61f 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,13 @@ +matrix-synapse-py3 (1.xx.0) stable; urgency=medium + + [ Synapse Packaging team ] + * New synapse release 1.xx.0. + + [ Aaron Raimist ] + * Fix outdated documentation for SYNAPSE_CACHE_FACTOR + + -- Synapse Packaging team XXXXX + matrix-synapse-py3 (1.17.0) stable; urgency=medium * New synapse release 1.17.0. diff --git a/debian/matrix-synapse.default b/debian/matrix-synapse.default index 65dc2f33d..f402d73bb 100644 --- a/debian/matrix-synapse.default +++ b/debian/matrix-synapse.default @@ -1,2 +1,2 @@ # Specify environment variables used when running Synapse -# SYNAPSE_CACHE_FACTOR=1 (default) +# SYNAPSE_CACHE_FACTOR=0.5 (default) diff --git a/debian/synctl.ronn b/debian/synctl.ronn index a73c832f6..1bad6094f 100644 --- a/debian/synctl.ronn +++ b/debian/synctl.ronn @@ -46,19 +46,20 @@ Configuration file may be generated as follows: ## ENVIRONMENT * `SYNAPSE_CACHE_FACTOR`: - Synapse's architecture is quite RAM hungry currently - a lot of - recent room data and metadata is deliberately cached in RAM in - order to speed up common requests. This will be improved in - future, but for now the easiest way to either reduce the RAM usage - (at the risk of slowing things down) is to set the - SYNAPSE_CACHE_FACTOR environment variable. Roughly speaking, a - SYNAPSE_CACHE_FACTOR of 1.0 will max out at around 3-4GB of - resident memory - this is what we currently run the matrix.org - on. The default setting is currently 0.1, which is probably around - a ~700MB footprint. You can dial it down further to 0.02 if - desired, which targets roughly ~512MB. Conversely you can dial it - up if you need performance for lots of users and have a box with a - lot of RAM. + Synapse's architecture is quite RAM hungry currently - we deliberately + cache a lot of recent room data and metadata in RAM in order to speed up + common requests. We'll improve this in the future, but for now the easiest + way to either reduce the RAM usage (at the risk of slowing things down) + is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment + variable. The default is 0.5, which can be decreased to reduce RAM usage + in memory constrained enviroments, or increased if performance starts to + degrade. + + However, degraded performance due to a low cache factor, common on + machines with slow disks, often leads to explosions in memory use due + backlogged requests. In this case, reducing the cache factor will make + things worse. Instead, try increasing it drastically. 2.0 is a good + starting value. ## COPYRIGHT diff --git a/docs/.sample_config_header.yaml b/docs/.sample_config_header.yaml index 35a591d04..8c9b31acd 100644 --- a/docs/.sample_config_header.yaml +++ b/docs/.sample_config_header.yaml @@ -10,5 +10,16 @@ # homeserver.yaml. Instead, if you are starting from scratch, please generate # a fresh config using Synapse by following the instructions in INSTALL.md. +# Configuration options that take a time period can be set using a number +# followed by a letter. Letters have the following meanings: +# s = second +# m = minute +# h = hour +# d = day +# w = week +# y = year +# For example, setting redaction_retention_period: 5m would remove redacted +# messages from the database after 5 minutes, rather than 5 months. + ################################################################################ diff --git a/docs/postgres.md b/docs/postgres.md index 70fe29cdc..e71a1975d 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -188,6 +188,9 @@ to do step 2. It is safe to at any time kill the port script and restart it. +Note that the database may take up significantly more (25% - 100% more) +space on disk after porting to Postgres. + ### Using the port script Firstly, shut down the currently running synapse server and copy its diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 09a729987..598fcd4ef 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -10,6 +10,17 @@ # homeserver.yaml. Instead, if you are starting from scratch, please generate # a fresh config using Synapse by following the instructions in INSTALL.md. +# Configuration options that take a time period can be set using a number +# followed by a letter. Letters have the following meanings: +# s = second +# m = minute +# h = hour +# d = day +# w = week +# y = year +# For example, setting redaction_retention_period: 5m would remove redacted +# messages from the database after 5 minutes, rather than 5 months. + ################################################################################ # Configuration file for Synapse. @@ -1149,24 +1160,6 @@ account_validity: # #default_identity_server: https://matrix.org -# The list of identity servers trusted to verify third party -# identifiers by this server. -# -# Also defines the ID server which will be called when an account is -# deactivated (one will be picked arbitrarily). -# -# Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity -# server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a -# background migration script, informing itself that the identity server all of its -# 3PIDs have been bound to is likely one of the below. -# -# As of Synapse v1.4.0, all other functionality of this option has been deprecated, and -# it is now solely used for the purposes of the background migration script, and can be -# removed once it has run. -#trusted_third_party_id_servers: -# - matrix.org -# - vector.im - # Handle threepid (email/phone etc) registration and password resets through a set of # *trusted* identity servers. Note that this allows the configured identity server to # reset passwords for accounts! diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 6badf4e75..a18565577 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -333,24 +333,6 @@ class RegistrationConfig(Config): # #default_identity_server: https://matrix.org - # The list of identity servers trusted to verify third party - # identifiers by this server. - # - # Also defines the ID server which will be called when an account is - # deactivated (one will be picked arbitrarily). - # - # Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity - # server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a - # background migration script, informing itself that the identity server all of its - # 3PIDs have been bound to is likely one of the below. - # - # As of Synapse v1.4.0, all other functionality of this option has been deprecated, and - # it is now solely used for the purposes of the background migration script, and can be - # removed once it has run. - #trusted_third_party_id_servers: - # - matrix.org - # - vector.im - # Handle threepid (email/phone etc) registration and password resets through a set of # *trusted* identity servers. Note that this allows the configured identity server to # reset passwords for accounts! From 8dff4a12424cda9e4abaa5f2905d58aa6e723777 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 29 Jul 2020 18:26:55 +0100 Subject: [PATCH 14/26] Re-implement unread counts (#7736) --- changelog.d/7736.feature | 1 + scripts/synapse_port_db | 2 +- synapse/handlers/sync.py | 6 + synapse/push/push_tools.py | 17 +- synapse/rest/client/v2_alpha/sync.py | 1 + synapse/storage/data_stores/main/cache.py | 1 + synapse/storage/data_stores/main/events.py | 48 +++++- .../storage/data_stores/main/events_worker.py | 86 +++++++++- .../schema/delta/58/12unread_messages.sql | 18 ++ tests/rest/client/v1/utils.py | 20 +++ tests/rest/client/v2_alpha/test_sync.py | 157 +++++++++++++++++- 11 files changed, 339 insertions(+), 18 deletions(-) create mode 100644 changelog.d/7736.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature new file mode 100644 index 000000000..c97864677 --- /dev/null +++ b/changelog.d/7736.feature @@ -0,0 +1 @@ +Add unread messages count to sync responses. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 22a6abd7d..bee525197 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -69,7 +69,7 @@ logger = logging.getLogger("synapse_port_db") BOOLEAN_COLUMNS = { - "events": ["processed", "outlier", "contains_url"], + "events": ["processed", "outlier", "contains_url", "count_as_unread"], "rooms": ["is_public"], "event_edges": ["is_state"], "presence_list": ["accepted"], diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index ebd3e9810..eaa4eeadf 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -103,6 +103,7 @@ class JoinedSyncResult: account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) summary = attr.ib(type=Optional[JsonDict]) + unread_count = attr.ib(type=int) def __nonzero__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -1886,6 +1887,10 @@ class SyncHandler(object): if room_builder.rtype == "joined": unread_notifications = {} # type: Dict[str, str] + + unread_count = await self.store.get_unread_message_count_for_user( + room_id, sync_config.user.to_string(), + ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -1894,6 +1899,7 @@ class SyncHandler(object): account_data=account_data_events, unread_notifications=unread_notifications, summary=summary, + unread_count=unread_count, ) if room_sync or always_include: diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index d0145666b..bc8f71916 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -21,22 +21,13 @@ async def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") - badge = len(invites) for room_id in joins: - if room_id in my_receipts_by_room: - last_unread_event_id = my_receipts_by_room[room_id] - - notifs = await ( - store.get_unread_event_push_actions_by_room_for_user( - room_id, user_id, last_unread_event_id - ) - ) - # return one badge count per conversation, as count per - # message is so noisy as to be almost useless - badge += 1 if notifs["notify_count"] else 0 + unread_count = await store.get_unread_message_count_for_user(room_id, user_id) + # return one badge count per conversation, as count per + # message is so noisy as to be almost useless + badge += 1 if unread_count else 0 return badge diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a5c24fbd6..3f5bf75e5 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py index f39f556c2..edc3624fe 100644 --- a/synapse/storage/data_stores/main/cache.py +++ b/synapse/storage/data_stores/main/cache.py @@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) + self.get_unread_message_count_for_user.invalidate_many((room_id,)) self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) if not backfilled: diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 6f2e0d15c..0c9c02afa 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -53,6 +53,47 @@ event_counter = Counter( ["type", "origin_type", "origin_entity"], ) +STATE_EVENT_TYPES_TO_MARK_UNREAD = { + EventTypes.Topic, + EventTypes.Name, + EventTypes.RoomAvatar, + EventTypes.Tombstone, +} + + +def should_count_as_unread(event: EventBase, context: EventContext) -> bool: + # Exclude rejected and soft-failed events. + if context.rejected or event.internal_metadata.is_soft_failed(): + return False + + # Exclude notices. + if ( + not event.is_state() + and event.type == EventTypes.Message + and event.content.get("msgtype") == "m.notice" + ): + return False + + # Exclude edits. + relates_to = event.content.get("m.relates_to", {}) + if relates_to.get("rel_type") == RelationTypes.REPLACE: + return False + + # Mark events that have a non-empty string body as unread. + body = event.content.get("body") + if isinstance(body, str) and body: + return True + + # Mark some state events as unread. + if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD: + return True + + # Mark encrypted events as unread. + if not event.is_state() and event.type == EventTypes.Encrypted: + return True + + return False + def encode_json(json_object): """ @@ -196,6 +237,10 @@ class PersistEventsStore: event_counter.labels(event.type, origin_type, origin_entity).inc() + self.store.get_unread_message_count_for_user.invalidate_many( + (event.room_id,), + ) + for room_id, new_state in current_state_for_room.items(): self.store.get_current_state_ids.prefill((room_id,), new_state) @@ -817,8 +862,9 @@ class PersistEventsStore: "contains_url": ( "url" in event.content and isinstance(event.content["url"], str) ), + "count_as_unread": should_count_as_unread(event, context), } - for event, _ in events_and_contexts + for event, context in events_and_contexts ], ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index e812c6707..b03b25963 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams.events import EventsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import Database +from synapse.storage.types import Cursor from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import get_domain_from_id -from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import ( + Cache, + _CacheContext, + cached, + cachedInlineCallbacks, +) from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore): desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) + @cached(tree=True, cache_context=True) + async def get_unread_message_count_for_user( + self, room_id: str, user_id: str, cache_context: _CacheContext, + ) -> int: + """Retrieve the count of unread messages for the given room and user. + + Args: + room_id: The ID of the room to count unread messages in. + user_id: The ID of the user to count unread messages for. + + Returns: + The number of unread messages for the given user in the given room. + """ + with Measure(self._clock, "get_unread_message_count_for_user"): + last_read_event_id = await self.get_last_receipt_event_id_for_user( + user_id=user_id, + room_id=room_id, + receipt_type="m.read", + on_invalidate=cache_context.invalidate, + ) + + return await self.db.runInteraction( + "get_unread_message_count_for_user", + self._get_unread_message_count_for_user_txn, + user_id, + room_id, + last_read_event_id, + ) + + def _get_unread_message_count_for_user_txn( + self, + txn: Cursor, + user_id: str, + room_id: str, + last_read_event_id: Optional[str], + ) -> int: + if last_read_event_id: + # Get the stream ordering for the last read event. + stream_ordering = self.db.simple_select_one_onecol_txn( + txn=txn, + table="events", + keyvalues={"room_id": room_id, "event_id": last_read_event_id}, + retcol="stream_ordering", + ) + else: + # If there's no read receipt for that room, it probably means the user hasn't + # opened it yet, in which case use the stream ID of their join event. + # We can't just set it to 0 otherwise messages from other local users from + # before this user joined will be counted as well. + txn.execute( + """ + SELECT stream_ordering FROM local_current_membership + LEFT JOIN events USING (event_id, room_id) + WHERE membership = 'join' + AND user_id = ? + AND room_id = ? + """, + (user_id, room_id), + ) + row = txn.fetchone() + + if row is None: + return 0 + + stream_ordering = row[0] + + # Count the messages that qualify as unread after the stream ordering we've just + # retrieved. + sql = """ + SELECT COUNT(*) FROM events + WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread + """ + + txn.execute(sql, (user_id, room_id, stream_ordering)) + row = txn.fetchone() + + return row[0] if row else 0 + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql new file mode 100644 index 000000000..531b532c7 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Store a boolean value in the events table for whether the event should be counted in +-- the unread_count property of sync responses. +ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN; diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 22d734e76..7f8252330 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -143,6 +143,26 @@ class RestHelper(object): return channel.json_body + def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200): + if txn_id is None: + txn_id = "m%s" % (str(time.time())) + + path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id) + if tok: + path = path + "?access_token=%s" % tok + + request, channel = make_request( + self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8") + ) + render(request, self.resource, self.hs.get_reactor()) + + assert int(channel.result["code"]) == expect_code, ( + "Expected: %d, got: %d, resp: %r" + % (expect_code, int(channel.result["code"]), channel.result["body"]) + ) + + return channel.json_body + def _read_write_state( self, room_id: str, diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py index fa3a3ec1b..a31e44c97 100644 --- a/tests/rest/client/v2_alpha/test_sync.py +++ b/tests/rest/client/v2_alpha/test_sync.py @@ -16,9 +16,9 @@ import json import synapse.rest.admin -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.rest.client.v1 import login, room -from synapse.rest.client.v2_alpha import sync +from synapse.rest.client.v2_alpha import read_marker, sync from tests import unittest from tests.server import TimedOutException @@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase): "GET", sync_url % (access_token, next_batch) ) self.assertRaises(TimedOutException, self.render, request) + + +class UnreadMessagesTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + read_marker.register_servlets, + room.register_servlets, + sync.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user (used to check the unread counts). + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room we'll check unread counts for. + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user (used to send events to the room). + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Change the power levels of the room so that the second user can send state + # events. + self.helper.send_state( + self.room_id, + EventTypes.PowerLevels, + { + "users": {self.user_id: 100, self.user2: 100}, + "users_default": 0, + "events": { + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.history_visibility": 100, + "m.room.canonical_alias": 50, + "m.room.avatar": 50, + "m.room.tombstone": 100, + "m.room.server_acl": 100, + "m.room.encryption": 100, + }, + "events_default": 0, + "state_default": 50, + "ban": 50, + "kick": 50, + "redact": 50, + "invite": 0, + }, + tok=self.tok, + ) + + def test_unread_counts(self): + """Tests that /sync returns the right value for the unread count (MSC2654).""" + + # Check that our own messages don't increase the unread count. + self.helper.send(self.room_id, "hello", tok=self.tok) + self._check_unread_count(0) + + # Join the new user and check that this doesn't increase the unread count. + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + self._check_unread_count(0) + + # Check that the new user sending a message increases our unread count. + res = self.helper.send(self.room_id, "hello", tok=self.tok2) + self._check_unread_count(1) + + # Send a read receipt to tell the server we've read the latest event. + body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + request, channel = self.make_request( + "POST", + "/rooms/%s/read_markers" % self.room_id, + body, + access_token=self.tok, + ) + self.render(request) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check that the unread counter is back to 0. + self._check_unread_count(0) + + # Check that room name changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2, + ) + self._check_unread_count(1) + + # Check that room topic changes increase the unread counter. + self.helper.send_state( + self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2, + ) + self._check_unread_count(2) + + # Check that encrypted messages increase the unread counter. + self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2) + self._check_unread_count(3) + + # Check that custom events with a body increase the unread counter. + self.helper.send_event( + self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that edits don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": "hello", + "msgtype": "m.text", + "m.relates_to": {"rel_type": RelationTypes.REPLACE}, + }, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that notices don't increase the unread counter. + self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"body": "hello", "msgtype": "m.notice"}, + tok=self.tok2, + ) + self._check_unread_count(4) + + # Check that tombstone events changes increase the unread counter. + self.helper.send_state( + self.room_id, + EventTypes.Tombstone, + {"replacement_room": "!someroom:test"}, + tok=self.tok2, + ) + self._check_unread_count(5) + + def _check_unread_count(self, expected_count: True): + """Syncs and compares the unread count with the expected value.""" + + request, channel = self.make_request( + "GET", self.url % self.next_batch, access_token=self.tok, + ) + self.render(request) + + self.assertEqual(channel.code, 200, channel.json_body) + + room_entry = channel.json_body["rooms"]["join"][self.room_id] + self.assertEqual( + room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry, + ) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] From f23c77389d1eddd14b132ea6f8b5bc013f87d7a6 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 29 Jul 2020 18:31:03 +0100 Subject: [PATCH 15/26] Add MSC reference to changelog for #7736 --- changelog.d/7736.feature | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/7736.feature b/changelog.d/7736.feature index c97864677..feb02be23 100644 --- a/changelog.d/7736.feature +++ b/changelog.d/7736.feature @@ -1 +1 @@ -Add unread messages count to sync responses. +Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654). From 3a00bd1378d6c1f8bc46fd82f174398d5ed9e063 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 29 Jul 2020 13:54:44 -0400 Subject: [PATCH 16/26] Add additional logging for SAML sessions. (#7971) --- changelog.d/7971.misc | 1 + synapse/handlers/saml_handler.py | 3 +++ 2 files changed, 4 insertions(+) create mode 100644 changelog.d/7971.misc diff --git a/changelog.d/7971.misc b/changelog.d/7971.misc new file mode 100644 index 000000000..87a4eb1f4 --- /dev/null +++ b/changelog.d/7971.misc @@ -0,0 +1 @@ +Log the SAML session ID during creation. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index abecaa831..2d506dc1f 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -96,6 +96,9 @@ class SamlHandler: relay_state=client_redirect_url ) + # Since SAML sessions timeout it is useful to log when they were created. + logger.info("Initiating a new SAML session: %s" % (reqid,)) + now = self._clock.time_msec() self._outstanding_requests_dict[reqid] = Saml2SessionData( creation_time=now, ui_auth_session_id=ui_auth_session_id, From d90087cffa362cef13f6823df29d2c121a2f9bfb Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 29 Jul 2020 13:55:01 -0400 Subject: [PATCH 17/26] Remove from the event_relations table when purging historical events. (#7978) --- changelog.d/7978.bugfix | 1 + synapse/storage/data_stores/main/purge_events.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/7978.bugfix diff --git a/changelog.d/7978.bugfix b/changelog.d/7978.bugfix new file mode 100644 index 000000000..247b18db2 --- /dev/null +++ b/changelog.d/7978.bugfix @@ -0,0 +1 @@ +Fix a long standing bug: 'Duplicate key value violates unique constraint "event_relations_id"' when message retention is configured. diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py index 654656913..b53fe35c3 100644 --- a/synapse/storage/data_stores/main/purge_events.py +++ b/synapse/storage/data_stores/main/purge_events.py @@ -62,6 +62,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # event_json # event_push_actions # event_reference_hashes + # event_relations # event_search # event_to_state_groups # events @@ -209,6 +210,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "event_edges", "event_forward_extremities", "event_reference_hashes", + "event_relations", "event_search", "rejections", ): From a53e0160a2d520b92365e00647640fb7eac955dd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 29 Jul 2020 13:56:06 -0400 Subject: [PATCH 18/26] Ensure the msg property of HttpResponseException is a string. (#7979) --- changelog.d/7979.misc | 1 + synapse/http/client.py | 16 ++++++++++++---- synapse/http/matrixfederationclient.py | 7 ++++--- 3 files changed, 17 insertions(+), 7 deletions(-) create mode 100644 changelog.d/7979.misc diff --git a/changelog.d/7979.misc b/changelog.d/7979.misc new file mode 100644 index 000000000..4304bbdd2 --- /dev/null +++ b/changelog.d/7979.misc @@ -0,0 +1 @@ +Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0. diff --git a/synapse/http/client.py b/synapse/http/client.py index 6bc51202c..155b7460d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -395,7 +395,9 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: return json.loads(body.decode("utf-8")) else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) @defer.inlineCallbacks def post_json_get_json(self, uri, post_json, headers=None): @@ -436,7 +438,9 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: return json.loads(body.decode("utf-8")) else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) @defer.inlineCallbacks def get_json(self, uri, args={}, headers=None): @@ -509,7 +513,9 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: return json.loads(body.decode("utf-8")) else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) @defer.inlineCallbacks def get_raw(self, uri, args={}, headers=None): @@ -544,7 +550,9 @@ class SimpleHttpClient(object): if 200 <= response.code < 300: return body else: - raise HttpResponseException(response.code, response.phrase, body) + raise HttpResponseException( + response.code, response.phrase.decode("ascii", errors="replace"), body + ) # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # The two should be factored out. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 148eeb19d..ea026ed9f 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -447,6 +447,7 @@ class MatrixFederationHttpClient(object): ).inc() set_tag(tags.HTTP_STATUS_CODE, response.code) + response_phrase = response.phrase.decode("ascii", errors="replace") if 200 <= response.code < 300: logger.debug( @@ -454,7 +455,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode("ascii", errors="replace"), + response_phrase, ) pass else: @@ -463,7 +464,7 @@ class MatrixFederationHttpClient(object): request.txn_id, request.destination, response.code, - response.phrase.decode("ascii", errors="replace"), + response_phrase, ) # :'( # Update transactions table? @@ -487,7 +488,7 @@ class MatrixFederationHttpClient(object): ) body = None - e = HttpResponseException(response.code, response.phrase, body) + e = HttpResponseException(response.code, response_phrase, body) # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException From 3950ae51ef3e7d0bdbe5002dbe8ef5c35a9e8eea Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 06:56:55 -0400 Subject: [PATCH 19/26] Ensure that remove_pusher is always async (#7981) --- changelog.d/7981.misc | 1 + synapse/app/generic_worker.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7981.misc diff --git a/changelog.d/7981.misc b/changelog.d/7981.misc new file mode 100644 index 000000000..dfe4c0317 --- /dev/null +++ b/changelog.d/7981.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index ec0dbddb8..6e8130351 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -628,7 +628,7 @@ class GenericWorkerServer(HomeServer): self.get_tcp_replication().start_replication(self) - def remove_pusher(self, app_id, push_key, user_id): + async def remove_pusher(self, app_id, push_key, user_id): self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id) def build_replication_data_handler(self): From b3a97d6dac7f9f619b02e213bb8a745d65983d0d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 07:20:41 -0400 Subject: [PATCH 20/26] Convert some of the data store to async. (#7976) --- changelog.d/7976.misc | 1 + .../data_stores/main/event_push_actions.py | 92 +++++++++-------- synapse/storage/data_stores/main/room.py | 98 ++++++++----------- synapse/storage/data_stores/main/state.py | 57 +++++------ synapse/storage/data_stores/main/stats.py | 53 +++++----- synapse/storage/data_stores/state/store.py | 37 +++---- synapse/storage/state.py | 11 ++- tests/storage/test_event_push_actions.py | 12 ++- tests/storage/test_room.py | 24 +++-- tests/storage/test_state.py | 12 ++- 10 files changed, 190 insertions(+), 207 deletions(-) create mode 100644 changelog.d/7976.misc diff --git a/changelog.d/7976.misc b/changelog.d/7976.misc new file mode 100644 index 000000000..dfe4c0317 --- /dev/null +++ b/changelog.d/7976.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py index 18297cf3b..ad8283890 100644 --- a/synapse/storage/data_stores/main/event_push_actions.py +++ b/synapse/storage/data_stores/main/event_push_actions.py @@ -15,11 +15,10 @@ # limitations under the License. import logging +from typing import List from canonicaljson import json -from twisted.internet import defer - from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage.database import Database @@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore): return {"notify_count": notify_count, "highlight_count": highlight_count} - @defer.inlineCallbacks - def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): + async def get_push_action_users_in_range( + self, min_stream_ordering, max_stream_ordering + ): def f(txn): sql = ( "SELECT DISTINCT(user_id) FROM event_push_actions WHERE" @@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (min_stream_ordering, max_stream_ordering)) return [r[0] for r in txn] - ret = yield self.db.runInteraction("get_push_action_users_in_range", f) + ret = await self.db.runInteraction("get_push_action_users_in_range", f) return ret - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_http( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): + async def get_unread_push_actions_for_user_in_range_for_http( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the httppusher. Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the + max_stream_ordering: The inclusive upper bound on the stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. + limit: The maximum number of rows to return. Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions". + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions". The list will be ordered by ascending stream_ordering. The list will have between 0~limit entries. """ @@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.db.runInteraction( + after_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt ) @@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.db.runInteraction( + no_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt ) @@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore): # one of the subqueries may have hit the limit. return notifs[:limit] - @defer.inlineCallbacks - def get_unread_push_actions_for_user_in_range_for_email( - self, user_id, min_stream_ordering, max_stream_ordering, limit=20 - ): + async def get_unread_push_actions_for_user_in_range_for_email( + self, + user_id: str, + min_stream_ordering: int, + max_stream_ordering: int, + limit: int = 20, + ) -> List[dict]: """Get a list of the most recent unread push actions for a given user, within the given stream ordering range. Called by the emailpusher Args: - user_id (str): The user to fetch push actions for. - min_stream_ordering(int): The exclusive lower bound on the + user_id: The user to fetch push actions for. + min_stream_ordering: The exclusive lower bound on the stream ordering of event push actions to fetch. - max_stream_ordering(int): The inclusive upper bound on the + max_stream_ordering: The inclusive upper bound on the stream ordering of event push actions to fetch. - limit (int): The maximum number of rows to return. + limit: The maximum number of rows to return. Returns: - A promise which resolves to a list of dicts with the keys "event_id", - "room_id", "stream_ordering", "actions", "received_ts". + A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts". The list will be ordered by descending received_ts. The list will have between 0~limit entries. """ @@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - after_read_receipt = yield self.db.runInteraction( + after_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt ) @@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, args) return txn.fetchall() - no_read_receipt = yield self.db.runInteraction( + no_read_receipt = await self.db.runInteraction( "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt ) @@ -461,17 +465,13 @@ class EventPushActionsWorkerStore(SQLBaseStore): "add_push_actions_to_staging", _add_push_actions_to_staging_txn ) - @defer.inlineCallbacks - def remove_push_actions_from_staging(self, event_id): + async def remove_push_actions_from_staging(self, event_id: str) -> None: """Called if we failed to persist the event to ensure that stale push actions don't build up in the DB - - Args: - event_id (str) """ try: - res = yield self.db.simple_delete( + res = await self.db.simple_delete( table="event_push_actions_staging", keyvalues={"event_id": event_id}, desc="remove_push_actions_from_staging", @@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): return range_end - @defer.inlineCallbacks - def get_time_of_last_push_action_before(self, stream_ordering): + async def get_time_of_last_push_action_before(self, stream_ordering): def f(txn): sql = ( "SELECT e.received_ts" @@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): txn.execute(sql, (stream_ordering,)) return txn.fetchone() - result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) + result = await self.db.runInteraction("get_time_of_last_push_action_before", f) return result[0] if result else None @@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): self._start_rotate_notifs, 30 * 60 * 1000 ) - @defer.inlineCallbacks - def get_push_actions_for_user( + async def get_push_actions_for_user( self, user_id, before=None, limit=50, only_highlight=False ): def f(txn): @@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore): txn.execute(sql, args) return self.db.cursor_to_dict(txn) - push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) + push_actions = await self.db.runInteraction("get_push_actions_for_user", f) for pa in push_actions: pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - @defer.inlineCallbacks - def get_latest_push_action_stream_ordering(self): + async def get_latest_push_action_stream_ordering(self): def f(txn): txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") return txn.fetchone() - result = yield self.db.runInteraction( + result = await self.db.runInteraction( "get_latest_push_action_stream_ordering", f ) return result[0] or 0 @@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore): def _start_rotate_notifs(self): return run_as_background_process("rotate_notifs", self._rotate_notifs) - @defer.inlineCallbacks - def _rotate_notifs(self): + async def _rotate_notifs(self): if self._doing_notif_rotation or self.stream_ordering_day_ago is None: return self._doing_notif_rotation = True @@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore): while True: logger.info("Rotating notifications") - caught_up = yield self.db.runInteraction( + caught_up = await self.db.runInteraction( "_rotate_notifs", self._rotate_notifs_txn ) if caught_up: break - yield self.hs.get_clock().sleep(self._rotate_delay) + await self.hs.get_clock().sleep(self._rotate_delay) finally: self._doing_notif_rotation = False diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py index d2e1e36e7..ab48052cd 100644 --- a/synapse/storage/data_stores/main/room.py +++ b/synapse/storage/data_stores/main/room.py @@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.api.errors import StoreError from synapse.api.room_versions import RoomVersion, RoomVersions @@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.data_stores.main.search import SearchStore from synapse.storage.database import Database, LoggingTransaction from synapse.types import ThirdPartyInstanceID -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore): return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) - @defer.inlineCallbacks - def get_largest_public_rooms( + async def get_largest_public_rooms( self, network_tuple: Optional[ThirdPartyInstanceID], search_filter: Optional[dict], @@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore): return results - ret_val = yield self.db.runInteraction( + ret_val = await self.db.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - defer.returnValue(ret_val) + return ret_val @cached(max_entries=10000) def is_room_blocked(self, room_id): @@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore): "get_rooms_paginate", _get_rooms_paginate_txn, ) - @cachedInlineCallbacks(max_entries=10000) - def get_ratelimit_for_user(self, user_id): + @cached(max_entries=10000) + async def get_ratelimit_for_user(self, user_id): """Check if there are any overrides for ratelimiting for the given user @@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore): of RatelimitOverride are None or 0 then ratelimitng has been disabled for that user entirely. """ - row = yield self.db.simple_select_one( + row = await self.db.simple_select_one( table="ratelimit_override", keyvalues={"user_id": user_id}, retcols=("messages_per_second", "burst_count"), @@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore): else: return None - @cachedInlineCallbacks() - def get_retention_policy_for_room(self, room_id): + @cached() + async def get_retention_policy_for_room(self, room_id): """Get the retention policy for a given room. If no retention policy has been found for this room, returns a policy defined @@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore): return self.db.cursor_to_dict(txn) - ret = yield self.db.runInteraction( + ret = await self.db.runInteraction( "get_retention_policy_for_room", get_retention_policy_for_room_txn, ) # If we don't know this room ID, ret will be None, in this case return the default # policy. if not ret: - defer.returnValue( - { - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, - } - ) + return { + "min_lifetime": self.config.retention_default_min_lifetime, + "max_lifetime": self.config.retention_default_max_lifetime, + } row = ret[0] @@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore): if row["max_lifetime"] is None: row["max_lifetime"] = self.config.retention_default_max_lifetime - defer.returnValue(row) + return row def get_media_mxcs_in_room(self, room_id): """Retrieves all the local and remote media MXC URIs in a given room @@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): self._background_add_rooms_room_version_column, ) - @defer.inlineCallbacks - def _background_insert_retention(self, progress, batch_size): + async def _background_insert_retention(self, progress, batch_size): """Retrieves a list of all rooms within a range and inserts an entry for each of them into the room_retention table. NULLs the property's columns if missing from the retention event in the room's @@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore): else: return False - end = yield self.db.runInteraction( + end = await self.db.runInteraction( "insert_room_retention", _background_insert_retention_txn, ) if end: - yield self.db.updates._end_background_update("insert_room_retention") + await self.db.updates._end_background_update("insert_room_retention") - defer.returnValue(batch_size) + return batch_size async def _background_add_rooms_room_version_column( self, progress: dict, batch_size: int @@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - @defer.inlineCallbacks - def store_room( + async def store_room( self, room_id: str, room_creator_user_id: str, @@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) + await self.db.runInteraction("store_room_txn", store_room_txn, next_id) except Exception as e: logger.error("store_room with room_id=%s failed: %s", room_id, e) raise StoreError(500, "Problem creating room.") @@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): lock=False, ) - @defer.inlineCallbacks - def set_room_is_public(self, room_id, is_public): + async def set_room_is_public(self, room_id, is_public): def set_room_is_public_txn(txn, next_id): self.db.simple_update_one_txn( txn, @@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction( + await self.db.runInteraction( "set_room_is_public", set_room_is_public_txn, next_id ) self.hs.get_notifier().on_new_replication_data() - @defer.inlineCallbacks - def set_room_is_public_appservice( + async def set_room_is_public_appservice( self, room_id, appservice_id, network_id, is_public ): """Edit the appservice/network specific public room list. @@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) with self._public_room_id_gen.get_next() as next_id: - yield self.db.runInteraction( + await self.db.runInteraction( "set_room_is_public_appservice", set_room_is_public_appservice_txn, next_id, @@ -1327,52 +1318,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - @defer.inlineCallbacks - def block_room(self, room_id, user_id): + async def block_room(self, room_id: str, user_id: str) -> None: """Marks the room as blocked. Can be called multiple times. Args: - room_id (str): Room to block - user_id (str): Who blocked it - - Returns: - Deferred + room_id: Room to block + user_id: Who blocked it """ - yield self.db.simple_upsert( + await self.db.simple_upsert( table="blocked_rooms", keyvalues={"room_id": room_id}, values={}, insertion_values={"user_id": user_id}, desc="block_room", ) - yield self.db.runInteraction( + await self.db.runInteraction( "block_room_invalidation", self._invalidate_cache_and_stream, self.is_room_blocked, (room_id,), ) - @defer.inlineCallbacks - def get_rooms_for_retention_period_in_range( - self, min_ms, max_ms, include_null=False - ): + async def get_rooms_for_retention_period_in_range( + self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False + ) -> Dict[str, dict]: """Retrieves all of the rooms within the given retention range. Optionally includes the rooms which don't have a retention policy. Args: - min_ms (int|None): Duration in milliseconds that define the lower limit of + min_ms: Duration in milliseconds that define the lower limit of the range to handle (exclusive). If None, doesn't set a lower limit. - max_ms (int|None): Duration in milliseconds that define the upper limit of + max_ms: Duration in milliseconds that define the upper limit of the range to handle (inclusive). If None, doesn't set an upper limit. - include_null (bool): Whether to include rooms which retention policy is NULL + include_null: Whether to include rooms which retention policy is NULL in the returned set. Returns: - dict[str, dict]: The rooms within this range, along with their retention - policy. The key is "room_id", and maps to a dict describing the retention - policy associated with this room ID. The keys for this nested dict are - "min_lifetime" (int|None), and "max_lifetime" (int|None). + The rooms within this range, along with their retention + policy. The key is "room_id", and maps to a dict describing the retention + policy associated with this room ID. The keys for this nested dict are + "min_lifetime" (int|None), and "max_lifetime" (int|None). """ def get_rooms_for_retention_period_in_range_txn(txn): @@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): return rooms_dict - rooms = yield self.db.runInteraction( + rooms = await self.db.runInteraction( "get_rooms_for_retention_period_in_range", get_rooms_for_retention_period_in_range_txn, ) - defer.returnValue(rooms) + return rooms diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index bb38a04ed..a36069940 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -16,12 +16,12 @@ import collections.abc import logging from collections import namedtuple - -from twisted.internet import defer +from typing import Iterable, Optional, Set from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion +from synapse.events import EventBase from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore @@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = await self.get_create_event_for_room(room_id) return create_event.content.get("room_version", "1") - @defer.inlineCallbacks - def get_room_predecessor(self, room_id): + async def get_room_predecessor(self, room_id: str) -> Optional[dict]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. Args: - room_id (str) + room_id: The room ID. Returns: - Deferred[dict|None]: A dictionary containing the structure of the predecessor - field from the room's create event. The structure is subject to other servers, - but it is expected to be: - * room_id (str): The room ID of the predecessor room - * event_id (str): The ID of the tombstone event in the predecessor room + A dictionary containing the structure of the predecessor + field from the room's create event. The structure is subject to other servers, + but it is expected to be: + * room_id (str): The room ID of the predecessor room + * event_id (str): The ID of the tombstone event in the predecessor room - None if a predecessor key is not found, or is not a dictionary. + None if a predecessor key is not found, or is not a dictionary. Raises: NotFoundError if the given room is unknown """ # Retrieve the room's create event - create_event = yield self.get_create_event_for_room(room_id) + create_event = await self.get_create_event_for_room(room_id) # Retrieve the predecessor key of the create event predecessor = create_event.content.get("predecessor", None) @@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return predecessor - @defer.inlineCallbacks - def get_create_event_for_room(self, room_id): + async def get_create_event_for_room(self, room_id: str) -> EventBase: """Get the create state event for a room. Args: - room_id (str) + room_id: The room ID. Returns: - Deferred[EventBase]: The room creation event. + The room creation event. Raises: NotFoundError if the room is unknown """ - state_ids = yield self.get_current_state_ids(room_id) + state_ids = await self.get_current_state_ids(room_id) create_id = state_ids.get((EventTypes.Create, "")) # If we can't find the create event, assume we've hit a dead end @@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return - create_event = yield self.get_event(create_id) + create_event = await self.get_event(create_id) return create_event @cached(max_entries=100000, iterable=True) @@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) - @defer.inlineCallbacks - def get_canonical_alias_for_room(self, room_id): + async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: """Get canonical alias for room, if any Args: - room_id (str) + room_id: The room ID Returns: - Deferred[str|None]: The canonical alias, if any + The canonical alias, if any """ - state = yield self.get_filtered_current_state_ids( + state = await self.get_filtered_current_state_ids( room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) ) @@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not event_id: return - event = yield self.get_event(event_id, allow_none=True) + event = await self.get_event(event_id, allow_none=True) if not event: return @@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {row["event_id"]: row["state_group"] for row in rows} - @defer.inlineCallbacks - def get_referenced_state_groups(self, state_groups): + async def get_referenced_state_groups( + self, state_groups: Iterable[int] + ) -> Set[int]: """Check if the state groups are referenced by events. Args: - state_groups (Iterable[int]) + state_groups Returns: - Deferred[set[int]]: The subset of state groups that are - referenced. + The subset of state groups that are referenced. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db.simple_select_many_batch( table="event_to_state_groups", column="state_group", iterable=state_groups, diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py index 380c1ec7d..922400a7c 100644 --- a/synapse/storage/data_stores/main/stats.py +++ b/synapse/storage/data_stores/main/stats.py @@ -16,8 +16,8 @@ import logging from itertools import chain +from typing import Tuple -from twisted.internet import defer from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership @@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore): """ return (ts // self.stats_bucket_size) * self.stats_bucket_size - @defer.inlineCallbacks - def _populate_stats_process_users(self, progress, batch_size): + async def _populate_stats_process_users(self, progress, batch_size): """ This is a background update which regenerates statistics for users. """ if not self.stats_enabled: - yield self.db.updates._end_background_update("populate_stats_process_users") + await self.db.updates._end_background_update("populate_stats_process_users") return 1 last_user_id = progress.get("last_user_id", "") @@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_user_id, batch_size)) return [r for r, in txn] - users_to_work_on = yield self.db.runInteraction( + users_to_work_on = await self.db.runInteraction( "_populate_stats_process_users", _get_next_batch ) # No more rooms -- complete the transaction. if not users_to_work_on: - yield self.db.updates._end_background_update("populate_stats_process_users") + await self.db.updates._end_background_update("populate_stats_process_users") return 1 for user_id in users_to_work_on: - yield self._calculate_and_set_initial_state_for_user(user_id) + await self._calculate_and_set_initial_state_for_user(user_id) progress["last_user_id"] = user_id - yield self.db.runInteraction( + await self.db.runInteraction( "populate_stats_process_users", self.db.updates._background_update_progress_txn, "populate_stats_process_users", @@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore): return len(users_to_work_on) - @defer.inlineCallbacks - def _populate_stats_process_rooms(self, progress, batch_size): + async def _populate_stats_process_rooms(self, progress, batch_size): """ This is a background update which regenerates statistics for rooms. """ if not self.stats_enabled: - yield self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db.updates._end_background_update("populate_stats_process_rooms") return 1 last_room_id = progress.get("last_room_id", "") @@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore): txn.execute(sql, (last_room_id, batch_size)) return [r for r, in txn] - rooms_to_work_on = yield self.db.runInteraction( + rooms_to_work_on = await self.db.runInteraction( "populate_stats_rooms_get_batch", _get_next_batch ) # No more rooms -- complete the transaction. if not rooms_to_work_on: - yield self.db.updates._end_background_update("populate_stats_process_rooms") + await self.db.updates._end_background_update("populate_stats_process_rooms") return 1 for room_id in rooms_to_work_on: - yield self._calculate_and_set_initial_state_for_room(room_id) + await self._calculate_and_set_initial_state_for_room(room_id) progress["last_room_id"] = room_id - yield self.db.runInteraction( + await self.db.runInteraction( "_populate_stats_process_rooms", self.db.updates._background_update_progress_txn, "populate_stats_process_rooms", @@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore): return room_deltas, user_deltas - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_room(self, room_id): + async def _calculate_and_set_initial_state_for_room( + self, room_id: str + ) -> Tuple[dict, dict, int]: """Calculate and insert an entry into room_stats_current. Args: - room_id (str) + room_id: The room ID under calculation. Returns: - Deferred[tuple[dict, dict, int]]: A tuple of room state, membership - counts and stream position. + A tuple of room state, membership counts and stream position. """ def _fetch_current_state_stats(txn): @@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore): current_state_events_count, users_in_room, pos, - ) = yield self.db.runInteraction( + ) = await self.db.runInteraction( "get_initial_state_for_room", _fetch_current_state_stats ) - state_event_map = yield self.get_events(event_ids, get_prev_content=False) + state_event_map = await self.get_events(event_ids, get_prev_content=False) room_state = { "join_rules": None, @@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore): event.content.get("m.federate", True) is True ) - yield self.update_room_state(room_id, room_state) + await self.update_room_state(room_id, room_state) local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] - yield self.update_stats_delta( + await self.update_stats_delta( ts=self.clock.time_msec(), stats_type="room", stats_id=room_id, @@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore): }, ) - @defer.inlineCallbacks - def _calculate_and_set_initial_state_for_user(self, user_id): + async def _calculate_and_set_initial_state_for_user(self, user_id): def _calculate_and_set_initial_state_for_user_txn(txn): pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) @@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore): (count,) = txn.fetchone() return count, pos - joined_rooms, pos = yield self.db.runInteraction( + joined_rooms, pos = await self.db.runInteraction( "calculate_and_set_initial_state_for_user", _calculate_and_set_initial_state_for_user_txn, ) - yield self.update_stats_delta( + await self.update_stats_delta( ts=self.clock.time_msec(), stats_type="user", stats_id=user_id, diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 128c09a2c..7dada7f75 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): "get_state_group_delta", _get_state_group_delta_txn ) - @defer.inlineCallbacks - def _get_state_groups_from_groups( + async def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Dict[int, StateMap[str]]: """Returns the state groups for a given set of groups from the database, filtering on types of state events. @@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_filter: The state filter used to fetch state from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ results = {} chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] for chunk in chunks: - res = yield self.db.runInteraction( + res = await self.db.runInteraction( "_get_state_groups_from_groups", self._get_state_groups_from_groups_txn, chunk, @@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types - @defer.inlineCallbacks - def _get_state_for_groups( + async def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[int, StateMap[str]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): state_filter: The state filter used to fetch state from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ member_filter, non_member_filter = state_filter.get_member_split() @@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ( non_member_state, incomplete_groups_nm, - ) = yield self._get_state_for_groups_using_cache( + ) = self._get_state_for_groups_using_cache( groups, self._state_group_cache, state_filter=non_member_filter ) - ( - member_state, - incomplete_groups_m, - ) = yield self._get_state_for_groups_using_cache( + (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache( groups, self._state_group_members_cache, state_filter=member_filter ) @@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): # Help the cache hit ratio by expanding the filter a bit db_state_filter = state_filter.return_expanded() - group_to_state_dict = yield self._get_state_groups_from_groups( + group_to_state_dict = await self._get_state_groups_from_groups( list(incomplete_groups), state_filter=db_state_filter ) @@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ((sg,) for sg in state_groups_to_delete), ) - @defer.inlineCallbacks - def get_previous_state_groups(self, state_groups): + async def get_previous_state_groups( + self, state_groups: Iterable[int] + ) -> Dict[int, int]: """Fetch the previous groups of the given state groups. Args: - state_groups (Iterable[int]) + state_groups Returns: - Deferred[dict[int, int]]: mapping from state group to previous - state group. + A mapping from state group to previous state group. """ - rows = yield self.db.simple_select_many_batch( + rows = await self.db.simple_select_many_batch( table="state_group_edges", column="prev_state_group", iterable=state_groups, diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 49ee9c9a7..534883361 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar +from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar import attr @@ -419,7 +419,7 @@ class StateGroupStorage(object): def _get_state_groups_from_groups( self, groups: List[int], state_filter: StateFilter - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Returns the state groups for a given set of groups, filtering on types of state events. @@ -429,7 +429,7 @@ class StateGroupStorage(object): from the database. Returns: - Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) @@ -532,7 +532,7 @@ class StateGroupStorage(object): def _get_state_for_groups( self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() - ): + ) -> Awaitable[Dict[int, StateMap[str]]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -540,8 +540,9 @@ class StateGroupStorage(object): groups: list of state groups for which we want to get the state. state_filter: The state filter used to fetch state. from the database. + Returns: - Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. + Dict of state group to state map. """ return self.stores.state._get_state_for_groups(groups, state_filter) diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 43dbeb42c..2b1580fee 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -39,14 +39,18 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_http(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_http( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_http( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks def test_get_unread_push_actions_for_user_in_range_for_email(self): - yield self.store.get_unread_push_actions_for_user_in_range_for_email( - USER_ID, 0, 1000, 20 + yield defer.ensureDeferred( + self.store.get_unread_push_actions_for_user_in_range_for_email( + USER_ID, 0, 1000, 20 + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index a5f250d47..d07b985a8 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -37,11 +37,13 @@ class RoomStoreTestCase(unittest.TestCase): self.alias = RoomAlias.from_string("#a-room-name:test") self.u_creator = UserID.from_string("@creator:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id=self.u_creator.to_string(), - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id=self.u_creator.to_string(), + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks @@ -88,11 +90,13 @@ class RoomEventsStoreTestCase(unittest.TestCase): self.room = RoomID.from_string("!abcde:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 6a48b9d3b..8bd12fa84 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -44,11 +44,13 @@ class StateStoreTestCase(tests.unittest.TestCase): self.room = RoomID.from_string("!abc123:test") - yield self.store.store_room( - self.room.to_string(), - room_creator_user_id="@creator:text", - is_public=True, - room_version=RoomVersions.V1, + yield defer.ensureDeferred( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) ) @defer.inlineCallbacks From 4cce8ef74ec233d8e49361bee705f2e38de2e11e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 07:27:39 -0400 Subject: [PATCH 21/26] Convert appservice to async. (#7973) --- changelog.d/7973.misc | 1 + synapse/appservice/__init__.py | 31 +++++----- synapse/appservice/api.py | 21 +++---- synapse/appservice/scheduler.py | 49 +++++++--------- synapse/handlers/appservice.py | 10 ++-- tests/appservice/test_appservice.py | 89 ++++++++++++++++++++--------- tests/appservice/test_scheduler.py | 19 +++--- tests/handlers/test_appservice.py | 5 +- 8 files changed, 122 insertions(+), 103 deletions(-) create mode 100644 changelog.d/7973.misc diff --git a/changelog.d/7973.misc b/changelog.d/7973.misc new file mode 100644 index 000000000..dfe4c0317 --- /dev/null +++ b/changelog.d/7973.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 032325647..1ffdc1ed9 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -15,11 +15,9 @@ import logging import re -from twisted.internet import defer - from synapse.api.constants import EventTypes from synapse.types import GroupID, get_domain_from_id -from synapse.util.caches.descriptors import cachedInlineCallbacks +from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -43,7 +41,7 @@ class AppServiceTransaction(object): Args: as_api(ApplicationServiceApi): The API to use to send. Returns: - A Deferred which resolves to True if the transaction was sent. + An Awaitable which resolves to True if the transaction was sent. """ return as_api.push_bulk( service=self.service, events=self.events, txn_id=self.id @@ -172,8 +170,7 @@ class ApplicationService(object): return regex_obj["exclusive"] return False - @defer.inlineCallbacks - def _matches_user(self, event, store): + async def _matches_user(self, event, store): if not event: return False @@ -188,12 +185,12 @@ class ApplicationService(object): if not store: return False - does_match = yield self._matches_user_in_member_list(event.room_id, store) + does_match = await self._matches_user_in_member_list(event.room_id, store) return does_match - @cachedInlineCallbacks(num_args=1, cache_context=True) - def _matches_user_in_member_list(self, room_id, store, cache_context): - member_list = yield store.get_users_in_room( + @cached(num_args=1, cache_context=True) + async def _matches_user_in_member_list(self, room_id, store, cache_context): + member_list = await store.get_users_in_room( room_id, on_invalidate=cache_context.invalidate ) @@ -208,35 +205,33 @@ class ApplicationService(object): return self.is_interested_in_room(event.room_id) return False - @defer.inlineCallbacks - def _matches_aliases(self, event, store): + async def _matches_aliases(self, event, store): if not store or not event: return False - alias_list = yield store.get_aliases_for_room(event.room_id) + alias_list = await store.get_aliases_for_room(event.room_id) for alias in alias_list: if self.is_interested_in_alias(alias): return True return False - @defer.inlineCallbacks - def is_interested(self, event, store=None): + async def is_interested(self, event, store=None) -> bool: """Check if this service is interested in this event. Args: event(Event): The event to check. store(DataStore) Returns: - bool: True if this service would like to know about this event. + True if this service would like to know about this event. """ # Do cheap checks first if self._matches_room_id(event): return True - if (yield self._matches_aliases(event, store)): + if await self._matches_aliases(event, store): return True - if (yield self._matches_user(event, store)): + if await self._matches_user(event, store): return True return False diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 1e0e4d497..db578bda7 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -93,13 +93,12 @@ class ApplicationServiceApi(SimpleHttpClient): hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS ) - @defer.inlineCallbacks - def query_user(self, service, user_id): + async def query_user(self, service, user_id): if service.url is None: return False uri = service.url + ("/users/%s" % urllib.parse.quote(user_id)) try: - response = yield self.get_json(uri, {"access_token": service.hs_token}) + response = await self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object return True except CodeMessageException as e: @@ -110,14 +109,12 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_user to %s threw exception %s", uri, ex) return False - @defer.inlineCallbacks - def query_alias(self, service, alias): + async def query_alias(self, service, alias): if service.url is None: return False uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias)) - response = None try: - response = yield self.get_json(uri, {"access_token": service.hs_token}) + response = await self.get_json(uri, {"access_token": service.hs_token}) if response is not None: # just an empty json object return True except CodeMessageException as e: @@ -128,8 +125,7 @@ class ApplicationServiceApi(SimpleHttpClient): logger.warning("query_alias to %s threw exception %s", uri, ex) return False - @defer.inlineCallbacks - def query_3pe(self, service, kind, protocol, fields): + async def query_3pe(self, service, kind, protocol, fields): if kind == ThirdPartyEntityKind.USER: required_field = "userid" elif kind == ThirdPartyEntityKind.LOCATION: @@ -146,7 +142,7 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.parse.quote(protocol), ) try: - response = yield self.get_json(uri, fields) + response = await self.get_json(uri, fields) if not isinstance(response, list): logger.warning( "query_3pe to %s returned an invalid response %r", uri, response @@ -202,8 +198,7 @@ class ApplicationServiceApi(SimpleHttpClient): key = (service.id, protocol) return self.protocol_meta_cache.wrap(key, _get) - @defer.inlineCallbacks - def push_bulk(self, service, events, txn_id=None): + async def push_bulk(self, service, events, txn_id=None): if service.url is None: return True @@ -218,7 +213,7 @@ class ApplicationServiceApi(SimpleHttpClient): uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id)) try: - yield self.put_json( + await self.put_json( uri=uri, json_body={"events": events}, args={"access_token": service.hs_token}, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 9998f822f..d5204b131 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -50,8 +50,6 @@ components. """ import logging -from twisted.internet import defer - from synapse.appservice import ApplicationServiceState from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process @@ -73,12 +71,11 @@ class ApplicationServiceScheduler(object): self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) - @defer.inlineCallbacks - def start(self): + async def start(self): logger.info("Starting appservice scheduler") # check for any DOWN ASes and start recoverers for them. - services = yield self.store.get_appservices_by_state( + services = await self.store.get_appservices_by_state( ApplicationServiceState.DOWN ) @@ -117,8 +114,7 @@ class _ServiceQueuer(object): "as-sender-%s" % (service.id,), self._send_request, service ) - @defer.inlineCallbacks - def _send_request(self, service): + async def _send_request(self, service): # sanity-check: we shouldn't get here if this service already has a sender # running. assert service.id not in self.requests_in_flight @@ -130,7 +126,7 @@ class _ServiceQueuer(object): if not events: return try: - yield self.txn_ctrl.send(service, events) + await self.txn_ctrl.send(service, events) except Exception: logger.exception("AS request failed") finally: @@ -162,36 +158,33 @@ class _TransactionController(object): # for UTs self.RECOVERER_CLASS = _Recoverer - @defer.inlineCallbacks - def send(self, service, events): + async def send(self, service, events): try: - txn = yield self.store.create_appservice_txn(service=service, events=events) - service_is_up = yield self._is_service_up(service) + txn = await self.store.create_appservice_txn(service=service, events=events) + service_is_up = await self._is_service_up(service) if service_is_up: - sent = yield txn.send(self.as_api) + sent = await txn.send(self.as_api) if sent: - yield txn.complete(self.store) + await txn.complete(self.store) else: run_in_background(self._on_txn_fail, service) except Exception: logger.exception("Error creating appservice transaction") run_in_background(self._on_txn_fail, service) - @defer.inlineCallbacks - def on_recovered(self, recoverer): + async def on_recovered(self, recoverer): logger.info( "Successfully recovered application service AS ID %s", recoverer.service.id ) self.recoverers.pop(recoverer.service.id) logger.info("Remaining active recoverers: %s", len(self.recoverers)) - yield self.store.set_appservice_state( + await self.store.set_appservice_state( recoverer.service, ApplicationServiceState.UP ) - @defer.inlineCallbacks - def _on_txn_fail(self, service): + async def _on_txn_fail(self, service): try: - yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN) + await self.store.set_appservice_state(service, ApplicationServiceState.DOWN) self.start_recoverer(service) except Exception: logger.exception("Error starting AS recoverer") @@ -211,9 +204,8 @@ class _TransactionController(object): recoverer.recover() logger.info("Now %i active recoverers", len(self.recoverers)) - @defer.inlineCallbacks - def _is_service_up(self, service): - state = yield self.store.get_appservice_state(service) + async def _is_service_up(self, service): + state = await self.store.get_appservice_state(service) return state == ApplicationServiceState.UP or state is None @@ -254,25 +246,24 @@ class _Recoverer(object): self.backoff_counter += 1 self.recover() - @defer.inlineCallbacks - def retry(self): + async def retry(self): logger.info("Starting retries on %s", self.service.id) try: while True: - txn = yield self.store.get_oldest_unsent_txn(self.service) + txn = await self.store.get_oldest_unsent_txn(self.service) if not txn: # nothing left: we're done! - self.callback(self) + await self.callback(self) return logger.info( "Retrying transaction %s for AS ID %s", txn.id, txn.service.id ) - sent = yield txn.send(self.as_api) + sent = await txn.send(self.as_api) if not sent: break - yield txn.complete(self.store) + await txn.complete(self.store) # reset the backoff counter and then process the next transaction self.backoff_counter = 1 diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 92d4c6e16..fbc56c351 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -27,7 +27,6 @@ from synapse.metrics import ( event_processing_loop_room_count, ) from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.util import log_failure from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -100,10 +99,11 @@ class ApplicationServicesHandler(object): if not self.started_scheduler: - def start_scheduler(): - return self.scheduler.start().addErrback( - log_failure, "Application Services Failure" - ) + async def start_scheduler(): + try: + return self.scheduler.start() + except Exception: + logger.error("Application Services Failure") run_as_background_process("as_scheduler", start_scheduler) self.started_scheduler = True diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 4003869ed..236b608d5 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -50,13 +50,17 @@ class ApplicationServiceTestCase(unittest.TestCase): def test_regex_user_id_prefix_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_user_id_prefix_no_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" - self.assertFalse((yield self.service.is_interested(self.event))) + self.assertFalse( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_member_is_checked(self): @@ -64,7 +68,9 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_id_match(self): @@ -72,7 +78,9 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_room_id_no_match(self): @@ -80,19 +88,26 @@ class ApplicationServiceTestCase(unittest.TestCase): _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" - self.assertFalse((yield self.service.is_interested(self.event))) + self.assertFalse( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_regex_alias_match(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = [ - "#irc_foobar:matrix.org", - "#athing:matrix.org", - ] - self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#irc_foobar:matrix.org", "#athing:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertTrue( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) def test_non_exclusive_alias(self): self.service.namespaces[ApplicationService.NS_ALIASES].append( @@ -135,12 +150,17 @@ class ApplicationServiceTestCase(unittest.TestCase): self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room.return_value = [ - "#xmpp_foobar:matrix.org", - "#athing:matrix.org", - ] - self.store.get_users_in_room.return_value = [] - self.assertFalse((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertFalse( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) @defer.inlineCallbacks def test_regex_multiple_matches(self): @@ -149,9 +169,17 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.store.get_aliases_for_room.return_value = ["#irc_barfoo:matrix.org"] - self.store.get_users_in_room.return_value = [] - self.assertTrue((yield self.service.is_interested(self.event, self.store))) + self.store.get_aliases_for_room.return_value = defer.succeed( + ["#irc_barfoo:matrix.org"] + ) + self.store.get_users_in_room.return_value = defer.succeed([]) + self.assertTrue( + ( + yield defer.ensureDeferred( + self.service.is_interested(self.event, self.store) + ) + ) + ) @defer.inlineCallbacks def test_interested_in_self(self): @@ -161,19 +189,24 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.type = "m.room.member" self.event.content = {"membership": "invite"} self.event.state_key = self.service.sender - self.assertTrue((yield self.service.is_interested(self.event))) + self.assertTrue( + (yield defer.ensureDeferred(self.service.is_interested(self.event))) + ) @defer.inlineCallbacks def test_member_list_match(self): self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) - self.store.get_users_in_room.return_value = [ - "@alice:here", - "@irc_fo:here", # AS user - "@bob:here", - ] - self.store.get_aliases_for_room.return_value = [] + # Note that @irc_fo:here is the AS user. + self.store.get_users_in_room.return_value = defer.succeed( + ["@alice:here", "@irc_fo:here", "@bob:here"] + ) + self.store.get_aliases_for_room.return_value = defer.succeed([]) self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( - (yield self.service.is_interested(event=self.event, store=self.store)) + ( + yield defer.ensureDeferred( + self.service.is_interested(event=self.event, store=self.store) + ) + ) ) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 52f89d3f8..68a4caabb 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -25,6 +25,7 @@ from synapse.appservice.scheduler import ( from synapse.logging.context import make_deferred_yieldable from tests import unittest +from tests.test_utils import make_awaitable from ..utils import MockClock @@ -52,11 +53,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.get_appservice_state = Mock( return_value=defer.succeed(ApplicationServiceState.UP) ) - txn.send = Mock(return_value=defer.succeed(True)) + txn.send = Mock(return_value=make_awaitable(True)) self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events # txn made and saved @@ -77,7 +78,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events # txn made and saved @@ -98,11 +99,11 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): return_value=defer.succeed(ApplicationServiceState.UP) ) self.store.set_appservice_state = Mock(return_value=defer.succeed(True)) - txn.send = Mock(return_value=defer.succeed(False)) # fails to send + txn.send = Mock(return_value=make_awaitable(False)) # fails to send self.store.create_appservice_txn = Mock(return_value=defer.succeed(txn)) # actual call - self.txnctrl.send(service, events) + self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.store.create_appservice_txn.assert_called_once_with( service=service, events=events @@ -144,7 +145,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=True) + txn.send = Mock(return_value=make_awaitable(True)) + txn.complete.return_value = make_awaitable(None) # wait for exp backoff self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) @@ -169,7 +171,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() self.assertEquals(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = Mock(return_value=False) + txn.send = Mock(return_value=make_awaitable(False)) + txn.complete.return_value = make_awaitable(None) self.clock.advance_time(2) self.assertEquals(1, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) @@ -182,7 +185,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEquals(3, txn.send.call_count) self.assertEquals(0, txn.complete.call_count) self.assertEquals(0, self.callback.call_count) - txn.send = Mock(return_value=True) # successfully send the txn + txn.send = Mock(return_value=make_awaitable(True)) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEquals(1, txn.send.call_count) # new mock reset call count diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index ebabe9a7d..628f7d8db 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.handlers.appservice import ApplicationServicesHandler +from tests.test_utils import make_awaitable from tests.utils import MockClock from .. import unittest @@ -117,7 +118,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice_alias(is_interested_in_alias=False), ] - self.mock_as_api.query_alias.return_value = defer.succeed(True) + self.mock_as_api.query_alias.return_value = make_awaitable(True) self.mock_store.get_app_services.return_value = services self.mock_store.get_association_from_room_alias.return_value = defer.succeed( Mock(room_id=room_id, servers=servers) @@ -135,7 +136,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def _mkservice(self, is_interested): service = Mock() - service.is_interested.return_value = defer.succeed(is_interested) + service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" return service From c978f6c4515a631f289aedb1844d8579b9334aaa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 08:01:33 -0400 Subject: [PATCH 22/26] Convert federation client to async/await. (#7975) --- changelog.d/7975.misc | 1 + contrib/cmdclient/console.py | 16 ++-- synapse/crypto/keyring.py | 60 ++++++------ synapse/federation/federation_client.py | 8 +- synapse/federation/sender/__init__.py | 19 ++-- synapse/federation/transport/client.py | 96 ++++++++----------- synapse/handlers/groups_local.py | 35 +++---- synapse/http/matrixfederationclient.py | 72 +++++++------- tests/crypto/test_keyring.py | 11 +-- tests/federation/test_complexity.py | 21 ++-- tests/federation/test_federation_sender.py | 10 +- tests/handlers/test_directory.py | 5 +- tests/handlers/test_profile.py | 3 +- tests/http/test_fedclient.py | 50 +++++++--- .../test_federation_sender_shard.py | 13 ++- tests/rest/admin/test_admin.py | 4 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +- tests/test_federation.py | 2 +- 18 files changed, 209 insertions(+), 221 deletions(-) create mode 100644 changelog.d/7975.misc diff --git a/changelog.d/7975.misc b/changelog.d/7975.misc new file mode 100644 index 000000000..dfe4c0317 --- /dev/null +++ b/changelog.d/7975.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 77422f5e5..dfc1d294d 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -609,13 +609,15 @@ class SynapseCmd(cmd.Cmd): @defer.inlineCallbacks def _do_event_stream(self, timeout): - res = yield self.http_client.get_json( - self._url() + "/events", - { - "access_token": self._tok(), - "timeout": str(timeout), - "from": self.event_stream_token, - }, + res = yield defer.ensureDeferred( + self.http_client.get_json( + self._url() + "/events", + { + "access_token": self._tok(), + "timeout": str(timeout), + "from": self.event_stream_token, + }, + ) ) print(json.dumps(res, indent=4)) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index dbfc3e897..443cde0b6 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -632,18 +632,20 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) try: - query_response = yield self.client.post_json( - destination=perspective_name, - path="/_matrix/key/v2/query", - data={ - "server_keys": { - server_name: { - key_id: {"minimum_valid_until_ts": min_valid_ts} - for key_id, min_valid_ts in server_keys.items() + query_response = yield defer.ensureDeferred( + self.client.post_json( + destination=perspective_name, + path="/_matrix/key/v2/query", + data={ + "server_keys": { + server_name: { + key_id: {"minimum_valid_until_ts": min_valid_ts} + for key_id, min_valid_ts in server_keys.items() + } + for server_name, server_keys in keys_to_fetch.items() } - for server_name, server_keys in keys_to_fetch.items() - } - }, + }, + ) ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve upon @@ -792,23 +794,25 @@ class ServerKeyFetcher(BaseV2KeyFetcher): time_now_ms = self.clock.time_msec() try: - response = yield self.client.get_json( - destination=server_name, - path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id), - ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an - # easy request to handle, so if it doesn't reply within 10s, it's - # probably not going to. - # - # Furthermore, when we are acting as a notary server, we cannot - # wait all day for all of the origin servers, as the requesting - # server will otherwise time out before we can respond. - # - # (Note that get_json may make 4 attempts, so this can still take - # almost 45 seconds to fetch the headers, plus up to another 60s to - # read the response). - timeout=10000, + response = yield defer.ensureDeferred( + self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server/" + + urllib.parse.quote(requested_key_id), + ignore_backoff=True, + # we only give the remote server 10s to respond. It should be an + # easy request to handle, so if it doesn't reply within 10s, it's + # probably not going to. + # + # Furthermore, when we are acting as a notary server, we cannot + # wait all day for all of the origin servers, as the requesting + # server will otherwise time out before we can respond. + # + # (Note that get_json may make 4 attempts, so this can still take + # almost 45 seconds to fetch the headers, plus up to another 60s to + # read the response). + timeout=10000, + ) ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 994e6c8d5..38ac7ec69 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -135,7 +135,7 @@ class FederationClient(FederationBase): and try the request anyway. Returns: - a Deferred which will eventually yield a JSON object from the + a Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels(query_type).inc() @@ -157,7 +157,7 @@ class FederationClient(FederationBase): content (dict): The query content. Returns: - a Deferred which will eventually yield a JSON object from the + an Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels("client_device_keys").inc() @@ -180,7 +180,7 @@ class FederationClient(FederationBase): content (dict): The query content. Returns: - a Deferred which will eventually yield a JSON object from the + an Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() @@ -900,7 +900,7 @@ class FederationClient(FederationBase): party instance Returns: - Deferred[Dict[str, Any]]: The response from the remote server, or None if + Awaitable[Dict[str, Any]]: The response from the remote server, or None if `remote_server` is the same as the local server_name Raises: diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index ba4ddd237..8f549ae6e 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -288,8 +288,7 @@ class FederationSender(object): for destination in destinations: self._get_per_destination_queue(destination).send_pdu(pdu, order) - @defer.inlineCallbacks - def send_read_receipt(self, receipt: ReadReceipt): + async def send_read_receipt(self, receipt: ReadReceipt) -> None: """Send a RR to any other servers in the room Args: @@ -330,9 +329,7 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = yield defer.ensureDeferred( - self.state.get_current_hosts_in_room(room_id) - ) + domains = await self.state.get_current_hosts_in_room(room_id) domains = [ d for d in domains @@ -387,8 +384,7 @@ class FederationSender(object): queue.flush_read_receipts_for_room(room_id) @preserve_fn # the caller should not yield on this - @defer.inlineCallbacks - def send_presence(self, states: List[UserPresenceState]): + async def send_presence(self, states: List[UserPresenceState]): """Send the new presence states to the appropriate destinations. This actually queues up the presence states ready for sending and @@ -423,7 +419,7 @@ class FederationSender(object): if not states_map: break - yield self._process_presence_inner(list(states_map.values())) + await self._process_presence_inner(list(states_map.values())) except Exception: logger.exception("Error sending presence states to servers") finally: @@ -450,14 +446,11 @@ class FederationSender(object): self._get_per_destination_queue(destination).send_presence(states) @measure_func("txnqueue._process_presence") - @defer.inlineCallbacks - def _process_presence_inner(self, states: List[UserPresenceState]): + async def _process_presence_inner(self, states: List[UserPresenceState]): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination """ - hosts_and_states = yield defer.ensureDeferred( - get_interested_remotes(self.store, states, self.state) - ) + hosts_and_states = await get_interested_remotes(self.store, states, self.state) for destinations, states in hosts_and_states: for destination in destinations: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index cfdf23d36..9ea821dbb 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -18,8 +18,6 @@ import logging import urllib from typing import Any, Dict, Optional -from twisted.internet import defer - from synapse.api.constants import Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.urls import ( @@ -51,7 +49,7 @@ class TransportLayerClient(object): event_id (str): The event we want the context at. Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) @@ -75,7 +73,7 @@ class TransportLayerClient(object): giving up. None indicates no timeout. Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) @@ -96,7 +94,7 @@ class TransportLayerClient(object): limit (int) Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug( "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", @@ -118,16 +116,15 @@ class TransportLayerClient(object): destination, path=path, args=args, try_trailing_slash_on_400=True ) - @defer.inlineCallbacks @log_function - def send_transaction(self, transaction, json_data_callback=None): + async def send_transaction(self, transaction, json_data_callback=None): """ Sends the given Transaction to its destination Args: transaction (Transaction) Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Fails with ``HTTPRequestException`` if we get an HTTP response @@ -154,7 +151,7 @@ class TransportLayerClient(object): path = _create_v1_path("/send/%s", transaction.transaction_id) - response = yield self.client.put_json( + response = await self.client.put_json( transaction.destination, path=path, data=json_data, @@ -166,14 +163,13 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def make_query( + async def make_query( self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False ): path = _create_v1_path("/query/%s", query_type) - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, args=args, @@ -184,9 +180,10 @@ class TransportLayerClient(object): return content - @defer.inlineCallbacks @log_function - def make_membership_event(self, destination, room_id, user_id, membership, params): + async def make_membership_event( + self, destination, room_id, user_id, membership, params + ): """Asks a remote server to build and sign us a membership event Note that this does not append any events to any graphs. @@ -200,7 +197,7 @@ class TransportLayerClient(object): request. Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body (ie, the new event). Fails with ``HTTPRequestException`` if we get an HTTP response @@ -231,7 +228,7 @@ class TransportLayerClient(object): ignore_backoff = True retry_on_dns_fail = True - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, args=params, @@ -242,34 +239,31 @@ class TransportLayerClient(object): return content - @defer.inlineCallbacks @log_function - def send_join_v1(self, destination, room_id, event_id, content): + async def send_join_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content ) return response - @defer.inlineCallbacks @log_function - def send_join_v2(self, destination, room_id, event_id, content): + async def send_join_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/send_join/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content ) return response - @defer.inlineCallbacks @log_function - def send_leave_v1(self, destination, room_id, event_id, content): + async def send_leave_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, @@ -282,12 +276,11 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def send_leave_v2(self, destination, room_id, event_id, content): + async def send_leave_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, @@ -300,31 +293,28 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def send_invite_v1(self, destination, room_id, event_id, content): + async def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) return response - @defer.inlineCallbacks @log_function - def send_invite_v2(self, destination, room_id, event_id, content): + async def send_invite_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/invite/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) return response - @defer.inlineCallbacks @log_function - def get_public_rooms( + async def get_public_rooms( self, remote_server: str, limit: Optional[int] = None, @@ -355,7 +345,7 @@ class TransportLayerClient(object): data["filter"] = search_filter try: - response = yield self.client.post_json( + response = await self.client.post_json( destination=remote_server, path=path, data=data, ignore_backoff=True ) except HttpResponseException as e: @@ -381,7 +371,7 @@ class TransportLayerClient(object): args["since"] = [since_token] try: - response = yield self.client.get_json( + response = await self.client.get_json( destination=remote_server, path=path, args=args, ignore_backoff=True ) except HttpResponseException as e: @@ -396,29 +386,26 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def exchange_third_party_invite(self, destination, room_id, event_dict): + async def exchange_third_party_invite(self, destination, room_id, event_dict): path = _create_v1_path("/exchange_third_party_invite/%s", room_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=event_dict ) return response - @defer.inlineCallbacks @log_function - def get_event_auth(self, destination, room_id, event_id): + async def get_event_auth(self, destination, room_id, event_id): path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) - content = yield self.client.get_json(destination=destination, path=path) + content = await self.client.get_json(destination=destination, path=path) return content - @defer.inlineCallbacks @log_function - def query_client_keys(self, destination, query_content, timeout): + async def query_client_keys(self, destination, query_content, timeout): """Query the device keys for a list of user ids hosted on a remote server. @@ -453,14 +440,13 @@ class TransportLayerClient(object): """ path = _create_v1_path("/user/keys/query") - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def query_user_devices(self, destination, user_id, timeout): + async def query_user_devices(self, destination, user_id, timeout): """Query the devices for a user id hosted on a remote server. Response: @@ -493,14 +479,13 @@ class TransportLayerClient(object): """ path = _create_v1_path("/user/devices/%s", user_id) - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def claim_client_keys(self, destination, query_content, timeout): + async def claim_client_keys(self, destination, query_content, timeout): """Claim one-time keys for a list of devices hosted on a remote server. Request: @@ -532,14 +517,13 @@ class TransportLayerClient(object): path = _create_v1_path("/user/keys/claim") - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def get_missing_events( + async def get_missing_events( self, destination, room_id, @@ -551,7 +535,7 @@ class TransportLayerClient(object): ): path = _create_v1_path("/get_missing_events/%s", room_id) - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data={ diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index ecdb12a7b..0e2656ccb 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -23,39 +23,32 @@ logger = logging.getLogger(__name__) def _create_rerouter(func_name): - """Returns a function that looks at the group id and calls the function + """Returns an async function that looks at the group id and calls the function on federation or the local group server if the group is local """ - def f(self, group_id, *args, **kwargs): + async def f(self, group_id, *args, **kwargs): if self.is_mine_id(group_id): - return getattr(self.groups_server_handler, func_name)( + return await getattr(self.groups_server_handler, func_name)( group_id, *args, **kwargs ) else: destination = get_domain_from_id(group_id) - d = getattr(self.transport_client, func_name)( - destination, group_id, *args, **kwargs - ) - # Capture errors returned by the remote homeserver and - # re-throw specific errors as SynapseErrors. This is so - # when the remote end responds with things like 403 Not - # In Group, we can communicate that to the client instead - # of a 500. - def http_response_errback(failure): - failure.trap(HttpResponseException) - e = failure.value + try: + return await getattr(self.transport_client, func_name)( + destination, group_id, *args, **kwargs + ) + except HttpResponseException as e: + # Capture errors returned by the remote homeserver and + # re-throw specific errors as SynapseErrors. This is so + # when the remote end responds with things like 403 Not + # In Group, we can communicate that to the client instead + # of a 500. raise e.to_synapse_error() - - def request_failed_errback(failure): - failure.trap(RequestSendFailed) + except RequestSendFailed: raise SynapseError(502, "Failed to contact group server") - d.addErrback(http_response_errback) - d.addErrback(request_failed_errback) - return d - return f diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index ea026ed9f..2a6373937 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -121,8 +121,7 @@ class MatrixFederationRequest(object): return self.json -@defer.inlineCallbacks -def _handle_json_response(reactor, timeout_sec, request, response): +async def _handle_json_response(reactor, timeout_sec, request, response): """ Reads the JSON body of a response, with a timeout @@ -141,7 +140,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): d = treq.json_content(response) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) - body = yield make_deferred_yieldable(d) + body = await make_deferred_yieldable(d) except TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response", request.txn_id, request.destination, @@ -224,8 +223,7 @@ class MatrixFederationHttpClient(object): self._cooperator = Cooperator(scheduler=schedule) - @defer.inlineCallbacks - def _send_request_with_optional_trailing_slash( + async def _send_request_with_optional_trailing_slash( self, request, try_trailing_slash_on_400=False, **send_request_args ): """Wrapper for _send_request which can optionally retry the request @@ -246,10 +244,10 @@ class MatrixFederationHttpClient(object): (except 429). Returns: - Deferred[Dict]: Parsed JSON response body. + Dict: Parsed JSON response body. """ try: - response = yield self._send_request(request, **send_request_args) + response = await self._send_request(request, **send_request_args) except HttpResponseException as e: # Received an HTTP error > 300. Check if it meets the requirements # to retry with a trailing slash @@ -265,12 +263,11 @@ class MatrixFederationHttpClient(object): logger.info("Retrying request with trailing slash") request.path += "/" - response = yield self._send_request(request, **send_request_args) + response = await self._send_request(request, **send_request_args) return response - @defer.inlineCallbacks - def _send_request( + async def _send_request( self, request, retry_on_dns_fail=True, @@ -311,7 +308,7 @@ class MatrixFederationHttpClient(object): backoff_on_404 (bool): Back off if we get a 404 Returns: - Deferred[twisted.web.client.Response]: resolves with the HTTP + twisted.web.client.Response: resolves with the HTTP response object on success. Raises: @@ -335,7 +332,7 @@ class MatrixFederationHttpClient(object): ): raise FederationDeniedError(request.destination) - limiter = yield synapse.util.retryutils.get_retry_limiter( + limiter = await synapse.util.retryutils.get_retry_limiter( request.destination, self.clock, self._store, @@ -433,7 +430,7 @@ class MatrixFederationHttpClient(object): reactor=self.reactor, ) - response = yield request_deferred + response = await request_deferred except TimeoutError as e: raise RequestSendFailed(e, can_retry=True) from e except DNSLookupError as e: @@ -474,7 +471,7 @@ class MatrixFederationHttpClient(object): ) try: - body = yield make_deferred_yieldable(d) + body = await make_deferred_yieldable(d) except Exception as e: # Eh, we're already going to raise an exception so lets # ignore if this fails. @@ -528,7 +525,7 @@ class MatrixFederationHttpClient(object): delay, ) - yield self.clock.sleep(delay) + await self.clock.sleep(delay) retries_left -= 1 else: raise @@ -591,8 +588,7 @@ class MatrixFederationHttpClient(object): ) return auth_headers - @defer.inlineCallbacks - def put_json( + async def put_json( self, destination, path, @@ -636,7 +632,7 @@ class MatrixFederationHttpClient(object): enabled. Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -658,7 +654,7 @@ class MatrixFederationHttpClient(object): json=data, ) - response = yield self._send_request_with_optional_trailing_slash( + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, backoff_on_404=backoff_on_404, @@ -667,14 +663,13 @@ class MatrixFederationHttpClient(object): timeout=timeout, ) - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, self.default_timeout, request, response ) return body - @defer.inlineCallbacks - def post_json( + async def post_json( self, destination, path, @@ -707,7 +702,7 @@ class MatrixFederationHttpClient(object): args (dict): query params Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -725,7 +720,7 @@ class MatrixFederationHttpClient(object): method="POST", destination=destination, path=path, query=args, json=data ) - response = yield self._send_request( + response = await self._send_request( request, long_retries=long_retries, timeout=timeout, @@ -737,13 +732,12 @@ class MatrixFederationHttpClient(object): else: _sec_timeout = self.default_timeout - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, _sec_timeout, request, response ) return body - @defer.inlineCallbacks - def get_json( + async def get_json( self, destination, path, @@ -775,7 +769,7 @@ class MatrixFederationHttpClient(object): response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -792,7 +786,7 @@ class MatrixFederationHttpClient(object): method="GET", destination=destination, path=path, query=args ) - response = yield self._send_request_with_optional_trailing_slash( + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, backoff_on_404=False, @@ -801,14 +795,13 @@ class MatrixFederationHttpClient(object): timeout=timeout, ) - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, self.default_timeout, request, response ) return body - @defer.inlineCallbacks - def delete_json( + async def delete_json( self, destination, path, @@ -836,7 +829,7 @@ class MatrixFederationHttpClient(object): args (dict): query params Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -853,20 +846,19 @@ class MatrixFederationHttpClient(object): method="DELETE", destination=destination, path=path, query=args ) - response = yield self._send_request( + response = await self._send_request( request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, ) - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, self.default_timeout, request, response ) return body - @defer.inlineCallbacks - def get_file( + async def get_file( self, destination, path, @@ -886,7 +878,7 @@ class MatrixFederationHttpClient(object): and try the request anyway. Returns: - Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of + tuple[int, dict]: Resolves with an (int,dict) tuple of the file length and a dict of the response headers. Raises: @@ -903,7 +895,7 @@ class MatrixFederationHttpClient(object): method="GET", destination=destination, path=path, query=args ) - response = yield self._send_request( + response = await self._send_request( request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff ) @@ -912,7 +904,7 @@ class MatrixFederationHttpClient(object): try: d = _readBodyToFile(response, output_stream, max_size) d.addTimeout(self.default_timeout, self.reactor) - length = yield make_deferred_yieldable(d) + length = await make_deferred_yieldable(d) except Exception as e: logger.warning( "{%s} [%s] Error reading response: %s", diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index f9ce60992..e0ad8e8a7 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -102,11 +102,10 @@ class KeyringTestCase(unittest.HomeserverTestCase): } persp_deferred = defer.Deferred() - @defer.inlineCallbacks - def get_perspectives(**kwargs): + async def get_perspectives(**kwargs): self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): - yield persp_deferred + await persp_deferred return persp_resp self.http_client.post_json.side_effect = get_perspectives @@ -355,7 +354,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): } signedjson.sign.sign_json(response, SERVER_NAME, testkey) - def get_json(destination, path, **kwargs): + async def get_json(destination, path, **kwargs): self.assertEqual(destination, SERVER_NAME) self.assertEqual(path, "/_matrix/key/v2/server/key1") return response @@ -444,7 +443,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect a perspectives-server key query """ - def post_json(destination, path, data, **kwargs): + async def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) self.assertEqual(path, "/_matrix/key/v2/query") @@ -580,14 +579,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # remove the perspectives server's signature response = build_response() del response["signatures"][self.mock_perspective_server.server_name] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig") # remove the origin server's signature response = build_response() del response["signatures"][SERVER_NAME] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 5cd0510f0..b8ca11871 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): @@ -78,9 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -109,9 +110,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -147,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) # Artificially raise the complexity @@ -203,9 +204,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -233,9 +234,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index d1bd18da3..5f512ff8b 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): receipt = ReadReceipt( "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() mock_send_transaction.assert_not_called() diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 00bb77627..bc0c5aefd 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -16,8 +16,6 @@ from mock import Mock -from twisted.internet import defer - import synapse import synapse.api.errors from synapse.api.constants import EventTypes @@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room from synapse.types import RoomAlias, create_requester from tests import unittest +from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): @@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 4f1347cd2..d70e1fc60 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -138,7 +139,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_other_name(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index fff4f0cbf..ac598249e 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase): @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: - fetch_d = self.cl.get_json("testserv:8008", "foo/bar") + fetch_d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar") + ) # Nothing happened yet self.assertNoResult(fetch_d) @@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase): """ If the DNS lookup returns an error, it will bubble up. """ - d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + ) self.pump() f = self.failureResultOf(d) @@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_connection_refused(self): - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request - d = cl.get_json("internal:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000)) # Nothing happened yet self.assertNoResult(d) @@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase): # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request - d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + ) # Nothing has happened yet self.assertNoResult(d) @@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request - d = cl.post_json("fine:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) @@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase): request = MatrixFederationRequest( method="GET", destination="testserv:8008", path="foo/bar" ) - d = self.cl._send_request(request, timeout=10000) + d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000)) self.pump() @@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) + defer.ensureDeferred( + self.cl.post_json( + "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} + ) + ) self.pump() @@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase): def test_closes_connection(self): """Check that the client closes unused HTTP connections""" - d = self.cl.get_json("testserv:8008", "foo/bar") + d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) self.pump() diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 8d4dbf232..83f9aa291 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -16,8 +16,6 @@ import logging from mock import Mock -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource @@ -25,6 +23,7 @@ from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.test_utils import make_awaitable logger = logging.getLogger(__name__) @@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", @@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index b1a4decce..0f1144fe1 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.fetches = [] - def get_file(destination, path, output_stream, args=None, max_size=None): + async def get_file(destination, path, output_stream, args=None, max_size=None): """ Returns tuple[int,dict,str,int] of file length, response headers, absolute URI, and response code. @@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): d = Deferred() d.addCallback(write_to) self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) + return await make_deferred_yieldable(d) client = Mock() client.get_file = get_file diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 99eb47714..6850c666b 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect an outgoing GET request for the given key """ - def get_json(destination, path, ignore_backoff=False, **kwargs): + async def get_json(destination, path, ignore_backoff=False, **kwargs): self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 - def post_json(destination, path, data): + async def post_json(destination, path, data): self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, "/_matrix/key/v2/query", diff --git a/tests/test_federation.py b/tests/test_federation.py index 87a16d7d7..c2f12c274 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): prev_events that said event references. """ - def post_json(destination, path, data, headers=None, timeout=0): + async def post_json(destination, path, data, headers=None, timeout=0): # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} From 606805bf0646a487e234c4e63ab434805209816d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 Jul 2020 16:28:36 +0100 Subject: [PATCH 23/26] Fix typo in docs/workers.md (#7992) --- changelog.d/7992.doc | 1 + docs/workers.md | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7992.doc diff --git a/changelog.d/7992.doc b/changelog.d/7992.doc new file mode 100644 index 000000000..3368fb591 --- /dev/null +++ b/changelog.d/7992.doc @@ -0,0 +1 @@ +Fix typo in `docs/workers.md`. diff --git a/docs/workers.md b/docs/workers.md index 38bd758e5..05d438240 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -278,7 +278,7 @@ instance_map: host: localhost port: 8034 -streams_writers: +stream_writers: events: event_persister1 ``` From 0a7fb2471685cf4bccfd76f51ddd614a3aeb4536 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 30 Jul 2020 16:58:57 +0100 Subject: [PATCH 24/26] Fix invite rejection when we have no forward-extremeties (#7980) Thanks to some slightly overzealous cleanup in the `delete_old_current_state_events`, it's possible to end up with no `event_forward_extremities` in a room where we have outstanding local invites. The user would then get a "no create event in auth events" when trying to reject the invite. We can hack around it by using the dangling invite as the prev event. --- changelog.d/7980.bugfix | 1 + synapse/handlers/room_member.py | 29 +++++++++++++++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 changelog.d/7980.bugfix diff --git a/changelog.d/7980.bugfix b/changelog.d/7980.bugfix new file mode 100644 index 000000000..fa351b4b7 --- /dev/null +++ b/changelog.d/7980.bugfix @@ -0,0 +1 @@ +Fix "no create event in auth events" when trying to reject invitation after inviter leaves. Bug introduced in Synapse v1.10.0. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 5a40e8c14..78586a0a1 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -469,26 +469,39 @@ class RoomMemberHandler(object): user_id=target.to_string(), room_id=room_id ) # type: Optional[RoomsForUser] if not invite: + logger.info( + "%s sent a leave request to %s, but that is not an active room " + "on this server, and there is no pending invite", + target, + room_id, + ) + raise SynapseError(404, "Not a known room") logger.info( "%s rejects invite to %s from %s", target, room_id, invite.sender ) - if self.hs.is_mine_id(invite.sender): - # the inviter was on our server, but has now left. Carry on - # with the normal rejection codepath. - # - # This is a bit of a hack, because the room might still be - # active on other servers. - pass - else: + if not self.hs.is_mine_id(invite.sender): # send the rejection to the inviter's HS (with fallback to # local event) return await self.remote_reject_invite( invite.event_id, txn_id, requester, content, ) + # the inviter was on our server, but has now left. Carry on + # with the normal rejection codepath, which will also send the + # rejection out to any other servers we believe are still in the room. + + # thanks to overzealous cleaning up of event_forward_extremities in + # `delete_old_current_state_events`, it's possible to end up with no + # forward extremities here. If that happens, let's just hang the + # rejection off the invite event. + # + # see: https://github.com/matrix-org/synapse/issues/7139 + if len(latest_event_ids) == 0: + latest_event_ids = [invite.event_id] + return await self._local_membership_update( requester=requester, target=target, From 6d4b790021b1452da05443103b35f0e9fc3d846a Mon Sep 17 00:00:00 2001 From: Stuart Mumford Date: Thu, 30 Jul 2020 17:30:11 +0100 Subject: [PATCH 25/26] Update workers docs (#7990) --- changelog.d/7990.doc | 1 + docs/workers.md | 59 +++++++++++++++++++++++++------------------- 2 files changed, 35 insertions(+), 25 deletions(-) create mode 100644 changelog.d/7990.doc diff --git a/changelog.d/7990.doc b/changelog.d/7990.doc new file mode 100644 index 000000000..8d8fd926e --- /dev/null +++ b/changelog.d/7990.doc @@ -0,0 +1 @@ +Improve workers docs. diff --git a/docs/workers.md b/docs/workers.md index 05d438240..80b65a0ce 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -1,10 +1,10 @@ # Scaling synapse via workers -For small instances it recommended to run Synapse in monolith mode (the -default). For larger instances where performance is a concern it can be helpful -to split out functionality into multiple separate python processes. These -processes are called 'workers', and are (eventually) intended to scale -horizontally independently. +For small instances it recommended to run Synapse in the default monolith mode. +For larger instances where performance is a concern it can be helpful to split +out functionality into multiple separate python processes. These processes are +called 'workers', and are (eventually) intended to scale horizontally +independently. Synapse's worker support is under active development and subject to change as we attempt to rapidly scale ever larger Synapse instances. However we are @@ -23,29 +23,30 @@ The processes communicate with each other via a Synapse-specific protocol called feeds streams of newly written data between processes so they can be kept in sync with the database state. -Additionally, processes may make HTTP requests to each other. Typically this is -used for operations which need to wait for a reply - such as sending an event. +When configured to do so, Synapse uses a +[Redis pub/sub channel](https://redis.io/topics/pubsub) to send the replication +stream between all configured Synapse processes. Additionally, processes may +make HTTP requests to each other, primarily for operations which need to wait +for a reply ─ such as sending an event. -As of Synapse v1.13.0, it is possible to configure Synapse to send replication -via a [Redis pub/sub channel](https://redis.io/topics/pubsub), and is now the -recommended way of configuring replication. This is an alternative to the old -direct TCP connections to the main process: rather than all the workers -connecting to the main process, all the workers and the main process connect to -Redis, which relays replication commands between processes. This can give a -significant cpu saving on the main process and will be a prerequisite for -upcoming performance improvements. +Redis support was added in v1.13.0 with it becoming the recommended method in +v1.18.0. It replaced the old direct TCP connections (which is deprecated as of +v1.18.0) to the main process. With Redis, rather than all the workers connecting +to the main process, all the workers and the main process connect to Redis, +which relays replication commands between processes. This can give a significant +cpu saving on the main process and will be a prerequisite for upcoming +performance improvements. -(See the [Architectural diagram](#architectural-diagram) section at the end for -a visualisation of what this looks like) +See the [Architectural diagram](#architectural-diagram) section at the end for +a visualisation of what this looks like. ## Setting up workers A Redis server is required to manage the communication between the processes. -(The older direct TCP connections are now deprecated.) The Redis server -should be installed following the normal procedure for your distribution (e.g. -`apt install redis-server` on Debian). It is safe to use an existing Redis -deployment if you have one. +The Redis server should be installed following the normal procedure for your +distribution (e.g. `apt install redis-server` on Debian). It is safe to use an +existing Redis deployment if you have one. Once installed, check that Redis is running and accessible from the host running Synapse, for example by executing `echo PING | nc -q1 localhost 6379` and seeing @@ -65,8 +66,9 @@ https://hub.docker.com/r/matrixdotorg/synapse/. To make effective use of the workers, you will need to configure an HTTP reverse-proxy such as nginx or haproxy, which will direct incoming requests to -the correct worker, or to the main synapse instance. See [reverse_proxy.md](reverse_proxy.md) -for information on setting up a reverse proxy. +the correct worker, or to the main synapse instance. See +[reverse_proxy.md](reverse_proxy.md) for information on setting up a reverse +proxy. To enable workers you should create a configuration file for each worker process. Each worker configuration file inherits the configuration of the shared @@ -75,8 +77,12 @@ that worker, e.g. the HTTP listener that it provides (if any); logging configuration; etc. You should minimise the number of overrides though to maintain a usable config. -Next you need to add both a HTTP replication listener and redis config to the -shared Synapse configuration file (`homeserver.yaml`). For example: + +### Shared Configuration + +Next you need to add both a HTTP replication listener, used for HTTP requests +between processes, and redis config to the shared Synapse configuration file +(`homeserver.yaml`). For example: ```yaml # extend the existing `listeners` section. This defines the ports that the @@ -98,6 +104,9 @@ See the sample config for the full documentation of each option. Under **no circumstances** should the replication listener be exposed to the public internet; it has no authentication and is unencrypted. + +### Worker Configuration + In the config file for each worker, you must specify the type of worker application (`worker_app`), and you should specify a unqiue name for the worker (`worker_name`). The currently available worker applications are listed below. From e2a4ba6f9baa82d9142e0662871f5e6bfcb3d538 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Thu, 30 Jul 2020 21:41:44 -0600 Subject: [PATCH 26/26] Add docs for undoing room shutdowns (#7998) These docs were tested successfully in production by a customer, so it's probably fine. --- changelog.d/7998.doc | 1 + docs/admin_api/shutdown_room.md | 22 +++++++++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7998.doc diff --git a/changelog.d/7998.doc b/changelog.d/7998.doc new file mode 100644 index 000000000..fc8b3f0c3 --- /dev/null +++ b/changelog.d/7998.doc @@ -0,0 +1 @@ +Add documentation for how to undo a room shutdown. diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md index 808caeec7..2ff552bcb 100644 --- a/docs/admin_api/shutdown_room.md +++ b/docs/admin_api/shutdown_room.md @@ -33,7 +33,7 @@ You will need to authenticate with an access token for an admin user. * `message` - Optional. A string containing the first message that will be sent as `new_room_user_id` in the new room. Ideally this will clearly convey why the original room was shut down. - + If not specified, the default value of `room_name` is "Content Violation Notification". The default value of `message` is "Sharing illegal content on othis server is not permitted and rooms in violation will be blocked." @@ -72,3 +72,23 @@ Response: "new_room_id": "!newroomid:example.com", }, ``` + +## Undoing room shutdowns + +*Note*: This guide may be outdated by the time you read it. By nature of room shutdowns being performed at the database level, +the structure can and does change without notice. + +First, it's important to understand that a room shutdown is very destructive. Undoing a shutdown is not as simple as pretending it +never happened - work has to be done to move forward instead of resetting the past. + +1. For safety reasons, it is recommended to shut down Synapse prior to continuing. +2. In the database, run `DELETE FROM blocked_rooms WHERE room_id = '!example:example.org';` + * For caution: it's recommended to run this in a transaction: `BEGIN; DELETE ...;`, verify you got 1 result, then `COMMIT;`. + * The room ID is the same one supplied to the shutdown room API, not the Content Violation room. +3. Restart Synapse (required). + +You will have to manually handle, if you so choose, the following: + +* Aliases that would have been redirected to the Content Violation room. +* Users that would have been booted from the room (and will have been force-joined to the Content Violation room). +* Removal of the Content Violation room if desired.