Add is_invite filtering to Sliding Sync /sync (#17335)

Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync
This commit is contained in:
Eric Eastwood 2024-06-24 19:07:56 -05:00 committed by GitHub
parent 805e6c9a8f
commit 6e8af83193
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 199 additions and 43 deletions

View File

@ -0,0 +1 @@
Add `is_invite` filtering to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.

View File

@ -554,7 +554,7 @@ class SlidingSyncHandler:
# Flatten out the map # Flatten out the map
dm_room_id_set = set() dm_room_id_set = set()
if dm_map: if isinstance(dm_map, dict):
for room_ids in dm_map.values(): for room_ids in dm_map.values():
# Account data should be a list of room IDs. Ignore anything else # Account data should be a list of room IDs. Ignore anything else
if isinstance(room_ids, list): if isinstance(room_ids, list):
@ -593,8 +593,21 @@ class SlidingSyncHandler:
): ):
filtered_room_id_set.remove(room_id) filtered_room_id_set.remove(room_id)
if filters.is_invite: # Filter for rooms that the user has been invited to
raise NotImplementedError() if filters.is_invite is not None:
# Make a copy so we don't run into an error: `Set changed size during
# iteration`, when we filter out and remove items
for room_id in list(filtered_room_id_set):
room_for_user = sync_room_map[room_id]
# If we're looking for invite rooms, filter out rooms that the user is
# not invited to and vice versa
if (
filters.is_invite and room_for_user.membership != Membership.INVITE
) or (
not filters.is_invite
and room_for_user.membership == Membership.INVITE
):
filtered_room_id_set.remove(room_id)
if filters.room_types: if filters.room_types:
raise NotImplementedError() raise NotImplementedError()

View File

@ -1200,11 +1200,7 @@ class FilterRoomsTestCase(HomeserverTestCase):
user2_tok = self.login(user2_id, "pass") user2_tok = self.login(user2_id, "pass")
# Create a normal room # Create a normal room
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
user1_id,
is_public=False,
tok=user1_tok,
)
# Create a DM room # Create a DM room
dm_room_id = self._create_dm_room( dm_room_id = self._create_dm_room(
@ -1261,18 +1257,10 @@ class FilterRoomsTestCase(HomeserverTestCase):
user1_tok = self.login(user1_id, "pass") user1_tok = self.login(user1_id, "pass")
# Create a normal room # Create a normal room
room_id = self.helper.create_room_as( room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
user1_id,
is_public=False,
tok=user1_tok,
)
# Create an encrypted room # Create an encrypted room
encrypted_room_id = self.helper.create_room_as( encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
user1_id,
is_public=False,
tok=user1_tok,
)
self.helper.send_state( self.helper.send_state(
encrypted_room_id, encrypted_room_id,
EventTypes.RoomEncryption, EventTypes.RoomEncryption,
@ -1319,6 +1307,62 @@ class FilterRoomsTestCase(HomeserverTestCase):
self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
def test_filter_invite_rooms(self) -> None:
"""
Test `filter.is_invite` for rooms that the user has been invited to
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
# Create a normal room
room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok)
# Create a room that user1 is invited to
invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
after_rooms_token = self.event_sources.get_current_token()
# Get the rooms the user should be syncing with
sync_room_map = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=None,
to_token=after_rooms_token,
)
)
# Try with `is_invite=True`
truthy_filtered_room_map = self.get_success(
self.sliding_sync_handler.filter_rooms(
UserID.from_string(user1_id),
sync_room_map,
SlidingSyncConfig.SlidingSyncList.Filters(
is_invite=True,
),
after_rooms_token,
)
)
self.assertEqual(truthy_filtered_room_map.keys(), {invite_room_id})
# Try with `is_invite=False`
falsy_filtered_room_map = self.get_success(
self.sliding_sync_handler.filter_rooms(
UserID.from_string(user1_id),
sync_room_map,
SlidingSyncConfig.SlidingSyncList.Filters(
is_invite=False,
),
after_rooms_token,
)
)
self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
class SortRoomsTestCase(HomeserverTestCase): class SortRoomsTestCase(HomeserverTestCase):
""" """

View File

@ -19,7 +19,8 @@
# #
# #
import json import json
from typing import List import logging
from typing import Dict, List
from parameterized import parameterized, parameterized_class from parameterized import parameterized, parameterized_class
@ -44,6 +45,8 @@ from tests.federation.transport.test_knocking import (
) )
from tests.server import TimedOutException from tests.server import TimedOutException
logger = logging.getLogger(__name__)
class FilterTestCase(unittest.HomeserverTestCase): class FilterTestCase(unittest.HomeserverTestCase):
user_id = "@apple:test" user_id = "@apple:test"
@ -1234,12 +1237,58 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
def _add_new_dm_to_global_account_data(
self, source_user_id: str, target_user_id: str, target_room_id: str
) -> None:
"""
Helper to handle inserting a new DM for the source user into global account data
(handles all of the list merging).
Args:
source_user_id: The user ID of the DM mapping we're going to update
target_user_id: User ID of the person the DM is with
target_room_id: Room ID of the DM
"""
# Get the current DM map
existing_dm_map = self.get_success(
self.store.get_global_account_data_by_type_for_user(
source_user_id, AccountDataTypes.DIRECT
)
)
# Scrutinize the account data since it has no concrete type. We're just copying
# everything into a known type. It should be a mapping from user ID to a list of
# room IDs. Ignore anything else.
new_dm_map: Dict[str, List[str]] = {}
if isinstance(existing_dm_map, dict):
for user_id, room_ids in existing_dm_map.items():
if isinstance(user_id, str) and isinstance(room_ids, list):
for room_id in room_ids:
if isinstance(room_id, str):
new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
room_id
]
# Add the new DM to the map
new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
target_room_id
]
# Save the DM map to global account data
self.get_success(
self.store.add_account_data_for_user(
source_user_id,
AccountDataTypes.DIRECT,
new_dm_map,
)
)
def _create_dm_room( def _create_dm_room(
self, self,
inviter_user_id: str, inviter_user_id: str,
inviter_tok: str, inviter_tok: str,
invitee_user_id: str, invitee_user_id: str,
invitee_tok: str, invitee_tok: str,
should_join_room: bool = True,
) -> str: ) -> str:
""" """
Helper to create a DM room as the "inviter" and invite the "invitee" user to the Helper to create a DM room as the "inviter" and invite the "invitee" user to the
@ -1260,24 +1309,17 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
tok=inviter_tok, tok=inviter_tok,
extra_data={"is_direct": True}, extra_data={"is_direct": True},
) )
# Person that was invited joins the room if should_join_room:
self.helper.join(room_id, invitee_user_id, tok=invitee_tok) # Person that was invited joins the room
self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
# Mimic the client setting the room as a direct message in the global account # Mimic the client setting the room as a direct message in the global account
# data # data for both users.
self.get_success( self._add_new_dm_to_global_account_data(
self.store.add_account_data_for_user( invitee_user_id, inviter_user_id, room_id
invitee_user_id,
AccountDataTypes.DIRECT,
{inviter_user_id: [room_id]},
)
) )
self.get_success( self._add_new_dm_to_global_account_data(
self.store.add_account_data_for_user( inviter_user_id, invitee_user_id, room_id
inviter_user_id,
AccountDataTypes.DIRECT,
{invitee_user_id: [room_id]},
)
) )
return room_id return room_id
@ -1397,15 +1439,28 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
user2_tok = self.login(user2_id, "pass") user2_tok = self.login(user2_id, "pass")
# Create a DM room # Create a DM room
dm_room_id = self._create_dm_room( joined_dm_room_id = self._create_dm_room(
inviter_user_id=user1_id, inviter_user_id=user1_id,
inviter_tok=user1_tok, inviter_tok=user1_tok,
invitee_user_id=user2_id, invitee_user_id=user2_id,
invitee_tok=user2_tok, invitee_tok=user2_tok,
should_join_room=True,
)
invited_dm_room_id = self._create_dm_room(
inviter_user_id=user1_id,
inviter_tok=user1_tok,
invitee_user_id=user2_id,
invitee_tok=user2_tok,
should_join_room=False,
) )
# Create a normal room # Create a normal room
room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
self.helper.join(room_id, user1_id, tok=user1_tok)
# Create a room that user1 is invited to
invite_room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
# Make the Sliding Sync request # Make the Sliding Sync request
channel = self.make_request( channel = self.make_request(
@ -1413,18 +1468,34 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
self.sync_endpoint, self.sync_endpoint,
{ {
"lists": { "lists": {
# Absense of filters does not imply "False" values
"all": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {},
},
# Test single truthy filter
"dms": { "dms": {
"ranges": [[0, 99]], "ranges": [[0, 99]],
"required_state": [], "required_state": [],
"timeline_limit": 1, "timeline_limit": 1,
"filters": {"is_dm": True}, "filters": {"is_dm": True},
}, },
"foo-list": { # Test single falsy filter
"non-dms": {
"ranges": [[0, 99]], "ranges": [[0, 99]],
"required_state": [], "required_state": [],
"timeline_limit": 1, "timeline_limit": 1,
"filters": {"is_dm": False}, "filters": {"is_dm": False},
}, },
# Test how multiple filters should stack (AND'd together)
"room-invites": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {"is_dm": False, "is_invite": True},
},
} }
}, },
access_token=user1_tok, access_token=user1_tok,
@ -1434,32 +1505,59 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
# Make sure it has the foo-list we requested # Make sure it has the foo-list we requested
self.assertListEqual( self.assertListEqual(
list(channel.json_body["lists"].keys()), list(channel.json_body["lists"].keys()),
["dms", "foo-list"], ["all", "dms", "non-dms", "room-invites"],
channel.json_body["lists"].keys(), channel.json_body["lists"].keys(),
) )
# Make sure the list includes the room we are joined to # Make sure the lists have the correct rooms
self.assertListEqual(
list(channel.json_body["lists"]["all"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [
invite_room_id,
room_id,
invited_dm_room_id,
joined_dm_room_id,
],
}
],
list(channel.json_body["lists"]["all"]),
)
self.assertListEqual( self.assertListEqual(
list(channel.json_body["lists"]["dms"]["ops"]), list(channel.json_body["lists"]["dms"]["ops"]),
[ [
{ {
"op": "SYNC", "op": "SYNC",
"range": [0, 99], "range": [0, 99],
"room_ids": [dm_room_id], "room_ids": [invited_dm_room_id, joined_dm_room_id],
} }
], ],
list(channel.json_body["lists"]["dms"]), list(channel.json_body["lists"]["dms"]),
) )
self.assertListEqual( self.assertListEqual(
list(channel.json_body["lists"]["foo-list"]["ops"]), list(channel.json_body["lists"]["non-dms"]["ops"]),
[ [
{ {
"op": "SYNC", "op": "SYNC",
"range": [0, 99], "range": [0, 99],
"room_ids": [room_id], "room_ids": [invite_room_id, room_id],
} }
], ],
list(channel.json_body["lists"]["foo-list"]), list(channel.json_body["lists"]["non-dms"]),
)
self.assertListEqual(
list(channel.json_body["lists"]["room-invites"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
"room_ids": [invite_room_id],
}
],
list(channel.json_body["lists"]["room-invites"]),
) )
def test_sort_list(self) -> None: def test_sort_list(self) -> None: