Support filtering by relations per MSC3440 (#11236)

Adds experimental support for `relation_types` and `relation_senders`
fields for filters.
This commit is contained in:
Patrick Cloke 2021-11-09 08:10:58 -05:00 committed by GitHub
parent 4b3e30c276
commit a19d01c3d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 680 additions and 110 deletions

View file

@ -1,7 +1,7 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -86,6 +86,9 @@ ROOM_EVENT_FILTER_SCHEMA = {
# cf https://github.com/matrix-org/matrix-doc/pull/2326
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
# MSC3440, filtering by event relations.
"io.element.relation_senders": {"type": "array", "items": {"type": "string"}},
"io.element.relation_types": {"type": "array", "items": {"type": "string"}},
},
}
@ -146,14 +149,16 @@ def matrix_user_id_validator(user_id_str: str) -> UserID:
class Filtering:
def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
self.store = hs.get_datastore()
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str]
) -> "FilterCollection":
result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result)
return FilterCollection(self._hs, result)
def add_user_filter(
self, user_localpart: str, user_filter: JsonDict
@ -191,21 +196,22 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
class FilterCollection:
def __init__(self, filter_json: JsonDict):
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
self._filter_json = filter_json
room_filter_json = self._filter_json.get("room", {})
self._room_filter = Filter(
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
hs,
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")},
)
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(room_filter_json.get("state", {}))
self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
self._room_account_data = Filter(room_filter_json.get("account_data", {}))
self._presence_filter = Filter(filter_json.get("presence", {}))
self._account_data = Filter(filter_json.get("account_data", {}))
self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
self._presence_filter = Filter(hs, filter_json.get("presence", {}))
self._account_data = Filter(hs, filter_json.get("account_data", {}))
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
self.event_fields = filter_json.get("event_fields", [])
@ -232,25 +238,37 @@ class FilterCollection:
def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members
def filter_presence(
async def filter_presence(
self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
return self._presence_filter.filter(events)
return await self._presence_filter.filter(events)
def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._account_data.filter(events)
async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return await self._account_data.filter(events)
def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_state_filter.filter(self._room_filter.filter(events))
async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return await self._room_state_filter.filter(
await self._room_filter.filter(events)
)
def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
return self._room_timeline_filter.filter(self._room_filter.filter(events))
async def filter_room_timeline(
self, events: Iterable[EventBase]
) -> List[EventBase]:
return await self._room_timeline_filter.filter(
await self._room_filter.filter(events)
)
def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return await self._room_ephemeral_filter.filter(
await self._room_filter.filter(events)
)
def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
return self._room_account_data.filter(self._room_filter.filter(events))
async def filter_room_account_data(
self, events: Iterable[JsonDict]
) -> List[JsonDict]:
return await self._room_account_data.filter(
await self._room_filter.filter(events)
)
def blocks_all_presence(self) -> bool:
return (
@ -274,7 +292,9 @@ class FilterCollection:
class Filter:
def __init__(self, filter_json: JsonDict):
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
self._hs = hs
self._store = hs.get_datastore()
self.filter_json = filter_json
self.limit = filter_json.get("limit", 10)
@ -297,6 +317,20 @@ class Filter:
self.labels = filter_json.get("org.matrix.labels", None)
self.not_labels = filter_json.get("org.matrix.not_labels", [])
# Ideally these would be rejected at the endpoint if they were provided
# and not supported, but that would involve modifying the JSON schema
# based on the homeserver configuration.
if hs.config.experimental.msc3440_enabled:
self.relation_senders = self.filter_json.get(
"io.element.relation_senders", None
)
self.relation_types = self.filter_json.get(
"io.element.relation_types", None
)
else:
self.relation_senders = None
self.relation_types = None
def filters_all_types(self) -> bool:
return "*" in self.not_types
@ -306,7 +340,7 @@ class Filter:
def filters_all_rooms(self) -> bool:
return "*" in self.not_rooms
def check(self, event: FilterEvent) -> bool:
def _check(self, event: FilterEvent) -> bool:
"""Checks whether the filter matches the given event.
Args:
@ -420,8 +454,30 @@ class Filter:
return room_ids
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return list(filter(self.check, events))
async def _check_event_relations(
self, events: Iterable[FilterEvent]
) -> List[FilterEvent]:
# The event IDs to check, mypy doesn't understand the ifinstance check.
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
event_ids_to_keep = set(
await self._store.events_have_relations(
event_ids, self.relation_senders, self.relation_types
)
)
return [
event
for event in events
if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep
]
async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
result = [event for event in events if self._check(event)]
if self.relation_senders or self.relation_types:
return await self._check_event_relations(result)
return result
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
"""Returns a new filter with the given room IDs appended.
@ -433,7 +489,7 @@ class Filter:
filter: A new filter including the given rooms and the old
filter's rooms.
"""
newFilter = Filter(self.filter_json)
newFilter = Filter(self._hs, self.filter_json)
newFilter.rooms += room_ids
return newFilter
@ -444,6 +500,3 @@ def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
return actual_value.startswith(type_prefix)
else:
return actual_value == filter_value
DEFAULT_FILTER_COLLECTION = FilterCollection({})