diff --git a/changelog.d/7678.misc b/changelog.d/7678.misc new file mode 100644 index 000000000..ab612200c --- /dev/null +++ b/changelog.d/7678.misc @@ -0,0 +1 @@ +Convert the device message and pagination handlers to async/await. diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 05c4b3eec..610b08d00 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -18,8 +18,6 @@ from typing import Any, Dict from canonicaljson import json -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( @@ -51,8 +49,7 @@ class DeviceMessageHandler(object): self._device_list_updater = hs.get_device_handler().device_list_updater - @defer.inlineCallbacks - def on_direct_to_device_edu(self, origin, content): + async def on_direct_to_device_edu(self, origin, content): local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): @@ -82,11 +79,11 @@ class DeviceMessageHandler(object): } local_messages[user_id] = messages_by_device - yield self._check_for_unknown_devices( + await self._check_for_unknown_devices( message_type, sender_user_id, by_device ) - stream_id = yield self.store.add_messages_from_remote_to_device_inbox( + stream_id = await self.store.add_messages_from_remote_to_device_inbox( origin, message_id, local_messages ) @@ -94,14 +91,13 @@ class DeviceMessageHandler(object): "to_device_key", stream_id, users=local_messages.keys() ) - @defer.inlineCallbacks - def _check_for_unknown_devices( + async def _check_for_unknown_devices( self, message_type: str, sender_user_id: str, by_device: Dict[str, Dict[str, Any]], ): - """Checks inbound device messages for unkown remote devices, and if + """Checks inbound device messages for unknown remote devices, and if found marks the remote cache for the user as stale. """ @@ -115,7 +111,7 @@ class DeviceMessageHandler(object): requesting_device_ids.add(device_id) # Check if we are tracking the devices of the remote user. - room_ids = yield self.store.get_rooms_for_user(sender_user_id) + room_ids = await self.store.get_rooms_for_user(sender_user_id) if not room_ids: logger.info( "Received device message from remote device we don't" @@ -127,7 +123,7 @@ class DeviceMessageHandler(object): # If we are tracking check that we know about the sending # devices. - cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id) + cached_devices = await self.store.get_cached_devices_for_user(sender_user_id) unknown_devices = requesting_device_ids - set(cached_devices) if unknown_devices: @@ -136,15 +132,14 @@ class DeviceMessageHandler(object): sender_user_id, unknown_devices, ) - yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id) + await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) # Immediately attempt a resync in the background run_in_background( self._device_list_updater.user_device_resync, sender_user_id ) - @defer.inlineCallbacks - def send_device_message(self, sender_user_id, message_type, messages): + async def send_device_message(self, sender_user_id, message_type, messages): set_tag("number_of_messages", len(messages)) set_tag("sender", sender_user_id) local_messages = {} @@ -183,7 +178,7 @@ class DeviceMessageHandler(object): } log_kv({"local_messages": local_messages}) - stream_id = yield self.store.add_messages_to_device_inbox( + stream_id = await self.store.add_messages_to_device_inbox( local_messages, remote_edu_contents ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 7fbc22950..da06582d4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -15,7 +15,6 @@ # limitations under the License. import logging -from twisted.internet import defer from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership @@ -97,8 +96,7 @@ class PaginationHandler(object): job["longest_max_lifetime"], ) - @defer.inlineCallbacks - def purge_history_for_rooms_in_range(self, min_ms, max_ms): + async def purge_history_for_rooms_in_range(self, min_ms, max_ms): """Purge outdated events from rooms within the given retention range. If a default retention policy is defined in the server's configuration and its @@ -137,7 +135,7 @@ class PaginationHandler(object): include_null, ) - rooms = yield self.store.get_rooms_for_retention_period_in_range( + rooms = await self.store.get_rooms_for_retention_period_in_range( min_ms, max_ms, include_null ) @@ -165,9 +163,9 @@ class PaginationHandler(object): # Figure out what token we should start purging at. ts = self.clock.time_msec() - max_lifetime - stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) + stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts) - r = yield self.store.get_room_event_before_stream_ordering( + r = await self.store.get_room_event_before_stream_ordering( room_id, stream_ordering, ) if not r: @@ -227,8 +225,7 @@ class PaginationHandler(object): ) return purge_id - @defer.inlineCallbacks - def _purge_history(self, purge_id, room_id, token, delete_local_events): + async def _purge_history(self, purge_id, room_id, token, delete_local_events): """Carry out a history purge on a room. Args: @@ -237,14 +234,11 @@ class PaginationHandler(object): token (str): topological token to delete events before delete_local_events (bool): True to delete local events as well as remote ones - - Returns: - Deferred """ self._purges_in_progress_by_room.add(room_id) try: - with (yield self.pagination_lock.write(room_id)): - yield self.storage.purge_events.purge_history( + with await self.pagination_lock.write(room_id): + await self.storage.purge_events.purge_history( room_id, token, delete_local_events ) logger.info("[purge] complete") @@ -282,9 +276,7 @@ class PaginationHandler(object): await self.store.get_room_version_id(room_id) # first check that we have no users in this room - joined = await defer.maybeDeferred( - self.store.is_host_joined, room_id, self._server_name - ) + joined = await self.store.is_host_joined(room_id, self._server_name) if joined: raise SynapseError(400, "Users are still joined to this room")