# # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright (C) 2023 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # See the GNU Affero General Public License for more details: # . # # Originally licensed under the Apache License, Version 2.0: # . # # [This file includes modifications made by New Vector Limited] # # import random from types import TracebackType from typing import ( TYPE_CHECKING, AsyncContextManager, Collection, Dict, Optional, Tuple, Type, Union, ) from weakref import WeakSet import attr from twisted.internet import defer from twisted.internet.interfaces import IReactorTime from synapse.logging.context import PreserveLoggingContext from synapse.logging.opentracing import start_active_span from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.databases.main.lock import Lock, LockStore from synapse.util.async_helpers import timeout_deferred if TYPE_CHECKING: from synapse.logging.opentracing import opentracing from synapse.server import HomeServer # This lock is used to avoid creating an event while we are purging the room. # We take a read lock when creating an event, and a write one when purging a room. # This is because it is fine to create several events concurrently, since referenced events # will not disappear under our feet as long as we don't delete the room. NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock" class WorkerLocksHandler: """A class for waiting on taking out locks, rather than using the storage functions directly (which don't support awaiting). """ def __init__(self, hs: "HomeServer") -> None: self._reactor = hs.get_reactor() self._store = hs.get_datastores().main self._clock = hs.get_clock() self._notifier = hs.get_notifier() self._instance_name = hs.get_instance_name() # Map from lock name/key to set of `WaitingLock` that are active for # that lock. self._locks: Dict[ Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]] ] = {} self._clock.looping_call(self._cleanup_locks, 30_000) self._notifier.add_lock_released_callback(self._on_lock_released) def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock": """Acquire a standard lock, returns a context manager that will block until the lock is acquired. Note: Care must be taken to avoid deadlocks. In particular, this function does *not* timeout. Usage: async with handler.acquire_lock(name, key): # Do work while holding the lock... """ lock = WaitingLock( reactor=self._reactor, store=self._store, handler=self, lock_name=lock_name, lock_key=lock_key, write=None, ) self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) return lock def acquire_read_write_lock( self, lock_name: str, lock_key: str, *, write: bool, ) -> "WaitingLock": """Acquire a read/write lock, returns a context manager that will block until the lock is acquired. Note: Care must be taken to avoid deadlocks. In particular, this function does *not* timeout. Usage: async with handler.acquire_read_write_lock(name, key, write=True): # Do work while holding the lock... """ lock = WaitingLock( reactor=self._reactor, store=self._store, handler=self, lock_name=lock_name, lock_key=lock_key, write=write, ) self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) return lock def acquire_multi_read_write_lock( self, lock_names: Collection[Tuple[str, str]], *, write: bool, ) -> "WaitingMultiLock": """Acquires multi read/write locks at once, returns a context manager that will block until all the locks are acquired. This will try and acquire all locks at once, and will never hold on to a subset of the locks. (This avoids accidentally creating deadlocks). Note: Care must be taken to avoid deadlocks. In particular, this function does *not* timeout. """ lock = WaitingMultiLock( lock_names=lock_names, write=write, reactor=self._reactor, store=self._store, handler=self, ) for lock_name, lock_key in lock_names: self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) return lock def notify_lock_released(self, lock_name: str, lock_key: str) -> None: """Notify that a lock has been released. Pokes both the notifier and replication. """ self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key) def _on_lock_released( self, instance_name: str, lock_name: str, lock_key: str ) -> None: """Called when a lock has been released. Wakes up any locks that might be waiting on this. """ locks = self._locks.get((lock_name, lock_key)) if not locks: return def _wake_all_locks( locks: Collection[Union[WaitingLock, WaitingMultiLock]], ) -> None: for lock in locks: deferred = lock.deferred if not deferred.called: deferred.callback(None) self._clock.call_later(0, _wake_all_locks, locks) @wrap_as_background_process("_cleanup_locks") async def _cleanup_locks(self) -> None: """Periodically cleans out stale entries in the locks map""" self._locks = {key: value for key, value in self._locks.items() if value} @attr.s(auto_attribs=True, eq=False) class WaitingLock: reactor: IReactorTime store: LockStore handler: WorkerLocksHandler lock_name: str lock_key: str write: Optional[bool] deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) _inner_lock: Optional[Lock] = None _retry_interval: float = 0.1 _lock_span: "opentracing.Scope" = attr.Factory( lambda: start_active_span("WaitingLock.lock") ) async def __aenter__(self) -> None: self._lock_span.__enter__() with start_active_span("WaitingLock.waiting_for_lock"): while self._inner_lock is None: self.deferred = defer.Deferred() if self.write is not None: lock = await self.store.try_acquire_read_write_lock( self.lock_name, self.lock_key, write=self.write ) else: lock = await self.store.try_acquire_lock( self.lock_name, self.lock_key ) if lock: self._inner_lock = lock break try: # Wait until the we get notified the lock might have been # released (by the deferred being resolved). We also # periodically wake up in case the lock was released but we # weren't notified. with PreserveLoggingContext(): await timeout_deferred( deferred=self.deferred, timeout=self._get_next_retry_interval(), reactor=self.reactor, ) except Exception: pass return await self._inner_lock.__aenter__() async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ) -> Optional[bool]: assert self._inner_lock self.handler.notify_lock_released(self.lock_name, self.lock_key) try: r = await self._inner_lock.__aexit__(exc_type, exc, tb) finally: self._lock_span.__exit__(exc_type, exc, tb) return r def _get_next_retry_interval(self) -> float: next = self._retry_interval self._retry_interval = max(5, next * 2) return next * random.uniform(0.9, 1.1) @attr.s(auto_attribs=True, eq=False) class WaitingMultiLock: lock_names: Collection[Tuple[str, str]] write: bool reactor: IReactorTime store: LockStore handler: WorkerLocksHandler deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) _inner_lock_cm: Optional[AsyncContextManager] = None _retry_interval: float = 0.1 _lock_span: "opentracing.Scope" = attr.Factory( lambda: start_active_span("WaitingLock.lock") ) async def __aenter__(self) -> None: self._lock_span.__enter__() with start_active_span("WaitingLock.waiting_for_lock"): while self._inner_lock_cm is None: self.deferred = defer.Deferred() lock_cm = await self.store.try_acquire_multi_read_write_lock( self.lock_names, write=self.write ) if lock_cm: self._inner_lock_cm = lock_cm break try: # Wait until the we get notified the lock might have been # released (by the deferred being resolved). We also # periodically wake up in case the lock was released but we # weren't notified. with PreserveLoggingContext(): await timeout_deferred( deferred=self.deferred, timeout=self._get_next_retry_interval(), reactor=self.reactor, ) except Exception: pass assert self._inner_lock_cm await self._inner_lock_cm.__aenter__() return async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], tb: Optional[TracebackType], ) -> Optional[bool]: assert self._inner_lock_cm for lock_name, lock_key in self.lock_names: self.handler.notify_lock_released(lock_name, lock_key) try: r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb) finally: self._lock_span.__exit__(exc_type, exc, tb) return r def _get_next_retry_interval(self) -> float: next = self._retry_interval self._retry_interval = max(5, next * 2) return next * random.uniform(0.9, 1.1)