Refactor Filter to handle fields according to data being filtered. (#11194)

This avoids filtering against fields which cannot exist on an
event source. E.g. presence updates don't have a room.
This commit is contained in:
Patrick Cloke 2021-10-27 11:26:30 -04:00 committed by GitHub
parent 8d46fac98e
commit 19d5dc6931
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 68 deletions

1
changelog.d/11194.misc Normal file
View File

@ -0,0 +1 @@
Refactor `Filter` to check different fields depending on the data type.

View File

@ -18,7 +18,8 @@ import json
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Awaitable, Awaitable,
Container, Callable,
Dict,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -217,19 +218,19 @@ class FilterCollection:
return self._filter_json return self._filter_json
def timeline_limit(self) -> int: def timeline_limit(self) -> int:
return self._room_timeline_filter.limit() return self._room_timeline_filter.limit
def presence_limit(self) -> int: def presence_limit(self) -> int:
return self._presence_filter.limit() return self._presence_filter.limit
def ephemeral_limit(self) -> int: def ephemeral_limit(self) -> int:
return self._room_ephemeral_filter.limit() return self._room_ephemeral_filter.limit
def lazy_load_members(self) -> bool: def lazy_load_members(self) -> bool:
return self._room_state_filter.lazy_load_members() return self._room_state_filter.lazy_load_members
def include_redundant_members(self) -> bool: def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members() return self._room_state_filter.include_redundant_members
def filter_presence( def filter_presence(
self, events: Iterable[UserPresenceState] self, events: Iterable[UserPresenceState]
@ -276,19 +277,25 @@ class Filter:
def __init__(self, filter_json: JsonDict): def __init__(self, filter_json: JsonDict):
self.filter_json = filter_json self.filter_json = filter_json
self.types = self.filter_json.get("types", None) self.limit = filter_json.get("limit", 10)
self.not_types = self.filter_json.get("not_types", []) self.lazy_load_members = filter_json.get("lazy_load_members", False)
self.include_redundant_members = filter_json.get(
"include_redundant_members", False
)
self.rooms = self.filter_json.get("rooms", None) self.types = filter_json.get("types", None)
self.not_rooms = self.filter_json.get("not_rooms", []) self.not_types = filter_json.get("not_types", [])
self.senders = self.filter_json.get("senders", None) self.rooms = filter_json.get("rooms", None)
self.not_senders = self.filter_json.get("not_senders", []) self.not_rooms = filter_json.get("not_rooms", [])
self.contains_url = self.filter_json.get("contains_url", None) self.senders = filter_json.get("senders", None)
self.not_senders = filter_json.get("not_senders", [])
self.labels = self.filter_json.get("org.matrix.labels", None) self.contains_url = filter_json.get("contains_url", None)
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
self.labels = filter_json.get("org.matrix.labels", None)
self.not_labels = filter_json.get("org.matrix.not_labels", [])
def filters_all_types(self) -> bool: def filters_all_types(self) -> bool:
return "*" in self.not_types return "*" in self.not_types
@ -302,76 +309,95 @@ class Filter:
def check(self, event: FilterEvent) -> bool: def check(self, event: FilterEvent) -> bool:
"""Checks whether the filter matches the given event. """Checks whether the filter matches the given event.
Args:
event: The event, account data, or presence to check against this
filter.
Returns: Returns:
True if the event matches True if the event matches the filter.
""" """
# We usually get the full "events" as dictionaries coming through, # We usually get the full "events" as dictionaries coming through,
# except for presence which actually gets passed around as its own # except for presence which actually gets passed around as its own
# namedtuple type. # namedtuple type.
if isinstance(event, UserPresenceState): if isinstance(event, UserPresenceState):
sender: Optional[str] = event.user_id user_id = event.user_id
room_id = None field_matchers = {
ev_type = "m.presence" "senders": lambda v: user_id == v,
contains_url = False "types": lambda v: "m.presence" == v,
labels: List[str] = [] }
return self._check_fields(field_matchers)
else: else:
content = event.get("content")
# Content is assumed to be a dict below, so ensure it is. This should
# always be true for events, but account_data has been allowed to
# have non-dict content.
if not isinstance(content, dict):
content = {}
sender = event.get("sender", None) sender = event.get("sender", None)
if not sender: if not sender:
# Presence events had their 'sender' in content.user_id, but are # Presence events had their 'sender' in content.user_id, but are
# now handled above. We don't know if anything else uses this # now handled above. We don't know if anything else uses this
# form. TODO: Check this and probably remove it. # form. TODO: Check this and probably remove it.
content = event.get("content") sender = content.get("user_id")
# account_data has been allowed to have non-dict content, so
# check type first
if isinstance(content, dict):
sender = content.get("user_id")
room_id = event.get("room_id", None) room_id = event.get("room_id", None)
ev_type = event.get("type", None) ev_type = event.get("type", None)
content = event.get("content") or {}
# check if there is a string url field in the content for filtering purposes # check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), str)
labels = content.get(EventContentFields.LABELS, []) labels = content.get(EventContentFields.LABELS, [])
return self.check_fields(room_id, sender, ev_type, labels, contains_url) field_matchers = {
"rooms": lambda v: room_id == v,
"senders": lambda v: sender == v,
"types": lambda v: _matches_wildcard(ev_type, v),
"labels": lambda v: v in labels,
}
def check_fields( result = self._check_fields(field_matchers)
self, if not result:
room_id: Optional[str], return result
sender: Optional[str],
event_type: Optional[str], contains_url_filter = self.contains_url
labels: Container[str], if contains_url_filter is not None:
contains_url: bool, contains_url = isinstance(content.get("url"), str)
) -> bool: if contains_url_filter != contains_url:
return False
return True
def _check_fields(self, field_matchers: Dict[str, Callable[[str], bool]]) -> bool:
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.
Args:
field_matchers: A map of attribute name to callable to use for checking
particular fields.
The attribute name and an inverse (not_<attribute name>) must
exist on the Filter.
The callable should return true if the event's value matches the
filter's value.
Returns: Returns:
True if the event fields match True if the event fields match
""" """
literal_keys = {
"rooms": lambda v: room_id == v,
"senders": lambda v: sender == v,
"types": lambda v: _matches_wildcard(event_type, v),
"labels": lambda v: v in labels,
}
for name, match_func in literal_keys.items(): for name, match_func in field_matchers.items():
# If the event matches one of the disallowed values, reject it.
not_name = "not_%s" % (name,) not_name = "not_%s" % (name,)
disallowed_values = getattr(self, not_name) disallowed_values = getattr(self, not_name)
if any(map(match_func, disallowed_values)): if any(map(match_func, disallowed_values)):
return False return False
# Other the event does not match at least one of the allowed values,
# reject it.
allowed_values = getattr(self, name) allowed_values = getattr(self, name)
if allowed_values is not None: if allowed_values is not None:
if not any(map(match_func, allowed_values)): if not any(map(match_func, allowed_values)):
return False return False
contains_url_filter = self.filter_json.get("contains_url") # Otherwise, accept it.
if contains_url_filter is not None:
if contains_url_filter != contains_url:
return False
return True return True
def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]: def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
@ -385,10 +411,10 @@ class Filter:
""" """
room_ids = set(room_ids) room_ids = set(room_ids)
disallowed_rooms = set(self.filter_json.get("not_rooms", [])) disallowed_rooms = set(self.not_rooms)
room_ids -= disallowed_rooms room_ids -= disallowed_rooms
allowed_rooms = self.filter_json.get("rooms", None) allowed_rooms = self.rooms
if allowed_rooms is not None: if allowed_rooms is not None:
room_ids &= set(allowed_rooms) room_ids &= set(allowed_rooms)
@ -397,15 +423,6 @@ class Filter:
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return list(filter(self.check, events)) return list(filter(self.check, events))
def limit(self) -> int:
return self.filter_json.get("limit", 10)
def lazy_load_members(self) -> bool:
return self.filter_json.get("lazy_load_members", False)
def include_redundant_members(self) -> bool:
return self.filter_json.get("include_redundant_members", False)
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter": def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
"""Returns a new filter with the given room IDs appended. """Returns a new filter with the given room IDs appended.

View File

@ -438,7 +438,7 @@ class PaginationHandler:
} }
state = None state = None
if event_filter and event_filter.lazy_load_members() and len(events) > 0: if event_filter and event_filter.lazy_load_members and len(events) > 0:
# TODO: remove redundant members # TODO: remove redundant members
# FIXME: we also care about invite targets etc. # FIXME: we also care about invite targets etc.

View File

@ -1173,7 +1173,7 @@ class RoomContextHandler:
else: else:
last_event_id = event_id last_event_id = event_id
if event_filter and event_filter.lazy_load_members(): if event_filter and event_filter.lazy_load_members:
state_filter = StateFilter.from_lazy_load_member_list( state_filter = StateFilter.from_lazy_load_member_list(
ev.sender ev.sender
for ev in itertools.chain( for ev in itertools.chain(

View File

@ -249,7 +249,7 @@ class SearchHandler:
) )
events.sort(key=lambda e: -rank_map[e.event_id]) events.sort(key=lambda e: -rank_map[e.event_id])
allowed_events = events[: search_filter.limit()] allowed_events = events[: search_filter.limit]
for e in allowed_events: for e in allowed_events:
rm = room_groups.setdefault( rm = room_groups.setdefault(
@ -271,13 +271,13 @@ class SearchHandler:
# We keep looping and we keep filtering until we reach the limit # We keep looping and we keep filtering until we reach the limit
# or we run out of things. # or we run out of things.
# But only go around 5 times since otherwise synapse will be sad. # But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5: while len(room_events) < search_filter.limit and i < 5:
i += 1 i += 1
search_result = await self.store.search_rooms( search_result = await self.store.search_rooms(
room_ids, room_ids,
search_term, search_term,
keys, keys,
search_filter.limit() * 2, search_filter.limit * 2,
pagination_token=pagination_token, pagination_token=pagination_token,
) )
@ -299,9 +299,9 @@ class SearchHandler:
) )
room_events.extend(events) room_events.extend(events)
room_events = room_events[: search_filter.limit()] room_events = room_events[: search_filter.limit]
if len(results) < search_filter.limit() * 2: if len(results) < search_filter.limit * 2:
pagination_token = None pagination_token = None
break break
else: else:
@ -311,7 +311,7 @@ class SearchHandler:
group = room_groups.setdefault(event.room_id, {"results": []}) group = room_groups.setdefault(event.room_id, {"results": []})
group["results"].append(event.event_id) group["results"].append(event.event_id)
if room_events and len(room_events) >= search_filter.limit(): if room_events and len(room_events) >= search_filter.limit:
last_event_id = room_events[-1].event_id last_event_id = room_events[-1].event_id
pagination_token = results_map[last_event_id]["pagination_token"] pagination_token = results_map[last_event_id]["pagination_token"]