diff --git a/changelog.d/12310.feature b/changelog.d/12310.feature new file mode 100644 index 000000000..f3fbb298f --- /dev/null +++ b/changelog.d/12310.feature @@ -0,0 +1 @@ +Add a configuration option to remove a specific set of rooms from sync responses. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a21b48ab2..b8d8c0dbf 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -539,6 +539,15 @@ templates: # #custom_template_directory: /path/to/custom/templates/ +# List of rooms to exclude from sync responses. This is useful for server +# administrators wishing to group users into a room without these users being able +# to see it from their client. +# +# By default, no room is excluded. +# +#exclude_rooms_from_sync: +# - !foo:example.com + # Message retention policy at the server level. # diff --git a/synapse/config/server.py b/synapse/config/server.py index 38de4b800..0f90302c9 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -680,6 +680,10 @@ class ServerConfig(Config): config.get("use_account_validity_in_account_status") or False ) + self.rooms_to_exclude_from_sync: List[str] = ( + config.get("exclude_rooms_from_sync") or [] + ) + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) @@ -1234,6 +1238,15 @@ class ServerConfig(Config): # information about using custom templates. # #custom_template_directory: /path/to/custom/templates/ + + # List of rooms to exclude from sync responses. This is useful for server + # administrators wishing to group users into a room without these users being able + # to see it from their client. + # + # By default, no room is excluded. + # + #exclude_rooms_from_sync: + # - !foo:example.com """ % locals() ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6c569cfb1..bceafca3b 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -298,6 +298,8 @@ class SyncHandler: expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, ) + self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync + async def wait_for_sync_for_user( self, requester: Requester, @@ -1607,13 +1609,15 @@ class SyncHandler: ignored_users = await self.store.ignored_users(user_id) if since_token: room_changes = await self._get_rooms_changed( - sync_result_builder, ignored_users + sync_result_builder, ignored_users, self.rooms_to_exclude ) tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) else: - room_changes = await self._get_all_rooms(sync_result_builder, ignored_users) + room_changes = await self._get_all_rooms( + sync_result_builder, ignored_users, self.rooms_to_exclude + ) tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) @@ -1689,7 +1693,10 @@ class SyncHandler: return False async def _get_rooms_changed( - self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] + self, + sync_result_builder: "SyncResultBuilder", + ignored_users: FrozenSet[str], + excluded_rooms: List[str], ) -> _RoomChanges: """Determine the changes in rooms to report to the user. @@ -1721,7 +1728,7 @@ class SyncHandler: # _have_rooms_changed. We could keep the results in memory to avoid a # second query, at the cost of more complicated source code. membership_change_events = await self.store.get_membership_changes_for_user( - user_id, since_token.room_key, now_token.room_key + user_id, since_token.room_key, now_token.room_key, excluded_rooms ) mem_change_events_by_room_id: Dict[str, List[EventBase]] = {} @@ -1922,7 +1929,10 @@ class SyncHandler: ) async def _get_all_rooms( - self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str] + self, + sync_result_builder: "SyncResultBuilder", + ignored_users: FrozenSet[str], + ignored_rooms: List[str], ) -> _RoomChanges: """Returns entries for all rooms for the user. @@ -1933,7 +1943,7 @@ class SyncHandler: Args: sync_result_builder ignored_users: Set of users ignored by user. - + ignored_rooms: List of rooms to ignore. """ user_id = sync_result_builder.sync_config.user.to_string() @@ -1944,6 +1954,7 @@ class SyncHandler: room_list = await self.store.get_rooms_for_local_user_where_membership_is( user_id=user_id, membership_list=Membership.LIST, + excluded_rooms=ignored_rooms, ) room_entries = [] diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 3248da535..98d09b373 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): return None async def get_rooms_for_local_user_where_membership_is( - self, user_id: str, membership_list: Collection[str] + self, + user_id: str, + membership_list: Collection[str], + excluded_rooms: Optional[List[str]] = None, ) -> List[RoomsForUser]: """Get all the rooms for this *local* user where the membership for this user matches one in the membership list. @@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): user_id: The user ID. membership_list: A list of synapse.api.constants.Membership values which the user must be in. + excluded_rooms: A list of rooms to ignore. Returns: The RoomsForUser that the user matches the membership types. @@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore): membership_list, ) - # Now we filter out forgotten rooms - forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id) - return [room for room in rooms if room.room_id not in forgotten_rooms] + # Now we filter out forgotten and excluded rooms + rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id) + + if excluded_rooms is not None: + rooms_to_exclude.update(set(excluded_rooms)) + + return [room for room in rooms if room.room_id not in rooms_to_exclude] def _get_rooms_for_local_user_where_membership_is_txn( - self, txn, user_id: str, membership_list: List[str] + self, + txn, + user_id: str, + membership_list: List[str], ) -> List[RoomsForUser]: # Paranoia check. if not self.hs.is_mine_id(user_id): diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 39e1efe37..8e764790d 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -36,7 +36,7 @@ what sort order was used: """ import logging -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple import attr from frozendict import frozendict @@ -585,7 +585,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): return ret, key async def get_membership_changes_for_user( - self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken + self, + user_id: str, + from_key: RoomStreamToken, + to_key: RoomStreamToken, + excluded_rooms: Optional[List[str]] = None, ) -> List[EventBase]: """Fetch membership events for a given user. @@ -610,23 +614,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): min_from_id = from_key.stream max_to_id = to_key.get_max_stream_pos() + args: List[Any] = [user_id, min_from_id, max_to_id] + + ignore_room_clause = "" + if excluded_rooms is not None and len(excluded_rooms) > 0: + ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join( + "?" for _ in excluded_rooms + ) + args = args + excluded_rooms + sql = """ SELECT m.event_id, instance_name, topological_ordering, stream_ordering FROM events AS e, room_memberships AS m WHERE e.event_id = m.event_id AND m.user_id = ? AND e.stream_ordering > ? AND e.stream_ordering <= ? + %s ORDER BY e.stream_ordering ASC - """ - txn.execute( - sql, - ( - user_id, - min_from_id, - max_to_id, - ), + """ % ( + ignore_room_clause, ) + txn.execute(sql, args) + rows = [ _EventDictReturn(event_id, None, stream_ordering) for event_id, instance_name, topological_ordering, stream_ordering in txn diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 435101395..f0f3a54f8 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -772,3 +772,65 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase): self.assertIn( self.user_id, device_list_changes, incremental_sync_channel.json_body ) + + +class ExcludeRoomTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + room.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.user_id = self.register_user("user", "password") + self.tok = self.login("user", "password") + + self.excluded_room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + self.included_room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # We need to manually append the room ID, because we can't know the ID before + # creating the room, and we can't set the config after starting the homeserver. + self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id) + + def test_join_leave(self) -> None: + """Tests that rooms are correctly excluded from the 'join' and 'leave' sections of + sync responses. + """ + channel = self.make_request("GET", "/sync", access_token=self.tok) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"]) + + self.helper.leave(self.excluded_room_id, self.user_id, tok=self.tok) + self.helper.leave(self.included_room_id, self.user_id, tok=self.tok) + + channel = self.make_request( + "GET", + "/sync?since=" + channel.json_body["next_batch"], + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["leave"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["leave"]) + + def test_invite(self) -> None: + """Tests that rooms are correctly excluded from the 'invite' section of sync + responses. + """ + invitee = self.register_user("invitee", "password") + invitee_tok = self.login("invitee", "password") + + self.helper.invite(self.excluded_room_id, self.user_id, invitee, tok=self.tok) + self.helper.invite(self.included_room_id, self.user_id, invitee, tok=self.tok) + + channel = self.make_request("GET", "/sync", access_token=invitee_tok) + self.assertEqual(channel.code, 200, channel.result) + + self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"]) + self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])