diff --git a/changelog.d/12357.misc b/changelog.d/12357.misc new file mode 100644 index 000000000..d571ae034 --- /dev/null +++ b/changelog.d/12357.misc @@ -0,0 +1 @@ +Refactor `Linearizer`, convert methods to async and use an async context manager. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index c7400c737..69d833585 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -188,7 +188,7 @@ class FederationServer(FederationBase): async def on_backfill_request( self, origin: str, room_id: str, versions: List[str], limit: int ) -> Tuple[int, Dict[str, Any]]: - with (await self._server_linearizer.queue((origin, room_id))): + async with self._server_linearizer.queue((origin, room_id)): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -218,7 +218,7 @@ class FederationServer(FederationBase): Tuple indicating the response status code and dictionary response body including `event_id`. """ - with (await self._server_linearizer.queue((origin, room_id))): + async with self._server_linearizer.queue((origin, room_id)): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -529,7 +529,7 @@ class FederationServer(FederationBase): # in the cache so we could return it without waiting for the linearizer # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. - with (await self._server_linearizer.queue((origin, room_id))): + async with self._server_linearizer.queue((origin, room_id)): resp: JsonDict = dict( await self._state_resp_cache.wrap( (room_id, event_id), @@ -883,7 +883,7 @@ class FederationServer(FederationBase): async def on_event_auth( self, origin: str, room_id: str, event_id: str ) -> Tuple[int, Dict[str, Any]]: - with (await self._server_linearizer.queue((origin, room_id))): + async with self._server_linearizer.queue((origin, room_id)): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) @@ -945,7 +945,7 @@ class FederationServer(FederationBase): latest_events: List[str], limit: int, ) -> Dict[str, list]: - with (await self._server_linearizer.queue((origin, room_id))): + async with self._server_linearizer.queue((origin, room_id)): origin_host, _ = parse_server_name(origin) await self.check_server_matches_acl(origin_host, room_id) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 316c4b677..1b5784050 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -330,10 +330,8 @@ class ApplicationServicesHandler: continue # Since we read/update the stream position for this AS/stream - with ( - await self._ephemeral_events_linearizer.queue( - (service.id, stream_key) - ) + async with self._ephemeral_events_linearizer.queue( + (service.id, stream_key) ): if stream_key == "receipt_key": events = await self._handle_receipts(service, new_token) diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index c710c02cf..ffa28b2a3 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -833,7 +833,7 @@ class DeviceListUpdater: async def _handle_device_updates(self, user_id: str) -> None: "Actually handle pending updates." - with (await self._remote_edu_linearizer.queue(user_id)): + async with self._remote_edu_linearizer.queue(user_id): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d96456cd4..d6714228e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -118,7 +118,7 @@ class E2eKeysHandler: from_device_id: the device making the query. This is used to limit the number of in-flight queries at a time. """ - with await self._query_devices_linearizer.queue((from_user_id, from_device_id)): + async with self._query_devices_linearizer.queue((from_user_id, from_device_id)): device_keys_query: Dict[str, Iterable[str]] = query_body.get( "device_keys", {} ) @@ -1386,7 +1386,7 @@ class SigningKeyEduUpdater: device_handler = self.e2e_keys_handler.device_handler device_list_updater = device_handler.device_list_updater - with (await self._remote_edu_linearizer.queue(user_id)): + async with self._remote_edu_linearizer.queue(user_id): pending_updates = self._pending_updates.pop(user_id, []) if not pending_updates: # This can happen since we batch updates diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 52e44a2d4..446f509bd 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -83,7 +83,7 @@ class E2eRoomKeysHandler: # we deliberately take the lock to get keys so that changing the version # works atomically - with (await self._upload_linearizer.queue(user_id)): + async with self._upload_linearizer.queue(user_id): # make sure the backup version exists try: await self.store.get_e2e_room_keys_version_info(user_id, version) @@ -126,7 +126,7 @@ class E2eRoomKeysHandler: """ # lock for consistency with uploading - with (await self._upload_linearizer.queue(user_id)): + async with self._upload_linearizer.queue(user_id): # make sure the backup version exists try: version_info = await self.store.get_e2e_room_keys_version_info( @@ -187,7 +187,7 @@ class E2eRoomKeysHandler: # TODO: Validate the JSON to make sure it has the right keys. # XXX: perhaps we should use a finer grained lock here? - with (await self._upload_linearizer.queue(user_id)): + async with self._upload_linearizer.queue(user_id): # Check that the version we're trying to upload is the current version try: @@ -332,7 +332,7 @@ class E2eRoomKeysHandler: # TODO: Validate the JSON to make sure it has the right keys. # lock everyone out until we've switched version - with (await self._upload_linearizer.queue(user_id)): + async with self._upload_linearizer.queue(user_id): new_version = await self.store.create_e2e_room_keys_version( user_id, version_info ) @@ -359,7 +359,7 @@ class E2eRoomKeysHandler: } """ - with (await self._upload_linearizer.queue(user_id)): + async with self._upload_linearizer.queue(user_id): try: res = await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: @@ -383,7 +383,7 @@ class E2eRoomKeysHandler: NotFoundError: if this backup version doesn't exist """ - with (await self._upload_linearizer.queue(user_id)): + async with self._upload_linearizer.queue(user_id): try: await self.store.delete_e2e_room_keys_version(user_id, version) except StoreError as e: @@ -413,7 +413,7 @@ class E2eRoomKeysHandler: raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) - with (await self._upload_linearizer.queue(user_id)): + async with self._upload_linearizer.queue(user_id): try: old_info = await self.store.get_e2e_room_keys_version_info( user_id, version diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 350ec9c03..78d149905 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -151,7 +151,7 @@ class FederationHandler: return. This is used as part of the heuristic to decide if we should back paginate. """ - with (await self._room_backfill.queue(room_id)): + async with self._room_backfill.queue(room_id): return await self._maybe_backfill_inner(room_id, current_depth, limit) async def _maybe_backfill_inner( diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index e7b9f15e1..03c1197c9 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -224,7 +224,7 @@ class FederationEventHandler: len(missing_prevs), shortstr(missing_prevs), ) - with (await self._room_pdu_linearizer.queue(pdu.room_id)): + async with self._room_pdu_linearizer.queue(pdu.room_id): logger.info( "Acquired room lock to fetch %d missing prev_events", len(missing_prevs), diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 766f597a5..7db6905c6 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -851,7 +851,7 @@ class EventCreationHandler: # a situation where event persistence can't keep up, causing # extremities to pile up, which in turn leads to state resolution # taking longer. - with (await self.limiter.queue(event_dict["room_id"])): + async with self.limiter.queue(event_dict["room_id"]): if txn_id and requester.access_token_id: existing_event_id = await self.store.get_event_id_from_transaction_id( event_dict["room_id"], diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index dace31d87..209a4b0e5 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1030,7 +1030,7 @@ class PresenceHandler(BasePresenceHandler): is_syncing: Whether or not the user is now syncing sync_time_msec: Time in ms when the user was last syncing """ - with (await self.external_sync_linearizer.queue(process_id)): + async with self.external_sync_linearizer.queue(process_id): prev_state = await self.current_state_for_user(user_id) process_presence = self.external_process_to_current_syncs.setdefault( @@ -1071,7 +1071,7 @@ class PresenceHandler(BasePresenceHandler): Used when the process has stopped/disappeared. """ - with (await self.external_sync_linearizer.queue(process_id)): + async with self.external_sync_linearizer.queue(process_id): process_presence = self.external_process_to_current_syncs.pop( process_id, set() ) diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index bad1acc63..05122fd5a 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -40,7 +40,7 @@ class ReadMarkerHandler: the read marker has changed. """ - with await self.read_marker_linearizer.queue((room_id, user_id)): + async with self.read_marker_linearizer.queue((room_id, user_id)): existing_read_marker = await self.store.get_account_data_for_room_and_type( user_id, room_id, "m.fully_read" ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 51a08fd2c..65d4aea9a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -883,7 +883,7 @@ class RoomCreationHandler: # # we also don't need to check the requester's shadow-ban here, as we # have already done so above (and potentially emptied invite_list). - with (await self.room_member_handler.member_linearizer.queue((room_id,))): + async with self.room_member_handler.member_linearizer.queue((room_id,)): content = {} is_direct = config.get("is_direct", None) if is_direct: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0785e3111..802e57c4d 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -515,8 +515,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # We first linearise by the application service (to try to limit concurrent joins # by application services), and then by room ID. - with (await self.member_as_limiter.queue(as_id)): - with (await self.member_linearizer.queue(key)): + async with self.member_as_limiter.queue(as_id): + async with self.member_linearizer.queue(key): result = await self.update_membership_locked( requester, target, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 4f02a060d..e4fe94e55 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -430,7 +430,7 @@ class SsoHandler: # grab a lock while we try to find a mapping for this user. This seems... # optimistic, especially for implementations that end up redirecting to # interstitial pages. - with await self._mapping_lock.queue(auth_provider_id): + async with self._mapping_lock.queue(auth_provider_id): # first of all, check if we already have a mapping for this user user_id = await self.get_sso_user_by_remote_user_id( auth_provider_id, diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index a402a3e40..b07cf2eee 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -397,7 +397,7 @@ class RulesForRoom: self.room_push_rule_cache_metrics.inc_hits() return self.data.rules_by_user - with (await self.linearizer.queue(self.room_id)): + async with self.linearizer.queue(self.room_id): if state_group and self.data.state_group == state_group: logger.debug("Using cached rules for %r", self.room_id) self.room_push_rule_cache_metrics.inc_hits() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index deeaaec4e..122892c7b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -451,7 +451,7 @@ class FederationSenderHandler: # service for robustness? Or could we replace it with an assertion that # we're not being re-entered? - with (await self._fed_position_linearizer.queue(None)): + async with self._fed_position_linearizer.queue(None): # We persist and ack the same position, so we take a copy of it # here as otherwise it can get modified from underneath us. current_position = self.federation_position diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 6c414402b..3e5d6c629 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -258,7 +258,7 @@ class MediaRepository: # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) - with (await self.remote_media_linearizer.queue(key)): + async with self.remote_media_linearizer.queue(key): responder, media_info = await self._get_remote_media_impl( server_name, media_id ) @@ -294,7 +294,7 @@ class MediaRepository: # We linearize here to ensure that we don't try and download remote # media multiple times concurrently key = (server_name, media_id) - with (await self.remote_media_linearizer.queue(key)): + async with self.remote_media_linearizer.queue(key): responder, media_info = await self._get_remote_media_impl( server_name, media_id ) @@ -850,7 +850,7 @@ class MediaRepository: # TODO: Should we delete from the backup store - with (await self.remote_media_linearizer.queue(key)): + async with self.remote_media_linearizer.queue(key): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: os.remove(full_path) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 21888cc8c..fbf7ba460 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -573,7 +573,7 @@ class StateResolutionHandler: """ group_names = frozenset(state_groups_ids.keys()) - with (await self.resolve_linearizer.queue(group_names)): + async with self.resolve_linearizer.queue(group_names): cache = self._state_cache.get(group_names, None) if cache: return cache diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 98d09b373..48e83592e 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -888,7 +888,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return frozenset(cache.hosts_to_joined_users) # Since we'll mutate the cache we need to lock. - with (await self._joined_host_linearizer.queue(room_id)): + async with self._joined_host_linearizer.queue(room_id): if state_entry.state_group == cache.state_group: # Same state group, so nothing to do. We've already checked for # this above, but the cache may have changed while waiting on diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 6a8e844d6..4b2a16a6a 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -18,7 +18,7 @@ import collections import inspect import itertools import logging -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager from typing import ( Any, AsyncIterator, @@ -29,7 +29,6 @@ from typing import ( Generic, Hashable, Iterable, - Iterator, List, Optional, Set, @@ -342,7 +341,7 @@ class Linearizer: Example: - with await limiter.queue("test_key"): + async with limiter.queue("test_key"): # do some work. """ @@ -383,95 +382,53 @@ class Linearizer: # non-empty. return bool(entry.deferreds) - def queue(self, key: Hashable) -> defer.Deferred: - # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. - # (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not - # propagated inside inlineCallbacks until Twisted 18.7) + def queue(self, key: Hashable) -> AsyncContextManager[None]: + @asynccontextmanager + async def _ctx_manager() -> AsyncIterator[None]: + entry = await self._acquire_lock(key) + try: + yield + finally: + self._release_lock(key, entry) + + return _ctx_manager() + + async def _acquire_lock(self, key: Hashable) -> _LinearizerEntry: + """Acquires a linearizer lock, waiting if necessary. + + Returns once we have secured the lock. + """ entry = self.key_to_defer.setdefault( key, _LinearizerEntry(0, collections.OrderedDict()) ) - # If the number of things executing is greater than the maximum - # then add a deferred to the list of blocked items - # When one of the things currently executing finishes it will callback - # this item so that it can continue executing. - if entry.count >= self.max_count: - res = self._await_lock(key) - else: + if entry.count < self.max_count: + # The number of things executing is less than the maximum. logger.debug( "Acquired uncontended linearizer lock %r for key %r", self.name, key ) entry.count += 1 - res = defer.succeed(None) - - # once we successfully get the lock, we need to return a context manager which - # will release the lock. - - @contextmanager - def _ctx_manager(_: None) -> Iterator[None]: - try: - yield - finally: - logger.debug("Releasing linearizer lock %r for key %r", self.name, key) - - # We've finished executing so check if there are any things - # blocked waiting to execute and start one of them - entry.count -= 1 - - if entry.deferreds: - (next_def, _) = entry.deferreds.popitem(last=False) - - # we need to run the next thing in the sentinel context. - with PreserveLoggingContext(): - next_def.callback(None) - elif entry.count == 0: - # We were the last thing for this key: remove it from the - # map. - del self.key_to_defer[key] - - res.addCallback(_ctx_manager) - return res - - def _await_lock(self, key: Hashable) -> defer.Deferred: - """Helper for queue: adds a deferred to the queue - - Assumes that we've already checked that we've reached the limit of the number - of lock-holders we allow. Creates a new deferred which is added to the list, and - adds some management around cancellations. - - Returns the deferred, which will callback once we have secured the lock. - - """ - entry = self.key_to_defer[key] + return entry + # Otherwise, the number of things executing is at the maximum and we have to + # add a deferred to the list of blocked items. + # When one of the things currently executing finishes it will callback + # this item so that it can continue executing. logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key) new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred()) entry.deferreds[new_defer] = 1 - def cb(_r: None) -> "defer.Deferred[None]": - logger.debug("Acquired linearizer lock %r for key %r", self.name, key) - entry.count += 1 - - # if the code holding the lock completes synchronously, then it - # will recursively run the next claimant on the list. That can - # relatively rapidly lead to stack exhaustion. This is essentially - # the same problem as http://twistedmatrix.com/trac/ticket/9304. - # - # In order to break the cycle, we add a cheeky sleep(0) here to - # ensure that we fall back to the reactor between each iteration. - # - # (This needs to happen while we hold the lock, and the context manager's exit - # code must be synchronous, so this is the only sensible place.) - return self._clock.sleep(0) - - def eb(e: Failure) -> Failure: + try: + await new_defer + except Exception as e: logger.info("defer %r got err %r", new_defer, e) if isinstance(e, CancelledError): logger.debug( - "Cancelling wait for linearizer lock %r for key %r", self.name, key + "Cancelling wait for linearizer lock %r for key %r", + self.name, + key, ) - else: logger.warning( "Unexpected exception waiting for linearizer lock %r for key %r", @@ -481,10 +438,43 @@ class Linearizer: # we just have to take ourselves back out of the queue. del entry.deferreds[new_defer] - return e + raise - new_defer.addCallbacks(cb, eb) - return new_defer + logger.debug("Acquired linearizer lock %r for key %r", self.name, key) + entry.count += 1 + + # if the code holding the lock completes synchronously, then it + # will recursively run the next claimant on the list. That can + # relatively rapidly lead to stack exhaustion. This is essentially + # the same problem as http://twistedmatrix.com/trac/ticket/9304. + # + # In order to break the cycle, we add a cheeky sleep(0) here to + # ensure that we fall back to the reactor between each iteration. + # + # This needs to happen while we hold the lock. We could put it on the + # exit path, but that would slow down the uncontended case. + await self._clock.sleep(0) + + return entry + + def _release_lock(self, key: Hashable, entry: _LinearizerEntry) -> None: + """Releases a held linearizer lock.""" + logger.debug("Releasing linearizer lock %r for key %r", self.name, key) + + # We've finished executing so check if there are any things + # blocked waiting to execute and start one of them + entry.count -= 1 + + if entry.deferreds: + (next_def, _) = entry.deferreds.popitem(last=False) + + # we need to run the next thing in the sentinel context. + with PreserveLoggingContext(): + next_def.callback(None) + elif entry.count == 0: + # We were the last thing for this key: remove it from the + # map. + del self.key_to_defer[key] class ReadWriteLock: diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index fa132391a..c2a209e63 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -46,7 +46,7 @@ class LinearizerTestCase(unittest.TestCase): unblock_d: "Deferred[None]" = Deferred() async def task() -> None: - with await linearizer.queue(key): + async with linearizer.queue(key): acquired_d.callback(None) await unblock_d @@ -125,7 +125,7 @@ class LinearizerTestCase(unittest.TestCase): async def func(i: int) -> None: with LoggingContext("func(%s)" % i) as lc: - with (await linearizer.queue(key)): + async with linearizer.queue(key): self.assertEqual(current_context(), lc) self.assertEqual(current_context(), lc)