Fix handling of public rooms filter with a network tuple. (#14053)

Fixes two related bugs:

* The handling of `[null]` for a `room_types` filter was incorrect.
* The ordering of arguments when providing both a network tuple
  and room type field was incorrect.
This commit is contained in:
Patrick Cloke 2022-10-05 08:49:52 -04:00 committed by GitHub
parent dcced5a8d7
commit 0b037d6c91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 27 deletions

1
changelog.d/14053.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.53.0 when querying `/publicRooms` with both a `room_type` filter and a `third_party_instance_id`.

View File

@ -207,21 +207,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _construct_room_type_where_clause( def _construct_room_type_where_clause(
self, room_types: Union[List[Union[str, None]], None] self, room_types: Union[List[Union[str, None]], None]
) -> Tuple[Union[str, None], List[str]]: ) -> Tuple[Union[str, None], list]:
if not room_types: if not room_types:
return None, [] return None, []
else:
# We use None when we want get rooms without a type
is_null_clause = ""
if None in room_types:
is_null_clause = "OR room_type IS NULL"
room_types = [value for value in room_types if value is not None]
# Since None is used to represent a room without a type, care needs to
# be taken into account when constructing the where clause.
clauses = []
args: list = []
room_types_set = set(room_types)
# We use None to represent a room without a type.
if None in room_types_set:
clauses.append("room_type IS NULL")
room_types_set.remove(None)
# If there are other room types, generate the proper clause.
if room_types:
list_clause, args = make_in_list_sql_clause( list_clause, args = make_in_list_sql_clause(
self.database_engine, "room_type", room_types self.database_engine, "room_type", room_types_set
) )
clauses.append(list_clause)
return f"({list_clause} {is_null_clause})", args return f"({' OR '.join(clauses)})", args
async def count_public_rooms( async def count_public_rooms(
self, self,
@ -241,14 +250,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _count_public_rooms_txn(txn: LoggingTransaction) -> int: def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = [] query_args = []
room_type_clause, args = self._construct_room_type_where_clause(
search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
if search_filter
else None
)
room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
query_args += args
if network_tuple: if network_tuple:
if network_tuple.appservice_id: if network_tuple.appservice_id:
published_sql = """ published_sql = """
@ -268,6 +269,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
UNION SELECT room_id from appservice_room_list UNION SELECT room_id from appservice_room_list
""" """
room_type_clause, args = self._construct_room_type_where_clause(
search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
if search_filter
else None
)
room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
query_args += args
sql = f""" sql = f"""
SELECT SELECT
COUNT(*) COUNT(*)

View File

@ -2213,14 +2213,17 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
) )
def make_public_rooms_request( def make_public_rooms_request(
self, room_types: Union[List[Union[str, None]], None] self,
room_types: Optional[List[Union[str, None]]],
instance_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[Dict[str, Any]], int]:
channel = self.make_request( body: JsonDict = {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}}
"POST", if instance_id:
self.url, body["third_party_instance_id"] = "test|test"
{"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
self.token, channel = self.make_request("POST", self.url, body, self.token)
) self.assertEqual(channel.code, 200)
chunk = channel.json_body["chunk"] chunk = channel.json_body["chunk"]
count = channel.json_body["total_room_count_estimate"] count = channel.json_body["total_room_count_estimate"]
@ -2230,31 +2233,49 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None: def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
chunk, count = self.make_public_rooms_request(None) chunk, count = self.make_public_rooms_request(None)
self.assertEqual(count, 2) self.assertEqual(count, 2)
# Also check if there's no filter property at all in the body.
channel = self.make_request("POST", self.url, {}, self.token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["chunk"]), 2)
self.assertEqual(channel.json_body["total_room_count_estimate"], 2)
chunk, count = self.make_public_rooms_request(None, "test|test")
self.assertEqual(count, 0)
def test_returns_only_rooms_based_on_filter(self) -> None: def test_returns_only_rooms_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request([None]) chunk, count = self.make_public_rooms_request([None])
self.assertEqual(count, 1) self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("room_type", None), None) self.assertEqual(chunk[0].get("room_type", None), None)
chunk, count = self.make_public_rooms_request([None], "test|test")
self.assertEqual(count, 0)
def test_returns_only_space_based_on_filter(self) -> None: def test_returns_only_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space"]) chunk, count = self.make_public_rooms_request(["m.space"])
self.assertEqual(count, 1) self.assertEqual(count, 1)
self.assertEqual(chunk[0].get("room_type", None), "m.space") self.assertEqual(chunk[0].get("room_type", None), "m.space")
chunk, count = self.make_public_rooms_request(["m.space"], "test|test")
self.assertEqual(count, 0)
def test_returns_both_rooms_and_space_based_on_filter(self) -> None: def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
chunk, count = self.make_public_rooms_request(["m.space", None]) chunk, count = self.make_public_rooms_request(["m.space", None])
self.assertEqual(count, 2) self.assertEqual(count, 2)
chunk, count = self.make_public_rooms_request(["m.space", None], "test|test")
self.assertEqual(count, 0)
def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None: def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
chunk, count = self.make_public_rooms_request([]) chunk, count = self.make_public_rooms_request([])
self.assertEqual(count, 2) self.assertEqual(count, 2)
chunk, count = self.make_public_rooms_request([], "test|test")
self.assertEqual(count, 0)
class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
"""Test that we correctly fallback to local filtering if a remote server """Test that we correctly fallback to local filtering if a remote server