Support filtering by relations per MSC3440 (#11236)

Adds experimental support for `relation_types` and `relation_senders`
fields for filters.
This commit is contained in:
Patrick Cloke 2021-11-09 08:10:58 -05:00 committed by GitHub
parent 4b3e30c276
commit a19d01c3d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 680 additions and 110 deletions

View File

@ -0,0 +1 @@
Support filtering by relation senders & types per [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).

View File

@ -1,7 +1,7 @@
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd # Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd # Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C. # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -86,6 +86,9 @@ ROOM_EVENT_FILTER_SCHEMA = {
# cf https://github.com/matrix-org/matrix-doc/pull/2326 # cf https://github.com/matrix-org/matrix-doc/pull/2326
"org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.labels": {"type": "array", "items": {"type": "string"}},
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
# MSC3440, filtering by event relations.
"io.element.relation_senders": {"type": "array", "items": {"type": "string"}},
"io.element.relation_types": {"type": "array", "items": {"type": "string"}},
}, },
} }
@ -146,14 +149,16 @@ def matrix_user_id_validator(user_id_str: str) -> UserID:
class Filtering: class Filtering:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() self._hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
async def get_user_filter( async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str] self, user_localpart: str, filter_id: Union[int, str]
) -> "FilterCollection": ) -> "FilterCollection":
result = await self.store.get_user_filter(user_localpart, filter_id) result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result) return FilterCollection(self._hs, result)
def add_user_filter( def add_user_filter(
self, user_localpart: str, user_filter: JsonDict self, user_localpart: str, user_filter: JsonDict
@ -191,21 +196,22 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
class FilterCollection: class FilterCollection:
def __init__(self, filter_json: JsonDict): def __init__(self, hs: "HomeServer", filter_json: JsonDict):
self._filter_json = filter_json self._filter_json = filter_json
room_filter_json = self._filter_json.get("room", {}) room_filter_json = self._filter_json.get("room", {})
self._room_filter = Filter( self._room_filter = Filter(
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")} hs,
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")},
) )
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {})) self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(room_filter_json.get("state", {})) self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {})) self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
self._room_account_data = Filter(room_filter_json.get("account_data", {})) self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
self._presence_filter = Filter(filter_json.get("presence", {})) self._presence_filter = Filter(hs, filter_json.get("presence", {}))
self._account_data = Filter(filter_json.get("account_data", {})) self._account_data = Filter(hs, filter_json.get("account_data", {}))
self.include_leave = filter_json.get("room", {}).get("include_leave", False) self.include_leave = filter_json.get("room", {}).get("include_leave", False)
self.event_fields = filter_json.get("event_fields", []) self.event_fields = filter_json.get("event_fields", [])
@ -232,25 +238,37 @@ class FilterCollection:
def include_redundant_members(self) -> bool: def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members return self._room_state_filter.include_redundant_members
def filter_presence( async def filter_presence(
self, events: Iterable[UserPresenceState] self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]: ) -> List[UserPresenceState]:
return self._presence_filter.filter(events) return await self._presence_filter.filter(events)
def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._account_data.filter(events) return await self._account_data.filter(events)
def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]: async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_state_filter.filter(self._room_filter.filter(events)) return await self._room_state_filter.filter(
await self._room_filter.filter(events)
)
def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]: async def filter_room_timeline(
return self._room_timeline_filter.filter(self._room_filter.filter(events)) self, events: Iterable[EventBase]
) -> List[EventBase]:
return await self._room_timeline_filter.filter(
await self._room_filter.filter(events)
)
def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]: async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_ephemeral_filter.filter(self._room_filter.filter(events)) return await self._room_ephemeral_filter.filter(
await self._room_filter.filter(events)
)
def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]: async def filter_room_account_data(
return self._room_account_data.filter(self._room_filter.filter(events)) self, events: Iterable[JsonDict]
) -> List[JsonDict]:
return await self._room_account_data.filter(
await self._room_filter.filter(events)
)
def blocks_all_presence(self) -> bool: def blocks_all_presence(self) -> bool:
return ( return (
@ -274,7 +292,9 @@ class FilterCollection:
class Filter: class Filter:
def __init__(self, filter_json: JsonDict): def __init__(self, hs: "HomeServer", filter_json: JsonDict):
self._hs = hs
self._store = hs.get_datastore()
self.filter_json = filter_json self.filter_json = filter_json
self.limit = filter_json.get("limit", 10) self.limit = filter_json.get("limit", 10)
@ -297,6 +317,20 @@ class Filter:
self.labels = filter_json.get("org.matrix.labels", None) self.labels = filter_json.get("org.matrix.labels", None)
self.not_labels = filter_json.get("org.matrix.not_labels", []) self.not_labels = filter_json.get("org.matrix.not_labels", [])
# Ideally these would be rejected at the endpoint if they were provided
# and not supported, but that would involve modifying the JSON schema
# based on the homeserver configuration.
if hs.config.experimental.msc3440_enabled:
self.relation_senders = self.filter_json.get(
"io.element.relation_senders", None
)
self.relation_types = self.filter_json.get(
"io.element.relation_types", None
)
else:
self.relation_senders = None
self.relation_types = None
def filters_all_types(self) -> bool: def filters_all_types(self) -> bool:
return "*" in self.not_types return "*" in self.not_types
@ -306,7 +340,7 @@ class Filter:
def filters_all_rooms(self) -> bool: def filters_all_rooms(self) -> bool:
return "*" in self.not_rooms return "*" in self.not_rooms
def check(self, event: FilterEvent) -> bool: def _check(self, event: FilterEvent) -> bool:
"""Checks whether the filter matches the given event. """Checks whether the filter matches the given event.
Args: Args:
@ -420,8 +454,30 @@ class Filter:
return room_ids return room_ids
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: async def _check_event_relations(
return list(filter(self.check, events)) self, events: Iterable[FilterEvent]
) -> List[FilterEvent]:
# The event IDs to check, mypy doesn't understand the ifinstance check.
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
event_ids_to_keep = set(
await self._store.events_have_relations(
event_ids, self.relation_senders, self.relation_types
)
)
return [
event
for event in events
if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep
]
async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
result = [event for event in events if self._check(event)]
if self.relation_senders or self.relation_types:
return await self._check_event_relations(result)
return result
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter": def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
"""Returns a new filter with the given room IDs appended. """Returns a new filter with the given room IDs appended.
@ -433,7 +489,7 @@ class Filter:
filter: A new filter including the given rooms and the old filter: A new filter including the given rooms and the old
filter's rooms. filter's rooms.
""" """
newFilter = Filter(self.filter_json) newFilter = Filter(self._hs, self.filter_json)
newFilter.rooms += room_ids newFilter.rooms += room_ids
return newFilter return newFilter
@ -444,6 +500,3 @@ def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
return actual_value.startswith(type_prefix) return actual_value.startswith(type_prefix)
else: else:
return actual_value == filter_value return actual_value == filter_value
DEFAULT_FILTER_COLLECTION = FilterCollection({})

View File

@ -424,7 +424,7 @@ class PaginationHandler:
if events: if events:
if event_filter: if event_filter:
events = event_filter.filter(events) events = await event_filter.filter(events)
events = await filter_events_for_client( events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None) self.storage, user_id, events, is_peeking=(member_event_id is None)

View File

@ -1158,8 +1158,10 @@ class RoomContextHandler:
) )
if event_filter: if event_filter:
results["events_before"] = event_filter.filter(results["events_before"]) results["events_before"] = await event_filter.filter(
results["events_after"] = event_filter.filter(results["events_after"]) results["events_before"]
)
results["events_after"] = await event_filter.filter(results["events_after"])
results["events_before"] = await filter_evts(results["events_before"]) results["events_before"] = await filter_evts(results["events_before"])
results["events_after"] = await filter_evts(results["events_after"]) results["events_after"] = await filter_evts(results["events_after"])
@ -1195,7 +1197,7 @@ class RoomContextHandler:
state_events = list(state[last_event_id].values()) state_events = list(state[last_event_id].values())
if event_filter: if event_filter:
state_events = event_filter.filter(state_events) state_events = await event_filter.filter(state_events)
results["state"] = await filter_evts(state_events) results["state"] = await filter_evts(state_events)

View File

@ -180,7 +180,7 @@ class SearchHandler:
% (set(group_keys) - {"room_id", "sender"},), % (set(group_keys) - {"room_id", "sender"},),
) )
search_filter = Filter(filter_dict) search_filter = Filter(self.hs, filter_dict)
# TODO: Search through left rooms too # TODO: Search through left rooms too
rooms = await self.store.get_rooms_for_local_user_where_membership_is( rooms = await self.store.get_rooms_for_local_user_where_membership_is(
@ -242,7 +242,7 @@ class SearchHandler:
rank_map.update({r["event"].event_id: r["rank"] for r in results}) rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = search_filter.filter([r["event"] for r in results]) filtered_events = await search_filter.filter([r["event"] for r in results])
events = await filter_events_for_client( events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events self.storage, user.to_string(), filtered_events
@ -292,7 +292,9 @@ class SearchHandler:
rank_map.update({r["event"].event_id: r["rank"] for r in results}) rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = search_filter.filter([r["event"] for r in results]) filtered_events = await search_filter.filter(
[r["event"] for r in results]
)
events = await filter_events_for_client( events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events self.storage, user.to_string(), filtered_events

View File

@ -510,7 +510,7 @@ class SyncHandler:
log_kv({"limited": limited}) log_kv({"limited": limited})
if potential_recents: if potential_recents:
recents = sync_config.filter_collection.filter_room_timeline( recents = await sync_config.filter_collection.filter_room_timeline(
potential_recents potential_recents
) )
log_kv({"recents_after_sync_filtering": len(recents)}) log_kv({"recents_after_sync_filtering": len(recents)})
@ -575,8 +575,8 @@ class SyncHandler:
log_kv({"loaded_recents": len(events)}) log_kv({"loaded_recents": len(events)})
loaded_recents = sync_config.filter_collection.filter_room_timeline( loaded_recents = (
events await sync_config.filter_collection.filter_room_timeline(events)
) )
log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)}) log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)})
@ -1015,7 +1015,7 @@ class SyncHandler:
return { return {
(e.type, e.state_key): e (e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state( for e in await sync_config.filter_collection.filter_room_state(
list(state.values()) list(state.values())
) )
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
@ -1383,7 +1383,7 @@ class SyncHandler:
sync_config.user sync_config.user
) )
account_data_for_user = sync_config.filter_collection.filter_account_data( account_data_for_user = await sync_config.filter_collection.filter_account_data(
[ [
{"type": account_data_type, "content": content} {"type": account_data_type, "content": content}
for account_data_type, content in account_data.items() for account_data_type, content in account_data.items()
@ -1448,7 +1448,7 @@ class SyncHandler:
# Deduplicate the presence entries so that there's at most one per user # Deduplicate the presence entries so that there's at most one per user
presence = list({p.user_id: p for p in presence}.values()) presence = list({p.user_id: p for p in presence}.values())
presence = sync_config.filter_collection.filter_presence(presence) presence = await sync_config.filter_collection.filter_presence(presence)
sync_result_builder.presence = presence sync_result_builder.presence = presence
@ -2021,12 +2021,14 @@ class SyncHandler:
) )
account_data_events = ( account_data_events = (
sync_config.filter_collection.filter_room_account_data( await sync_config.filter_collection.filter_room_account_data(
account_data_events account_data_events
) )
) )
ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral) ephemeral = await sync_config.filter_collection.filter_room_ephemeral(
ephemeral
)
if not ( if not (
always_include always_include

View File

@ -583,6 +583,7 @@ class RoomEventContextServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self._hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler() self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@ -600,7 +601,9 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8") filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str: if filter_str:
filter_json = urlparse.unquote(filter_str) filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else: else:
event_filter = None event_filter = None

View File

@ -550,6 +550,7 @@ class RoomMessageListRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self._hs = hs
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -567,7 +568,9 @@ class RoomMessageListRestServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8") filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str: if filter_str:
filter_json = urlparse.unquote(filter_str) filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
if ( if (
event_filter event_filter
and event_filter.filter_json.get("event_format", "client") and event_filter.filter_json.get("event_format", "client")
@ -672,6 +675,7 @@ class RoomEventContextServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self._hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler() self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@ -688,7 +692,9 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, "filter", encoding="utf-8") filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str: if filter_str:
filter_json = urlparse.unquote(filter_str) filter_json = urlparse.unquote(filter_str)
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json)) event_filter: Optional[Filter] = Filter(
self._hs, json_decoder.decode(filter_json)
)
else: else:
event_filter = None event_filter = None

View File

@ -29,7 +29,7 @@ from typing import (
from synapse.api.constants import Membership, PresenceState from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import ( from synapse.events.utils import (
@ -150,7 +150,7 @@ class SyncRestServlet(RestServlet):
request_key = (user, timeout, since, filter_id, full_state, device_id) request_key = (user, timeout, since, filter_id, full_state, device_id)
if filter_id is None: if filter_id is None:
filter_collection = DEFAULT_FILTER_COLLECTION filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION
elif filter_id.startswith("{"): elif filter_id.startswith("{"):
try: try:
filter_object = json_decoder.decode(filter_id) filter_object = json_decoder.decode(filter_id)
@ -160,7 +160,7 @@ class SyncRestServlet(RestServlet):
except Exception: except Exception:
raise SynapseError(400, "Invalid filter JSON") raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object) self.filtering.check_valid_filter(filter_object)
filter_collection = FilterCollection(filter_object) filter_collection = FilterCollection(self.hs, filter_object)
else: else:
try: try:
filter_collection = await self.filtering.get_user_filter( filter_collection = await self.filtering.get_user_filter(

View File

@ -20,7 +20,7 @@ import attr
from synapse.api.constants import RelationTypes from synapse.api.constants import RelationTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import ( from synapse.storage.relations import (
AggregationPaginationToken, AggregationPaginationToken,
@ -334,6 +334,62 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event return count, latest_event
async def events_have_relations(
self,
parent_ids: List[str],
relation_senders: Optional[List[str]],
relation_types: Optional[List[str]],
) -> List[str]:
"""Check which events have a relationship from the given senders of the
given types.
Args:
parent_ids: The events being annotated
relation_senders: The relation senders to check.
relation_types: The relation types to check.
Returns:
True if the event has at least one relationship from one of the given senders of the given type.
"""
# If no restrictions are given then the event has the required relations.
if not relation_senders and not relation_types:
return parent_ids
sql = """
SELECT relates_to_id FROM event_relations
INNER JOIN events USING (event_id)
WHERE
%s;
"""
def _get_if_event_has_relations(txn) -> List[str]:
clauses: List[str] = []
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", parent_ids
)
clauses.append(clause)
if relation_senders:
clause, temp_args = make_in_list_sql_clause(
txn.database_engine, "sender", relation_senders
)
clauses.append(clause)
args.extend(temp_args)
if relation_types:
clause, temp_args = make_in_list_sql_clause(
txn.database_engine, "relation_type", relation_types
)
clauses.append(clause)
args.extend(temp_args)
txn.execute(sql % " AND ".join(clauses), args)
return [row[0] for row in txn]
return await self.db_pool.runInteraction(
"get_if_event_has_relations", _get_if_event_has_relations
)
async def has_user_annotated_event( async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str self, parent_id: str, event_type: str, aggregation_key: str, sender: str
) -> bool: ) -> bool:

View File

@ -272,31 +272,37 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
args = [] args = []
if event_filter.types: if event_filter.types:
clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types)) clauses.append(
"(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types)
)
args.extend(event_filter.types) args.extend(event_filter.types)
for typ in event_filter.not_types: for typ in event_filter.not_types:
clauses.append("type != ?") clauses.append("event.type != ?")
args.append(typ) args.append(typ)
if event_filter.senders: if event_filter.senders:
clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)) clauses.append(
"(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders)
)
args.extend(event_filter.senders) args.extend(event_filter.senders)
for sender in event_filter.not_senders: for sender in event_filter.not_senders:
clauses.append("sender != ?") clauses.append("event.sender != ?")
args.append(sender) args.append(sender)
if event_filter.rooms: if event_filter.rooms:
clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)) clauses.append(
"(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms)
)
args.extend(event_filter.rooms) args.extend(event_filter.rooms)
for room_id in event_filter.not_rooms: for room_id in event_filter.not_rooms:
clauses.append("room_id != ?") clauses.append("event.room_id != ?")
args.append(room_id) args.append(room_id)
if event_filter.contains_url: if event_filter.contains_url:
clauses.append("contains_url = ?") clauses.append("event.contains_url = ?")
args.append(event_filter.contains_url) args.append(event_filter.contains_url)
# We're only applying the "labels" filter on the database query, because applying the # We're only applying the "labels" filter on the database query, because applying the
@ -307,6 +313,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels)) clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
args.extend(event_filter.labels) args.extend(event_filter.labels)
# Filter on relation_senders / relation types from the joined tables.
if event_filter.relation_senders:
clauses.append(
"(%s)"
% " OR ".join(
"related_event.sender = ?" for _ in event_filter.relation_senders
)
)
args.extend(event_filter.relation_senders)
if event_filter.relation_types:
clauses.append(
"(%s)"
% " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
)
args.extend(event_filter.relation_types)
return " AND ".join(clauses), args return " AND ".join(clauses), args
@ -1116,7 +1139,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds = generate_pagination_where_clause( bounds = generate_pagination_where_clause(
direction=direction, direction=direction,
column_names=("topological_ordering", "stream_ordering"), column_names=("event.topological_ordering", "event.stream_ordering"),
from_token=from_bound, from_token=from_bound,
to_token=to_bound, to_token=to_bound,
engine=self.database_engine, engine=self.database_engine,
@ -1133,32 +1156,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
select_keywords = "SELECT" select_keywords = "SELECT"
join_clause = "" join_clause = ""
# Using DISTINCT in this SELECT query is quite expensive, because it
# requires the engine to sort on the entire (not limited) result set,
# i.e. the entire events table. Only use it in scenarios that could result
# in the same event ID occurring multiple times in the results.
needs_distinct = False
if event_filter and event_filter.labels: if event_filter and event_filter.labels:
# If we're not filtering on a label, then joining on event_labels will # If we're not filtering on a label, then joining on event_labels will
# return as many row for a single event as the number of labels it has. To # return as many row for a single event as the number of labels it has. To
# avoid this, only join if we're filtering on at least one label. # avoid this, only join if we're filtering on at least one label.
join_clause = """ join_clause += """
LEFT JOIN event_labels LEFT JOIN event_labels
USING (event_id, room_id, topological_ordering) USING (event_id, room_id, topological_ordering)
""" """
if len(event_filter.labels) > 1: if len(event_filter.labels) > 1:
# Using DISTINCT in this SELECT query is quite expensive, because it # Multiple labels could cause the same event to appear multiple times.
# requires the engine to sort on the entire (not limited) result set, needs_distinct = True
# i.e. the entire events table. We only need to use it when we're
# filtering on more than two labels, because that's the only scenario # If there is a filter on relation_senders and relation_types join to the
# in which we can possibly to get multiple times the same event ID in # relations table.
# the results. if event_filter and (
event_filter.relation_senders or event_filter.relation_types
):
# Filtering by relations could cause the same event to appear multiple
# times (since there's no limit on the number of relations to an event).
needs_distinct = True
join_clause += """
LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
"""
if event_filter.relation_senders:
join_clause += """
LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
"""
if needs_distinct:
select_keywords += " DISTINCT" select_keywords += " DISTINCT"
sql = """ sql = """
%(select_keywords)s %(select_keywords)s
event_id, instance_name, event.event_id, event.instance_name,
topological_ordering, stream_ordering event.topological_ordering, event.stream_ordering
FROM events FROM events AS event
%(join_clause)s %(join_clause)s
WHERE outlier = ? AND room_id = ? AND %(bounds)s WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s
ORDER BY topological_ordering %(order)s, ORDER BY event.topological_ordering %(order)s,
stream_ordering %(order)s LIMIT ? event.stream_ordering %(order)s LIMIT ?
""" % { """ % {
"select_keywords": select_keywords, "select_keywords": select_keywords,
"join_clause": join_clause, "join_clause": join_clause,

View File

@ -15,6 +15,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest.mock import patch
import jsonschema import jsonschema
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
@ -51,9 +53,8 @@ class FilteringTestCase(unittest.HomeserverTestCase):
{"presence": {"senders": ["@bar;pik.test.com"]}}, {"presence": {"senders": ["@bar;pik.test.com"]}},
] ]
for filter in invalid_filters: for filter in invalid_filters:
with self.assertRaises(SynapseError) as check_filter_error: with self.assertRaises(SynapseError):
self.filtering.check_valid_filter(filter) self.filtering.check_valid_filter(filter)
self.assertIsInstance(check_filter_error.exception, SynapseError)
def test_valid_filters(self): def test_valid_filters(self):
valid_filters = [ valid_filters = [
@ -119,12 +120,12 @@ class FilteringTestCase(unittest.HomeserverTestCase):
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_types_works_with_wildcards(self): def test_definition_types_works_with_wildcards(self):
definition = {"types": ["m.*", "org.matrix.foo.bar"]} definition = {"types": ["m.*", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_types_works_with_unknowns(self): def test_definition_types_works_with_unknowns(self):
definition = {"types": ["m.room.message", "org.matrix.foo.bar"]} definition = {"types": ["m.room.message", "org.matrix.foo.bar"]}
@ -133,24 +134,24 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="now.for.something.completely.different", type="now.for.something.completely.different",
room_id="!foo:bar", room_id="!foo:bar",
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_literals(self): def test_definition_not_types_works_with_literals(self):
definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]} definition = {"not_types": ["m.room.message", "org.matrix.foo.bar"]}
event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.message", room_id="!foo:bar")
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_wildcards(self): def test_definition_not_types_works_with_wildcards(self):
definition = {"not_types": ["m.room.message", "org.matrix.*"]} definition = {"not_types": ["m.room.message", "org.matrix.*"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar" sender="@foo:bar", type="org.matrix.custom.event", room_id="!foo:bar"
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_types_works_with_unknowns(self): def test_definition_not_types_works_with_unknowns(self):
definition = {"not_types": ["m.*", "org.*"]} definition = {"not_types": ["m.*", "org.*"]}
event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="com.nom.nom.nom", room_id="!foo:bar")
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_types_takes_priority_over_types(self): def test_definition_not_types_takes_priority_over_types(self):
definition = { definition = {
@ -158,35 +159,35 @@ class FilteringTestCase(unittest.HomeserverTestCase):
"types": ["m.room.message", "m.room.topic"], "types": ["m.room.message", "m.room.topic"],
} }
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar") event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_senders_works_with_literals(self): def test_definition_senders_works_with_literals(self):
definition = {"senders": ["@flibble:wibble"]} definition = {"senders": ["@flibble:wibble"]}
event = MockEvent( event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
) )
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_senders_works_with_unknowns(self): def test_definition_senders_works_with_unknowns(self):
definition = {"senders": ["@flibble:wibble"]} definition = {"senders": ["@flibble:wibble"]}
event = MockEvent( event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_works_with_literals(self): def test_definition_not_senders_works_with_literals(self):
definition = {"not_senders": ["@flibble:wibble"]} definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent( event = MockEvent(
sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar" sender="@flibble:wibble", type="com.nom.nom.nom", room_id="!foo:bar"
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_works_with_unknowns(self): def test_definition_not_senders_works_with_unknowns(self):
definition = {"not_senders": ["@flibble:wibble"]} definition = {"not_senders": ["@flibble:wibble"]}
event = MockEvent( event = MockEvent(
sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar" sender="@challenger:appears", type="com.nom.nom.nom", room_id="!foo:bar"
) )
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_senders_takes_priority_over_senders(self): def test_definition_not_senders_takes_priority_over_senders(self):
definition = { definition = {
@ -196,14 +197,14 @@ class FilteringTestCase(unittest.HomeserverTestCase):
event = MockEvent( event = MockEvent(
sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar" sender="@misspiggy:muppets", type="m.room.topic", room_id="!foo:bar"
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_rooms_works_with_literals(self): def test_definition_rooms_works_with_literals(self):
definition = {"rooms": ["!secretbase:unknown"]} definition = {"rooms": ["!secretbase:unknown"]}
event = MockEvent( event = MockEvent(
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
) )
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_rooms_works_with_unknowns(self): def test_definition_rooms_works_with_unknowns(self):
definition = {"rooms": ["!secretbase:unknown"]} definition = {"rooms": ["!secretbase:unknown"]}
@ -212,7 +213,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", type="m.room.message",
room_id="!anothersecretbase:unknown", room_id="!anothersecretbase:unknown",
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_works_with_literals(self): def test_definition_not_rooms_works_with_literals(self):
definition = {"not_rooms": ["!anothersecretbase:unknown"]} definition = {"not_rooms": ["!anothersecretbase:unknown"]}
@ -221,7 +222,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", type="m.room.message",
room_id="!anothersecretbase:unknown", room_id="!anothersecretbase:unknown",
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_works_with_unknowns(self): def test_definition_not_rooms_works_with_unknowns(self):
definition = {"not_rooms": ["!secretbase:unknown"]} definition = {"not_rooms": ["!secretbase:unknown"]}
@ -230,7 +231,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", type="m.room.message",
room_id="!anothersecretbase:unknown", room_id="!anothersecretbase:unknown",
) )
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_not_rooms_takes_priority_over_rooms(self): def test_definition_not_rooms_takes_priority_over_rooms(self):
definition = { definition = {
@ -240,7 +241,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
event = MockEvent( event = MockEvent(
sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown" sender="@foo:bar", type="m.room.message", room_id="!secretbase:unknown"
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event(self): def test_definition_combined_event(self):
definition = { definition = {
@ -256,7 +257,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", # yup type="m.room.message", # yup
room_id="!stage:unknown", # yup room_id="!stage:unknown", # yup
) )
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_sender(self): def test_definition_combined_event_bad_sender(self):
definition = { definition = {
@ -272,7 +273,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", # yup type="m.room.message", # yup
room_id="!stage:unknown", # yup room_id="!stage:unknown", # yup
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_room(self): def test_definition_combined_event_bad_room(self):
definition = { definition = {
@ -288,7 +289,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="m.room.message", # yup type="m.room.message", # yup
room_id="!piggyshouse:muppets", # nope room_id="!piggyshouse:muppets", # nope
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_definition_combined_event_bad_type(self): def test_definition_combined_event_bad_type(self):
definition = { definition = {
@ -304,7 +305,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
type="muppets.misspiggy.kisses", # nope type="muppets.misspiggy.kisses", # nope
room_id="!stage:unknown", # yup room_id="!stage:unknown", # yup
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_filter_labels(self): def test_filter_labels(self):
definition = {"org.matrix.labels": ["#fun"]} definition = {"org.matrix.labels": ["#fun"]}
@ -315,7 +316,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#fun"]}, content={EventContentFields.LABELS: ["#fun"]},
) )
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -324,7 +325,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#notfun"]}, content={EventContentFields.LABELS: ["#notfun"]},
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
def test_filter_not_labels(self): def test_filter_not_labels(self):
definition = {"org.matrix.not_labels": ["#fun"]} definition = {"org.matrix.not_labels": ["#fun"]}
@ -335,7 +336,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#fun"]}, content={EventContentFields.LABELS: ["#fun"]},
) )
self.assertFalse(Filter(definition).check(event)) self.assertFalse(Filter(self.hs, definition)._check(event))
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -344,7 +345,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
content={EventContentFields.LABELS: ["#notfun"]}, content={EventContentFields.LABELS: ["#notfun"]},
) )
self.assertTrue(Filter(definition).check(event)) self.assertTrue(Filter(self.hs, definition)._check(event))
def test_filter_presence_match(self): def test_filter_presence_match(self):
user_filter_json = {"presence": {"types": ["m.*"]}} user_filter_json = {"presence": {"types": ["m.*"]}}
@ -362,7 +363,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
) )
results = user_filter.filter_presence(events=events) results = self.get_success(user_filter.filter_presence(events=events))
self.assertEquals(events, results) self.assertEquals(events, results)
def test_filter_presence_no_match(self): def test_filter_presence_no_match(self):
@ -386,7 +387,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
) )
results = user_filter.filter_presence(events=events) results = self.get_success(user_filter.filter_presence(events=events))
self.assertEquals([], results) self.assertEquals([], results)
def test_filter_room_state_match(self): def test_filter_room_state_match(self):
@ -405,7 +406,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
) )
results = user_filter.filter_room_state(events=events) results = self.get_success(user_filter.filter_room_state(events=events))
self.assertEquals(events, results) self.assertEquals(events, results)
def test_filter_room_state_no_match(self): def test_filter_room_state_no_match(self):
@ -426,7 +427,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
) )
) )
results = user_filter.filter_room_state(events) results = self.get_success(user_filter.filter_room_state(events))
self.assertEquals([], results) self.assertEquals([], results)
def test_filter_rooms(self): def test_filter_rooms(self):
@ -441,10 +442,52 @@ class FilteringTestCase(unittest.HomeserverTestCase):
"!not_included:example.com", # Disallowed because not in rooms. "!not_included:example.com", # Disallowed because not in rooms.
] ]
filtered_room_ids = list(Filter(definition).filter_rooms(room_ids)) filtered_room_ids = list(Filter(self.hs, definition).filter_rooms(room_ids))
self.assertEquals(filtered_room_ids, ["!allowed:example.com"]) self.assertEquals(filtered_room_ids, ["!allowed:example.com"])
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_filter_relations(self):
events = [
# An event without a relation.
MockEvent(
event_id="$no_relation",
sender="@foo:bar",
type="org.matrix.custom.event",
room_id="!foo:bar",
),
# An event with a relation.
MockEvent(
event_id="$with_relation",
sender="@foo:bar",
type="org.matrix.custom.event",
room_id="!foo:bar",
),
# Non-EventBase objects get passed through.
{},
]
# For the following tests we patch the datastore method (intead of injecting
# events). This is a bit cheeky, but tests the logic of _check_event_relations.
# Filter for a particular sender.
definition = {
"io.element.relation_senders": ["@foo:bar"],
}
async def events_have_relations(*args, **kwargs):
return ["$with_relation"]
with patch.object(
self.datastore, "events_have_relations", new=events_have_relations
):
filtered_events = list(
self.get_success(
Filter(self.hs, definition)._check_event_relations(events)
)
)
self.assertEquals(filtered_events, events[1:])
def test_add_filter(self): def test_add_filter(self):
user_filter_json = {"room": {"state": {"types": ["m.*"]}}} user_filter_json = {"room": {"state": {"types": ["m.*"]}}}

View File

@ -13,10 +13,11 @@
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Optional
from unittest.mock import Mock
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.api.filtering import Filtering
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.rest import admin from synapse.rest import admin
@ -197,7 +198,7 @@ def generate_sync_config(
_request_key += 1 _request_key += 1
return SyncConfig( return SyncConfig(
user=UserID.from_string(user_id), user=UserID.from_string(user_id),
filter_collection=DEFAULT_FILTER_COLLECTION, filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION,
is_guest=False, is_guest=False,
request_key=("request_key", _request_key), request_key=("request_key", _request_key),
device_id=device_id, device_id=device_id,

View File

@ -25,7 +25,12 @@ from urllib import parse as urlparse
from twisted.internet import defer from twisted.internet import defer
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import (
EventContentFields,
EventTypes,
Membership,
RelationTypes,
)
from synapse.api.errors import Codes, HttpResponseException from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin from synapse.rest import admin
@ -2157,6 +2162,153 @@ class LabelsTestCase(unittest.HomeserverTestCase):
return event_id return event_id
class RelationsTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
def default_config(self):
config = super().default_config()
config["experimental_features"] = {"msc3440_enabled": True}
return config
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
self.second_user_id = self.register_user("second", "test")
self.second_tok = self.login("second", "test")
self.helper.join(
room=self.room_id, user=self.second_user_id, tok=self.second_tok
)
self.third_user_id = self.register_user("third", "test")
self.third_tok = self.login("third", "test")
self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
# An initial event with a relation from second user.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "Message 1"},
tok=self.tok,
)
self.event_id_1 = res["event_id"]
self.helper.send_event(
room_id=self.room_id,
type="m.reaction",
content={
"m.relates_to": {
"rel_type": RelationTypes.ANNOTATION,
"event_id": self.event_id_1,
"key": "👍",
}
},
tok=self.second_tok,
)
# Another event with a relation from third user.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "Message 2"},
tok=self.tok,
)
self.event_id_2 = res["event_id"]
self.helper.send_event(
room_id=self.room_id,
type="m.reaction",
content={
"m.relates_to": {
"rel_type": RelationTypes.REFERENCE,
"event_id": self.event_id_2,
}
},
tok=self.third_tok,
)
# An event with no relations.
self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "No relations"},
tok=self.tok,
)
def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:
"""Make a request to /messages with a filter, returns the chunk of events."""
channel = self.make_request(
"GET",
"/rooms/%s/messages?filter=%s&dir=b" % (self.room_id, json.dumps(filter)),
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["chunk"]
def test_filter_relation_senders(self):
# Messages which second user reacted to.
filter = {"io.element.relation_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
# Messages which third user reacted to.
filter = {"io.element.relation_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)
# Messages which either user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id, self.third_user_id]
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
)
def test_filter_relation_type(self):
# Messages which have annotations.
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
# Messages which have references.
filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)
# Messages which have either annotations or references.
filter = {
"io.element.relation_types": [
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
]
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
)
def test_filter_relation_senders_and_type(self):
# Messages which second user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id],
"io.element.relation_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)
class ContextTestCase(unittest.HomeserverTestCase): class ContextTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [

View File

@ -0,0 +1,207 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.types import JsonDict
from tests.unittest import HomeserverTestCase
class PaginationTestCase(HomeserverTestCase):
"""
Test the pre-filtering done in the pagination code.
This is similar to some of the tests in tests.rest.client.test_rooms but here
we ensure that the filtering done in the database is applied successfully.
"""
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
def default_config(self):
config = super().default_config()
config["experimental_features"] = {"msc3440_enabled": True}
return config
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
self.second_user_id = self.register_user("second", "test")
self.second_tok = self.login("second", "test")
self.helper.join(
room=self.room_id, user=self.second_user_id, tok=self.second_tok
)
self.third_user_id = self.register_user("third", "test")
self.third_tok = self.login("third", "test")
self.helper.join(room=self.room_id, user=self.third_user_id, tok=self.third_tok)
# An initial event with a relation from second user.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "Message 1"},
tok=self.tok,
)
self.event_id_1 = res["event_id"]
self.helper.send_event(
room_id=self.room_id,
type="m.reaction",
content={
"m.relates_to": {
"rel_type": RelationTypes.ANNOTATION,
"event_id": self.event_id_1,
"key": "👍",
}
},
tok=self.second_tok,
)
# Another event with a relation from third user.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "Message 2"},
tok=self.tok,
)
self.event_id_2 = res["event_id"]
self.helper.send_event(
room_id=self.room_id,
type="m.reaction",
content={
"m.relates_to": {
"rel_type": RelationTypes.REFERENCE,
"event_id": self.event_id_2,
}
},
tok=self.third_tok,
)
# An event with no relations.
self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "No relations"},
tok=self.tok,
)
def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
"""Make a request to /messages with a filter, returns the chunk of events."""
from_token = self.get_success(
self.hs.get_event_sources().get_current_token_for_pagination()
)
events, next_key = self.get_success(
self.hs.get_datastore().paginate_room_events(
room_id=self.room_id,
from_key=from_token.room_key,
to_key=None,
direction="b",
limit=10,
event_filter=Filter(self.hs, filter),
)
)
return events
def test_filter_relation_senders(self):
# Messages which second user reacted to.
filter = {"io.element.relation_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_1)
# Messages which third user reacted to.
filter = {"io.element.relation_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_2)
# Messages which either user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id, self.third_user_id]
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
[c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
)
def test_filter_relation_type(self):
# Messages which have annotations.
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_1)
# Messages which have references.
filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_2)
# Messages which have either annotations or references.
filter = {
"io.element.relation_types": [
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
]
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
[c.event_id for c in chunk], [self.event_id_1, self.event_id_2]
)
def test_filter_relation_senders_and_type(self):
# Messages which second user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id],
"io.element.relation_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_1)
def test_duplicate_relation(self):
"""An event should only be returned once if there are multiple relations to it."""
self.helper.send_event(
room_id=self.room_id,
type="m.reaction",
content={
"m.relates_to": {
"rel_type": RelationTypes.ANNOTATION,
"event_id": self.event_id_1,
"key": "A",
}
},
tok=self.second_tok,
)
filter = {"io.element.relation_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0].event_id, self.event_id_1)