diff --git a/changelog.d/14844.misc b/changelog.d/14844.misc new file mode 100644 index 000000000..30ce86630 --- /dev/null +++ b/changelog.d/14844.misc @@ -0,0 +1 @@ +Add check to avoid starting duplicate partial state syncs. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index eca75f110..e386f77de 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -27,6 +27,7 @@ from typing import ( Iterable, List, Optional, + Set, Tuple, Union, ) @@ -171,12 +172,23 @@ class FederationHandler: self.third_party_event_rules = hs.get_third_party_event_rules() + # Tracks running partial state syncs by room ID. + # Partial state syncs currently only run on the main process, so it's okay to + # track them in-memory for now. + self._active_partial_state_syncs: Set[str] = set() + # Tracks partial state syncs we may want to restart. + # A dictionary mapping room IDs to (initial destination, other destinations) + # tuples. + self._partial_state_syncs_maybe_needing_restart: Dict[ + str, Tuple[Optional[str], Collection[str]] + ] = {} + # if this is the main process, fire off a background process to resume # any partial-state-resync operations which were in flight when we # were shut down. if not hs.config.worker.worker_app: run_as_background_process( - "resume_sync_partial_state_room", self._resume_sync_partial_state_room + "resume_sync_partial_state_room", self._resume_partial_state_room_sync ) @trace @@ -679,9 +691,7 @@ class FederationHandler: if ret.partial_state: # Kick off the process of asynchronously fetching the state for this # room. - run_as_background_process( - desc="sync_partial_state_room", - func=self._sync_partial_state_room, + self._start_partial_state_room_sync( initial_destination=origin, other_destinations=ret.servers_in_room, room_id=room_id, @@ -1660,20 +1670,100 @@ class FederationHandler: # well. return None - async def _resume_sync_partial_state_room(self) -> None: + async def _resume_partial_state_room_sync(self) -> None: """Resumes resyncing of all partial-state rooms after a restart.""" assert not self.config.worker.worker_app partial_state_rooms = await self.store.get_partial_state_room_resync_info() for room_id, resync_info in partial_state_rooms.items(): - run_as_background_process( - desc="sync_partial_state_room", - func=self._sync_partial_state_room, + self._start_partial_state_room_sync( initial_destination=resync_info.joined_via, other_destinations=resync_info.servers_in_room, room_id=room_id, ) + def _start_partial_state_room_sync( + self, + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + """Starts the background process to resync the state of a partial state room, + if it is not already running. + + Args: + initial_destination: the initial homeserver to pull the state from + other_destinations: other homeservers to try to pull the state from, if + `initial_destination` is unavailable + room_id: room to be resynced + """ + + async def _sync_partial_state_room_wrapper() -> None: + if room_id in self._active_partial_state_syncs: + # Another local user has joined the room while there is already a + # partial state sync running. This implies that there is a new join + # event to un-partial state. We might find ourselves in one of a few + # scenarios: + # 1. There is an existing partial state sync. The partial state sync + # un-partial states the new join event before completing and all is + # well. + # 2. Before the latest join, the homeserver was no longer in the room + # and there is an existing partial state sync from our previous + # membership of the room. The partial state sync may have: + # a) succeeded, but not yet terminated. The room will not be + # un-partial stated again unless we restart the partial state + # sync. + # b) failed, because we were no longer in the room and remote + # homeservers were refusing our requests, but not yet + # terminated. After the latest join, remote homeservers may + # start answering our requests again, so we should restart the + # partial state sync. + # In the cases where we would want to restart the partial state sync, + # the room would have the partial state flag when the partial state sync + # terminates. + self._partial_state_syncs_maybe_needing_restart[room_id] = ( + initial_destination, + other_destinations, + ) + return + + self._active_partial_state_syncs.add(room_id) + + try: + await self._sync_partial_state_room( + initial_destination=initial_destination, + other_destinations=other_destinations, + room_id=room_id, + ) + finally: + # Read the room's partial state flag while we still hold the claim to + # being the active partial state sync (so that another partial state + # sync can't come along and mess with it under us). + # Normally, the partial state flag will be gone. If it isn't, then we + # may find ourselves in scenario 2a or 2b as described in the comment + # above, where we want to restart the partial state sync. + is_still_partial_state_room = await self.store.is_partial_state_room( + room_id + ) + self._active_partial_state_syncs.remove(room_id) + + if room_id in self._partial_state_syncs_maybe_needing_restart: + ( + restart_initial_destination, + restart_other_destinations, + ) = self._partial_state_syncs_maybe_needing_restart.pop(room_id) + + if is_still_partial_state_room: + self._start_partial_state_room_sync( + initial_destination=restart_initial_destination, + other_destinations=restart_other_destinations, + room_id=room_id, + ) + + run_as_background_process( + desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper + ) + async def _sync_partial_state_room( self, initial_destination: Optional[str], diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index cedbb9faf..c1558c40c 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import cast +from typing import Collection, Optional, cast from unittest import TestCase from unittest.mock import Mock, patch +from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes @@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): f"Stale partial-stated room flag left over for {room_id} after a" f" failed do_invite_join!", ) + + def test_duplicate_partial_state_room_syncs(self) -> None: + """ + Tests that concurrent partial state syncs are not started for the same room. + """ + is_partial_state = True + end_sync: "Deferred[None]" = Deferred() + + async def is_partial_state_room(room_id: str) -> bool: + return is_partial_state + + async def sync_partial_state_room( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + nonlocal end_sync + try: + await end_sync + finally: + end_sync = Deferred() + + mock_is_partial_state_room = Mock(side_effect=is_partial_state_room) + mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room) + + fed_handler = self.hs.get_federation_handler() + store = self.hs.get_datastores().main + + with patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + # Start the partial state sync. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Try to start another partial state sync. + # Nothing should happen. + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # End the partial state sync + is_partial_state = False + end_sync.callback(None) + + # The partial state sync should not be restarted. + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # The next attempt to start the partial state sync should work. + is_partial_state = True + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + def test_partial_state_room_sync_restart(self) -> None: + """ + Tests that partial state syncs are restarted when a second partial state sync + was deduplicated and the first partial state sync fails. + """ + is_partial_state = True + end_sync: "Deferred[None]" = Deferred() + + async def is_partial_state_room(room_id: str) -> bool: + return is_partial_state + + async def sync_partial_state_room( + initial_destination: Optional[str], + other_destinations: Collection[str], + room_id: str, + ) -> None: + nonlocal end_sync + try: + await end_sync + finally: + end_sync = Deferred() + + mock_is_partial_state_room = Mock(side_effect=is_partial_state_room) + mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room) + + fed_handler = self.hs.get_federation_handler() + store = self.hs.get_datastores().main + + with patch.object( + fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room + ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room): + # Start the partial state sync. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Fail the partial state sync. + # The partial state sync should not be restarted. + end_sync.errback(Exception("Failed to request /state_ids")) + self.assertEqual(mock_sync_partial_state_room.call_count, 1) + + # Start the partial state sync again. + fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + # Deduplicate another partial state sync. + fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id") + self.assertEqual(mock_sync_partial_state_room.call_count, 2) + + # Fail the partial state sync. + # It should restart with the latest parameters. + end_sync.errback(Exception("Failed to request /state_ids")) + self.assertEqual(mock_sync_partial_state_room.call_count, 3) + mock_sync_partial_state_room.assert_called_with( + initial_destination="hs3", + other_destinations=["hs2"], + room_id="room_id", + )