Correctly exclude users when making a room public or private (#11075)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
David Robertson 2021-10-15 15:53:05 +01:00 committed by GitHub
parent 5573133348
commit e09be0c87a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 148 additions and 83 deletions

1
changelog.d/11075.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where users excluded from the user directory were added into the directory if they belonged to a room which became public or private.

View File

@ -266,14 +266,17 @@ class UserDirectoryHandler(StateDeltasHandler):
for user_id in users_in_room: for user_id in users_in_room:
await self.store.remove_user_who_share_room(user_id, room_id) await self.store.remove_user_who_share_room(user_id, room_id)
# Then, re-add them to the tables. # Then, re-add all remote users and some local users to the tables.
# NOTE: this is not the most efficient method, as _track_user_joined_room sets # NOTE: this is not the most efficient method, as _track_user_joined_room sets
# up local_user -> other_user and other_user_whos_local -> local_user, # up local_user -> other_user and other_user_whos_local -> local_user,
# which when ran over an entire room, will result in the same values # which when ran over an entire room, will result in the same values
# being added multiple times. The batching upserts shouldn't make this # being added multiple times. The batching upserts shouldn't make this
# too bad, though. # too bad, though.
for user_id in users_in_room: for user_id in users_in_room:
await self._track_user_joined_room(room_id, user_id) if not self.is_mine_id(
user_id
) or await self.store.should_include_local_user_in_dir(user_id):
await self._track_user_joined_room(room_id, user_id)
async def _handle_room_membership_event( async def _handle_room_membership_event(
self, self,
@ -364,8 +367,8 @@ class UserDirectoryHandler(StateDeltasHandler):
"""Someone's just joined a room. Update `users_in_public_rooms` or """Someone's just joined a room. Update `users_in_public_rooms` or
`users_who_share_private_rooms` as appropriate. `users_who_share_private_rooms` as appropriate.
The caller is responsible for ensuring that the given user is not excluded The caller is responsible for ensuring that the given user should be
from the user directory. included in the user directory.
""" """
is_public = await self.store.is_room_world_readable_or_publicly_joinable( is_public = await self.store.is_room_world_readable_or_publicly_joinable(
room_id room_id

View File

@ -109,18 +109,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
tok=alice_token, tok=alice_token,
) )
users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) # The user directory should reflect the room memberships above.
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) users, in_public, in_private = self.get_success(
in_private = self.get_success( self.user_dir_helper.get_tables()
self.user_dir_helper.get_users_who_share_private_rooms()
) )
self.assertEqual(users, {alice, bob}) self.assertEqual(users, {alice, bob})
self.assertEqual(in_public, {(alice, public), (bob, public), (alice, public2)})
self.assertEqual( self.assertEqual(
set(in_public), {(alice, public), (bob, public), (alice, public2)} in_private,
)
self.assertEqual(
self.user_dir_helper._compress_shared(in_private),
{(alice, bob, private), (bob, alice, private)}, {(alice, bob, private), (bob, alice, private)},
) )
@ -209,6 +205,88 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
self.assertEqual(set(in_public), {(user1, room), (user2, room)}) self.assertEqual(set(in_public), {(user1, room), (user2, room)})
def test_excludes_users_when_making_room_public(self) -> None:
# Create a regular user and a support user.
alice = self.register_user("alice", "pass")
alice_token = self.login(alice, "pass")
support = "@support1:test"
self.get_success(
self.store.register_user(
user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
)
)
# Make a public and private room containing Alice and the support user
public, initially_private = self._create_rooms_and_inject_memberships(
alice, alice_token, support
)
self._check_only_one_user_in_directory(alice, public)
# Alice makes the private room public.
self.helper.send_state(
initially_private,
"m.room.join_rules",
{"join_rule": "public"},
tok=alice_token,
)
users, in_public, in_private = self.get_success(
self.user_dir_helper.get_tables()
)
self.assertEqual(users, {alice})
self.assertEqual(in_public, {(alice, public), (alice, initially_private)})
self.assertEqual(in_private, set())
def test_switching_from_private_to_public_to_private(self) -> None:
"""Check we update the room sharing tables when switching a room
from private to public, then back again to private."""
# Alice and Bob share a private room.
alice = self.register_user("alice", "pass")
alice_token = self.login(alice, "pass")
bob = self.register_user("bob", "pass")
bob_token = self.login(bob, "pass")
room = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
self.helper.invite(room, alice, bob, tok=alice_token)
self.helper.join(room, bob, tok=bob_token)
# The user directory should reflect this.
def check_user_dir_for_private_room() -> None:
users, in_public, in_private = self.get_success(
self.user_dir_helper.get_tables()
)
self.assertEqual(users, {alice, bob})
self.assertEqual(in_public, set())
self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)})
check_user_dir_for_private_room()
# Alice makes the room public.
self.helper.send_state(
room,
"m.room.join_rules",
{"join_rule": "public"},
tok=alice_token,
)
# The user directory should be updated accordingly
users, in_public, in_private = self.get_success(
self.user_dir_helper.get_tables()
)
self.assertEqual(users, {alice, bob})
self.assertEqual(in_public, {(alice, room), (bob, room)})
self.assertEqual(in_private, set())
# Alice makes the room private.
self.helper.send_state(
room,
"m.room.join_rules",
{"join_rule": "invite"},
tok=alice_token,
)
# The user directory should be updated accordingly
check_user_dir_for_private_room()
def _create_rooms_and_inject_memberships( def _create_rooms_and_inject_memberships(
self, creator: str, token: str, joiner: str self, creator: str, token: str, joiner: str
) -> Tuple[str, str]: ) -> Tuple[str, str]:
@ -232,15 +310,18 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
return public_room, private_room return public_room, private_room
def _check_only_one_user_in_directory(self, user: str, public: str) -> None: def _check_only_one_user_in_directory(self, user: str, public: str) -> None:
users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) """Check that the user directory DB tables show that:
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
in_private = self.get_success(
self.user_dir_helper.get_users_who_share_private_rooms()
)
- only one user is in the user directory
- they belong to exactly one public room
- they don't share a private room with anyone.
"""
users, in_public, in_private = self.get_success(
self.user_dir_helper.get_tables()
)
self.assertEqual(users, {user}) self.assertEqual(users, {user})
self.assertEqual(set(in_public), {(user, public)}) self.assertEqual(in_public, {(user, public)})
self.assertEqual(in_private, []) self.assertEqual(in_private, set())
def test_handle_local_profile_change_with_support_user(self) -> None: def test_handle_local_profile_change_with_support_user(self) -> None:
support_user_id = "@support:test" support_user_id = "@support:test"
@ -581,11 +662,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.user_dir_helper.get_users_in_public_rooms() self.user_dir_helper.get_users_in_public_rooms()
) )
self.assertEqual( self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
self.user_dir_helper._compress_shared(shares_private), self.assertEqual(public_users, set())
{(u1, u2, room), (u2, u1, room)},
)
self.assertEqual(public_users, [])
# We get one search result when searching for user2 by user1. # We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10)) s = self.get_success(self.handler.search_users(u1, "user2", 10))
@ -610,8 +688,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.user_dir_helper.get_users_in_public_rooms() self.user_dir_helper.get_users_in_public_rooms()
) )
self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) self.assertEqual(shares_private, set())
self.assertEqual(public_users, []) self.assertEqual(public_users, set())
# User1 now gets no search results for any of the other users. # User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10)) s = self.get_success(self.handler.search_users(u1, "user2", 10))
@ -645,11 +723,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.user_dir_helper.get_users_in_public_rooms() self.user_dir_helper.get_users_in_public_rooms()
) )
self.assertEqual( self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
self.user_dir_helper._compress_shared(shares_private), self.assertEqual(public_users, set())
{(u1, u2, room), (u2, u1, room)},
)
self.assertEqual(public_users, [])
# We get one search result when searching for user2 by user1. # We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10)) s = self.get_success(self.handler.search_users(u1, "user2", 10))
@ -704,11 +779,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.user_dir_helper.get_users_in_public_rooms() self.user_dir_helper.get_users_in_public_rooms()
) )
self.assertEqual( self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
self.user_dir_helper._compress_shared(shares_private), self.assertEqual(public_users, set())
{(u1, u2, room), (u2, u1, room)},
)
self.assertEqual(public_users, [])
# Configure a spam checker. # Configure a spam checker.
spam_checker = self.hs.get_spam_checker() spam_checker = self.hs.get_spam_checker()
@ -740,8 +812,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
) )
# No users share rooms # No users share rooms
self.assertEqual(public_users, []) self.assertEqual(public_users, set())
self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) self.assertEqual(shares_private, set())
# Despite not sharing a room, search_all_users means we get a search # Despite not sharing a room, search_all_users means we get a search
# result. # result.

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, List, Set, Tuple from typing import Any, Dict, Set, Tuple
from unittest import mock from unittest import mock
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -42,18 +42,7 @@ class GetUserDirectoryTables:
def __init__(self, store: DataStore): def __init__(self, store: DataStore):
self.store = store self.store = store
def _compress_shared( async def get_users_in_public_rooms(self) -> Set[Tuple[str, str]]:
self, shared: List[Dict[str, str]]
) -> Set[Tuple[str, str, str]]:
"""
Compress a list of users who share rooms dicts to a list of tuples.
"""
r = set()
for i in shared:
r.add((i["user_id"], i["other_user_id"], i["room_id"]))
return r
async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
"""Fetch the entire `users_in_public_rooms` table. """Fetch the entire `users_in_public_rooms` table.
Returns a list of tuples (user_id, room_id) where room_id is public and Returns a list of tuples (user_id, room_id) where room_id is public and
@ -63,24 +52,27 @@ class GetUserDirectoryTables:
"users_in_public_rooms", None, ("user_id", "room_id") "users_in_public_rooms", None, ("user_id", "room_id")
) )
retval = [] retval = set()
for i in r: for i in r:
retval.append((i["user_id"], i["room_id"])) retval.add((i["user_id"], i["room_id"]))
return retval return retval
async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]: async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
"""Fetch the entire `users_who_share_private_rooms` table. """Fetch the entire `users_who_share_private_rooms` table.
Returns a dict containing "user_id", "other_user_id" and "room_id" keys. Returns a set of tuples (user_id, other_user_id, room_id) corresponding
The dicts can be flattened to Tuples with the `_compress_shared` method. to the rows of `users_who_share_private_rooms`.
(This seems a little awkward---maybe we could clean this up.)
""" """
return await self.store.db_pool.simple_select_list( rows = await self.store.db_pool.simple_select_list(
"users_who_share_private_rooms", "users_who_share_private_rooms",
None, None,
["user_id", "other_user_id", "room_id"], ["user_id", "other_user_id", "room_id"],
) )
rv = set()
for row in rows:
rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
return rv
async def get_users_in_user_directory(self) -> Set[str]: async def get_users_in_user_directory(self) -> Set[str]:
"""Fetch the set of users in the `user_directory` table. """Fetch the set of users in the `user_directory` table.
@ -113,6 +105,16 @@ class GetUserDirectoryTables:
for row in rows for row in rows
} }
async def get_tables(
self,
) -> Tuple[Set[str], Set[Tuple[str, str]], Set[Tuple[str, str, str]]]:
"""Multiple tests want to inspect these tables, so expose them together."""
return (
await self.get_users_in_user_directory(),
await self.get_users_in_public_rooms(),
await self.get_users_who_share_private_rooms(),
)
class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
"""Ensure that rebuilding the directory writes the correct data to the DB. """Ensure that rebuilding the directory writes the correct data to the DB.
@ -166,8 +168,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
) )
# Nothing updated yet # Nothing updated yet
self.assertEqual(shares_private, []) self.assertEqual(shares_private, set())
self.assertEqual(public_users, []) self.assertEqual(public_users, set())
# Ugh, have to reset this flag # Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False self.store.db_pool.updates._all_done = False
@ -236,24 +238,15 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Do the initial population of the user directory via the background update # Do the initial population of the user directory via the background update
self._purge_and_rebuild_user_dir() self._purge_and_rebuild_user_dir()
shares_private = self.get_success( users, in_public, in_private = self.get_success(
self.user_dir_helper.get_users_who_share_private_rooms() self.user_dir_helper.get_tables()
)
public_users = self.get_success(
self.user_dir_helper.get_users_in_public_rooms()
) )
# User 1 and User 2 are in the same public room # User 1 and User 2 are in the same public room
self.assertEqual(set(public_users), {(u1, room), (u2, room)}) self.assertEqual(in_public, {(u1, room), (u2, room)})
# User 1 and User 3 share private rooms # User 1 and User 3 share private rooms
self.assertEqual( self.assertEqual(in_private, {(u1, u3, private_room), (u3, u1, private_room)})
self.user_dir_helper._compress_shared(shares_private),
{(u1, u3, private_room), (u3, u1, private_room)},
)
# All three should have entries in the directory # All three should have entries in the directory
users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
self.assertEqual(users, {u1, u2, u3}) self.assertEqual(users, {u1, u2, u3})
# The next four tests (test_population_excludes_*) all set up # The next four tests (test_population_excludes_*) all set up
@ -289,16 +282,12 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
self, normal_user: str, public_room: str, private_room: str self, normal_user: str, public_room: str, private_room: str
) -> None: ) -> None:
# After rebuilding the directory, we should only see the normal user. # After rebuilding the directory, we should only see the normal user.
users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) users, in_public, in_private = self.get_success(
self.user_dir_helper.get_tables()
)
self.assertEqual(users, {normal_user}) self.assertEqual(users, {normal_user})
in_public_rooms = self.get_success( self.assertEqual(in_public, {(normal_user, public_room)})
self.user_dir_helper.get_users_in_public_rooms() self.assertEqual(in_private, set())
)
self.assertEqual(set(in_public_rooms), {(normal_user, public_room)})
in_private_rooms = self.get_success(
self.user_dir_helper.get_users_who_share_private_rooms()
)
self.assertEqual(in_private_rooms, [])
def test_population_excludes_support_user(self) -> None: def test_population_excludes_support_user(self) -> None:
# Create a normal and support user. # Create a normal and support user.