diff --git a/changelog.d/8202.misc b/changelog.d/8202.misc new file mode 100644 index 000000000..dfe4c0317 --- /dev/null +++ b/changelog.d/8202.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index ac3418d69..5a1aa7d83 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -14,15 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Any, Dict, Optional from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError +from synapse.api.filtering import Filter from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.state import StateFilter -from synapse.types import RoomStreamToken +from synapse.streams.config import PaginationConfig +from synapse.types import Requester, RoomStreamToken from synapse.util.async_helpers import ReadWriteLock from synapse.util.stringutils import random_string from synapse.visibility import filter_events_for_client @@ -247,15 +250,16 @@ class PaginationHandler(object): ) return purge_id - async def _purge_history(self, purge_id, room_id, token, delete_local_events): + async def _purge_history( + self, purge_id: str, room_id: str, token: str, delete_local_events: bool + ) -> None: """Carry out a history purge on a room. Args: - purge_id (str): The id for this purge - room_id (str): The room to purge from - token (str): topological token to delete events before - delete_local_events (bool): True to delete local events as well as - remote ones + purge_id: The id for this purge + room_id: The room to purge from + token: topological token to delete events before + delete_local_events: True to delete local events as well as remote ones """ self._purges_in_progress_by_room.add(room_id) try: @@ -291,9 +295,9 @@ class PaginationHandler(object): """ return self._purges_by_id.get(purge_id) - async def purge_room(self, room_id): + async def purge_room(self, room_id: str) -> None: """Purge the given room from the database""" - with (await self.pagination_lock.write(room_id)): + with await self.pagination_lock.write(room_id): # check we know about the room await self.store.get_room_version_id(room_id) @@ -307,23 +311,22 @@ class PaginationHandler(object): async def get_messages( self, - requester, - room_id=None, - pagin_config=None, - as_client_event=True, - event_filter=None, - ): + requester: Requester, + room_id: Optional[str] = None, + pagin_config: Optional[PaginationConfig] = None, + as_client_event: bool = True, + event_filter: Optional[Filter] = None, + ) -> Dict[str, Any]: """Get messages in a room. Args: - requester (Requester): The user requesting messages. - room_id (str): The room they want messages from. - pagin_config (synapse.api.streams.PaginationConfig): The pagination - config rules to apply, if any. - as_client_event (bool): True to get events in client-server format. - event_filter (Filter): Filter to apply to results or None + requester: The user requesting messages. + room_id: The room they want messages from. + pagin_config: The pagination config rules to apply, if any. + as_client_event: True to get events in client-server format. + event_filter: Filter to apply to results or None Returns: - dict: Pagination API results + Pagination API results """ user_id = requester.user.to_string() @@ -343,7 +346,7 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") - with (await self.pagination_lock.read(room_id)): + with await self.pagination_lock.read(room_id): ( membership, member_event_id, diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index f56277092..dfefbd996 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -20,6 +20,7 @@ from contextlib import contextmanager from typing import Dict, Sequence, Set, Union import attr +from typing_extensions import ContextManager from twisted.internet import defer from twisted.internet.defer import CancelledError @@ -338,11 +339,11 @@ class Linearizer(object): class ReadWriteLock(object): - """A deferred style read write lock. + """An async read write lock. Example: - with (yield read_write_lock.read("test_key")): + with await read_write_lock.read("test_key"): # do some work """ @@ -365,8 +366,7 @@ class ReadWriteLock(object): # Latest writer queued self.key_to_current_writer = {} # type: Dict[str, defer.Deferred] - @defer.inlineCallbacks - def read(self, key): + async def read(self, key: str) -> ContextManager: new_defer = defer.Deferred() curr_readers = self.key_to_current_readers.setdefault(key, set()) @@ -376,7 +376,8 @@ class ReadWriteLock(object): # We wait for the latest writer to finish writing. We can safely ignore # any existing readers... as they're readers. - yield make_deferred_yieldable(curr_writer) + if curr_writer: + await make_deferred_yieldable(curr_writer) @contextmanager def _ctx_manager(): @@ -388,8 +389,7 @@ class ReadWriteLock(object): return _ctx_manager() - @defer.inlineCallbacks - def write(self, key): + async def write(self, key: str) -> ContextManager: new_defer = defer.Deferred() curr_readers = self.key_to_current_readers.get(key, set()) @@ -405,7 +405,7 @@ class ReadWriteLock(object): curr_readers.clear() self.key_to_current_writer[key] = new_defer - yield make_deferred_yieldable(defer.gatherResults(to_wait_on)) + await make_deferred_yieldable(defer.gatherResults(to_wait_on)) @contextmanager def _ctx_manager(): diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index bd32e2cee..d3dea3b52 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer from synapse.util.async_helpers import ReadWriteLock @@ -43,6 +44,7 @@ class ReadWriteLockTestCase(unittest.TestCase): rwlock.read(key), # 5 rwlock.write(key), # 6 ] + ds = [defer.ensureDeferred(d) for d in ds] self._assert_called_before_not_after(ds, 2) @@ -73,12 +75,12 @@ class ReadWriteLockTestCase(unittest.TestCase): with ds[6].result: pass - d = rwlock.write(key) + d = defer.ensureDeferred(rwlock.write(key)) self.assertTrue(d.called) with d.result: pass - d = rwlock.read(key) + d = defer.ensureDeferred(rwlock.read(key)) self.assertTrue(d.called) with d.result: pass