mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Fix handling of public rooms filter with a network tuple. (#14053)
Fixes two related bugs: * The handling of `[null]` for a `room_types` filter was incorrect. * The ordering of arguments when providing both a network tuple and room type field was incorrect.
This commit is contained in:
parent
dcced5a8d7
commit
0b037d6c91
1
changelog.d/14053.bugfix
Normal file
1
changelog.d/14053.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix a bug introduced in Synapse 1.53.0 when querying `/publicRooms` with both a `room_type` filter and a `third_party_instance_id`.
|
@ -207,21 +207,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
|
||||
def _construct_room_type_where_clause(
|
||||
self, room_types: Union[List[Union[str, None]], None]
|
||||
) -> Tuple[Union[str, None], List[str]]:
|
||||
) -> Tuple[Union[str, None], list]:
|
||||
if not room_types:
|
||||
return None, []
|
||||
else:
|
||||
# We use None when we want get rooms without a type
|
||||
is_null_clause = ""
|
||||
if None in room_types:
|
||||
is_null_clause = "OR room_type IS NULL"
|
||||
room_types = [value for value in room_types if value is not None]
|
||||
|
||||
# Since None is used to represent a room without a type, care needs to
|
||||
# be taken into account when constructing the where clause.
|
||||
clauses = []
|
||||
args: list = []
|
||||
|
||||
room_types_set = set(room_types)
|
||||
|
||||
# We use None to represent a room without a type.
|
||||
if None in room_types_set:
|
||||
clauses.append("room_type IS NULL")
|
||||
room_types_set.remove(None)
|
||||
|
||||
# If there are other room types, generate the proper clause.
|
||||
if room_types:
|
||||
list_clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "room_type", room_types
|
||||
self.database_engine, "room_type", room_types_set
|
||||
)
|
||||
clauses.append(list_clause)
|
||||
|
||||
return f"({list_clause} {is_null_clause})", args
|
||||
return f"({' OR '.join(clauses)})", args
|
||||
|
||||
async def count_public_rooms(
|
||||
self,
|
||||
@ -241,14 +250,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
|
||||
query_args = []
|
||||
|
||||
room_type_clause, args = self._construct_room_type_where_clause(
|
||||
search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
|
||||
if search_filter
|
||||
else None
|
||||
)
|
||||
room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
|
||||
query_args += args
|
||||
|
||||
if network_tuple:
|
||||
if network_tuple.appservice_id:
|
||||
published_sql = """
|
||||
@ -268,6 +269,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
UNION SELECT room_id from appservice_room_list
|
||||
"""
|
||||
|
||||
room_type_clause, args = self._construct_room_type_where_clause(
|
||||
search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
|
||||
if search_filter
|
||||
else None
|
||||
)
|
||||
room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
|
||||
query_args += args
|
||||
|
||||
sql = f"""
|
||||
SELECT
|
||||
COUNT(*)
|
||||
|
@ -2213,14 +2213,17 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
def make_public_rooms_request(
|
||||
self, room_types: Union[List[Union[str, None]], None]
|
||||
self,
|
||||
room_types: Optional[List[Union[str, None]]],
|
||||
instance_id: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.url,
|
||||
{"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
|
||||
self.token,
|
||||
)
|
||||
body: JsonDict = {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}
|
||||
if instance_id:
|
||||
body["third_party_instance_id"] = "test|test"
|
||||
|
||||
channel = self.make_request("POST", self.url, body, self.token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
chunk = channel.json_body["chunk"]
|
||||
count = channel.json_body["total_room_count_estimate"]
|
||||
|
||||
@ -2230,31 +2233,49 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
|
||||
chunk, count = self.make_public_rooms_request(None)
|
||||
|
||||
self.assertEqual(count, 2)
|
||||
|
||||
# Also check if there's no filter property at all in the body.
|
||||
channel = self.make_request("POST", self.url, {}, self.token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(len(channel.json_body["chunk"]), 2)
|
||||
self.assertEqual(channel.json_body["total_room_count_estimate"], 2)
|
||||
|
||||
chunk, count = self.make_public_rooms_request(None, "test|test")
|
||||
self.assertEqual(count, 0)
|
||||
|
||||
def test_returns_only_rooms_based_on_filter(self) -> None:
|
||||
chunk, count = self.make_public_rooms_request([None])
|
||||
|
||||
self.assertEqual(count, 1)
|
||||
self.assertEqual(chunk[0].get("room_type", None), None)
|
||||
|
||||
chunk, count = self.make_public_rooms_request([None], "test|test")
|
||||
self.assertEqual(count, 0)
|
||||
|
||||
def test_returns_only_space_based_on_filter(self) -> None:
|
||||
chunk, count = self.make_public_rooms_request(["m.space"])
|
||||
|
||||
self.assertEqual(count, 1)
|
||||
self.assertEqual(chunk[0].get("room_type", None), "m.space")
|
||||
|
||||
chunk, count = self.make_public_rooms_request(["m.space"], "test|test")
|
||||
self.assertEqual(count, 0)
|
||||
|
||||
def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
|
||||
chunk, count = self.make_public_rooms_request(["m.space", None])
|
||||
|
||||
self.assertEqual(count, 2)
|
||||
|
||||
chunk, count = self.make_public_rooms_request(["m.space", None], "test|test")
|
||||
self.assertEqual(count, 0)
|
||||
|
||||
def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
|
||||
chunk, count = self.make_public_rooms_request([])
|
||||
|
||||
self.assertEqual(count, 2)
|
||||
|
||||
chunk, count = self.make_public_rooms_request([], "test|test")
|
||||
self.assertEqual(count, 0)
|
||||
|
||||
|
||||
class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
||||
"""Test that we correctly fallback to local filtering if a remote server
|
||||
|
Loading…
Reference in New Issue
Block a user