Correctly exclude users when making a room public or private (#11075)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
David Robertson 2021-10-15 15:53:05 +01:00 committed by GitHub
parent 5573133348
commit e09be0c87a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 148 additions and 83 deletions

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Set, Tuple
from typing import Any, Dict, Set, Tuple
from unittest import mock
from unittest.mock import Mock, patch
@ -42,18 +42,7 @@ class GetUserDirectoryTables:
def __init__(self, store: DataStore):
self.store = store
def _compress_shared(
self, shared: List[Dict[str, str]]
) -> Set[Tuple[str, str, str]]:
"""
Compress a list of users who share rooms dicts to a list of tuples.
"""
r = set()
for i in shared:
r.add((i["user_id"], i["other_user_id"], i["room_id"]))
return r
async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
async def get_users_in_public_rooms(self) -> Set[Tuple[str, str]]:
"""Fetch the entire `users_in_public_rooms` table.
Returns a list of tuples (user_id, room_id) where room_id is public and
@ -63,24 +52,27 @@ class GetUserDirectoryTables:
"users_in_public_rooms", None, ("user_id", "room_id")
)
retval = []
retval = set()
for i in r:
retval.append((i["user_id"], i["room_id"]))
retval.add((i["user_id"], i["room_id"]))
return retval
async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]:
async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
"""Fetch the entire `users_who_share_private_rooms` table.
Returns a dict containing "user_id", "other_user_id" and "room_id" keys.
The dicts can be flattened to Tuples with the `_compress_shared` method.
(This seems a little awkward---maybe we could clean this up.)
Returns a set of tuples (user_id, other_user_id, room_id) corresponding
to the rows of `users_who_share_private_rooms`.
"""
return await self.store.db_pool.simple_select_list(
rows = await self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
)
rv = set()
for row in rows:
rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
return rv
async def get_users_in_user_directory(self) -> Set[str]:
"""Fetch the set of users in the `user_directory` table.
@ -113,6 +105,16 @@ class GetUserDirectoryTables:
for row in rows
}
async def get_tables(
self,
) -> Tuple[Set[str], Set[Tuple[str, str]], Set[Tuple[str, str, str]]]:
"""Multiple tests want to inspect these tables, so expose them together."""
return (
await self.get_users_in_user_directory(),
await self.get_users_in_public_rooms(),
await self.get_users_who_share_private_rooms(),
)
class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
"""Ensure that rebuilding the directory writes the correct data to the DB.
@ -166,8 +168,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
)
# Nothing updated yet
self.assertEqual(shares_private, [])
self.assertEqual(public_users, [])
self.assertEqual(shares_private, set())
self.assertEqual(public_users, set())
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
@ -236,24 +238,15 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._purge_and_rebuild_user_dir()
shares_private = self.get_success(
self.user_dir_helper.get_users_who_share_private_rooms()
)
public_users = self.get_success(
self.user_dir_helper.get_users_in_public_rooms()
users, in_public, in_private = self.get_success(
self.user_dir_helper.get_tables()
)
# User 1 and User 2 are in the same public room
self.assertEqual(set(public_users), {(u1, room), (u2, room)})
self.assertEqual(in_public, {(u1, room), (u2, room)})
# User 1 and User 3 share private rooms
self.assertEqual(
self.user_dir_helper._compress_shared(shares_private),
{(u1, u3, private_room), (u3, u1, private_room)},
)
self.assertEqual(in_private, {(u1, u3, private_room), (u3, u1, private_room)})
# All three should have entries in the directory
users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
self.assertEqual(users, {u1, u2, u3})
# The next four tests (test_population_excludes_*) all set up
@ -289,16 +282,12 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
self, normal_user: str, public_room: str, private_room: str
) -> None:
# After rebuilding the directory, we should only see the normal user.
users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
users, in_public, in_private = self.get_success(
self.user_dir_helper.get_tables()
)
self.assertEqual(users, {normal_user})
in_public_rooms = self.get_success(
self.user_dir_helper.get_users_in_public_rooms()
)
self.assertEqual(set(in_public_rooms), {(normal_user, public_room)})
in_private_rooms = self.get_success(
self.user_dir_helper.get_users_who_share_private_rooms()
)
self.assertEqual(in_private_rooms, [])
self.assertEqual(in_public, {(normal_user, public_room)})
self.assertEqual(in_private, set())
def test_population_excludes_support_user(self) -> None:
# Create a normal and support user.