mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-06 10:04:11 -04:00
Support filtering by relations per MSC3440 (#11236)
Adds experimental support for `relation_types` and `relation_senders` fields for filters.
This commit is contained in:
parent
4b3e30c276
commit
a19d01c3d9
15 changed files with 680 additions and 110 deletions
|
@ -1,7 +1,7 @@
|
|||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2018-2019 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -86,6 +86,9 @@ ROOM_EVENT_FILTER_SCHEMA = {
|
|||
# cf https://github.com/matrix-org/matrix-doc/pull/2326
|
||||
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
|
||||
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
|
||||
# MSC3440, filtering by event relations.
|
||||
"io.element.relation_senders": {"type": "array", "items": {"type": "string"}},
|
||||
"io.element.relation_types": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -146,14 +149,16 @@ def matrix_user_id_validator(user_id_str: str) -> UserID:
|
|||
|
||||
class Filtering:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
|
||||
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
) -> "FilterCollection":
|
||||
result = await self.store.get_user_filter(user_localpart, filter_id)
|
||||
return FilterCollection(result)
|
||||
return FilterCollection(self._hs, result)
|
||||
|
||||
def add_user_filter(
|
||||
self, user_localpart: str, user_filter: JsonDict
|
||||
|
@ -191,21 +196,22 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
|
|||
|
||||
|
||||
class FilterCollection:
|
||||
def __init__(self, filter_json: JsonDict):
|
||||
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
|
||||
self._filter_json = filter_json
|
||||
|
||||
room_filter_json = self._filter_json.get("room", {})
|
||||
|
||||
self._room_filter = Filter(
|
||||
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
|
||||
hs,
|
||||
{k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")},
|
||||
)
|
||||
|
||||
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
|
||||
self._room_state_filter = Filter(room_filter_json.get("state", {}))
|
||||
self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
|
||||
self._room_account_data = Filter(room_filter_json.get("account_data", {}))
|
||||
self._presence_filter = Filter(filter_json.get("presence", {}))
|
||||
self._account_data = Filter(filter_json.get("account_data", {}))
|
||||
self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
|
||||
self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
|
||||
self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
|
||||
self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
|
||||
self._presence_filter = Filter(hs, filter_json.get("presence", {}))
|
||||
self._account_data = Filter(hs, filter_json.get("account_data", {}))
|
||||
|
||||
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
|
||||
self.event_fields = filter_json.get("event_fields", [])
|
||||
|
@ -232,25 +238,37 @@ class FilterCollection:
|
|||
def include_redundant_members(self) -> bool:
|
||||
return self._room_state_filter.include_redundant_members
|
||||
|
||||
def filter_presence(
|
||||
async def filter_presence(
|
||||
self, events: Iterable[UserPresenceState]
|
||||
) -> List[UserPresenceState]:
|
||||
return self._presence_filter.filter(events)
|
||||
return await self._presence_filter.filter(events)
|
||||
|
||||
def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
||||
return self._account_data.filter(events)
|
||||
async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
||||
return await self._account_data.filter(events)
|
||||
|
||||
def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
|
||||
return self._room_state_filter.filter(self._room_filter.filter(events))
|
||||
async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
|
||||
return await self._room_state_filter.filter(
|
||||
await self._room_filter.filter(events)
|
||||
)
|
||||
|
||||
def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
|
||||
return self._room_timeline_filter.filter(self._room_filter.filter(events))
|
||||
async def filter_room_timeline(
|
||||
self, events: Iterable[EventBase]
|
||||
) -> List[EventBase]:
|
||||
return await self._room_timeline_filter.filter(
|
||||
await self._room_filter.filter(events)
|
||||
)
|
||||
|
||||
def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
||||
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
|
||||
async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
||||
return await self._room_ephemeral_filter.filter(
|
||||
await self._room_filter.filter(events)
|
||||
)
|
||||
|
||||
def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
|
||||
return self._room_account_data.filter(self._room_filter.filter(events))
|
||||
async def filter_room_account_data(
|
||||
self, events: Iterable[JsonDict]
|
||||
) -> List[JsonDict]:
|
||||
return await self._room_account_data.filter(
|
||||
await self._room_filter.filter(events)
|
||||
)
|
||||
|
||||
def blocks_all_presence(self) -> bool:
|
||||
return (
|
||||
|
@ -274,7 +292,9 @@ class FilterCollection:
|
|||
|
||||
|
||||
class Filter:
|
||||
def __init__(self, filter_json: JsonDict):
|
||||
def __init__(self, hs: "HomeServer", filter_json: JsonDict):
|
||||
self._hs = hs
|
||||
self._store = hs.get_datastore()
|
||||
self.filter_json = filter_json
|
||||
|
||||
self.limit = filter_json.get("limit", 10)
|
||||
|
@ -297,6 +317,20 @@ class Filter:
|
|||
self.labels = filter_json.get("org.matrix.labels", None)
|
||||
self.not_labels = filter_json.get("org.matrix.not_labels", [])
|
||||
|
||||
# Ideally these would be rejected at the endpoint if they were provided
|
||||
# and not supported, but that would involve modifying the JSON schema
|
||||
# based on the homeserver configuration.
|
||||
if hs.config.experimental.msc3440_enabled:
|
||||
self.relation_senders = self.filter_json.get(
|
||||
"io.element.relation_senders", None
|
||||
)
|
||||
self.relation_types = self.filter_json.get(
|
||||
"io.element.relation_types", None
|
||||
)
|
||||
else:
|
||||
self.relation_senders = None
|
||||
self.relation_types = None
|
||||
|
||||
def filters_all_types(self) -> bool:
|
||||
return "*" in self.not_types
|
||||
|
||||
|
@ -306,7 +340,7 @@ class Filter:
|
|||
def filters_all_rooms(self) -> bool:
|
||||
return "*" in self.not_rooms
|
||||
|
||||
def check(self, event: FilterEvent) -> bool:
|
||||
def _check(self, event: FilterEvent) -> bool:
|
||||
"""Checks whether the filter matches the given event.
|
||||
|
||||
Args:
|
||||
|
@ -420,8 +454,30 @@ class Filter:
|
|||
|
||||
return room_ids
|
||||
|
||||
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
||||
return list(filter(self.check, events))
|
||||
async def _check_event_relations(
|
||||
self, events: Iterable[FilterEvent]
|
||||
) -> List[FilterEvent]:
|
||||
# The event IDs to check, mypy doesn't understand the ifinstance check.
|
||||
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
|
||||
event_ids_to_keep = set(
|
||||
await self._store.events_have_relations(
|
||||
event_ids, self.relation_senders, self.relation_types
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
event
|
||||
for event in events
|
||||
if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep
|
||||
]
|
||||
|
||||
async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
||||
result = [event for event in events if self._check(event)]
|
||||
|
||||
if self.relation_senders or self.relation_types:
|
||||
return await self._check_event_relations(result)
|
||||
|
||||
return result
|
||||
|
||||
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
|
||||
"""Returns a new filter with the given room IDs appended.
|
||||
|
@ -433,7 +489,7 @@ class Filter:
|
|||
filter: A new filter including the given rooms and the old
|
||||
filter's rooms.
|
||||
"""
|
||||
newFilter = Filter(self.filter_json)
|
||||
newFilter = Filter(self._hs, self.filter_json)
|
||||
newFilter.rooms += room_ids
|
||||
return newFilter
|
||||
|
||||
|
@ -444,6 +500,3 @@ def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
|
|||
return actual_value.startswith(type_prefix)
|
||||
else:
|
||||
return actual_value == filter_value
|
||||
|
||||
|
||||
DEFAULT_FILTER_COLLECTION = FilterCollection({})
|
||||
|
|
|
@ -424,7 +424,7 @@ class PaginationHandler:
|
|||
|
||||
if events:
|
||||
if event_filter:
|
||||
events = event_filter.filter(events)
|
||||
events = await event_filter.filter(events)
|
||||
|
||||
events = await filter_events_for_client(
|
||||
self.storage, user_id, events, is_peeking=(member_event_id is None)
|
||||
|
|
|
@ -1158,8 +1158,10 @@ class RoomContextHandler:
|
|||
)
|
||||
|
||||
if event_filter:
|
||||
results["events_before"] = event_filter.filter(results["events_before"])
|
||||
results["events_after"] = event_filter.filter(results["events_after"])
|
||||
results["events_before"] = await event_filter.filter(
|
||||
results["events_before"]
|
||||
)
|
||||
results["events_after"] = await event_filter.filter(results["events_after"])
|
||||
|
||||
results["events_before"] = await filter_evts(results["events_before"])
|
||||
results["events_after"] = await filter_evts(results["events_after"])
|
||||
|
@ -1195,7 +1197,7 @@ class RoomContextHandler:
|
|||
|
||||
state_events = list(state[last_event_id].values())
|
||||
if event_filter:
|
||||
state_events = event_filter.filter(state_events)
|
||||
state_events = await event_filter.filter(state_events)
|
||||
|
||||
results["state"] = await filter_evts(state_events)
|
||||
|
||||
|
|
|
@ -180,7 +180,7 @@ class SearchHandler:
|
|||
% (set(group_keys) - {"room_id", "sender"},),
|
||||
)
|
||||
|
||||
search_filter = Filter(filter_dict)
|
||||
search_filter = Filter(self.hs, filter_dict)
|
||||
|
||||
# TODO: Search through left rooms too
|
||||
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
|
||||
|
@ -242,7 +242,7 @@ class SearchHandler:
|
|||
|
||||
rank_map.update({r["event"].event_id: r["rank"] for r in results})
|
||||
|
||||
filtered_events = search_filter.filter([r["event"] for r in results])
|
||||
filtered_events = await search_filter.filter([r["event"] for r in results])
|
||||
|
||||
events = await filter_events_for_client(
|
||||
self.storage, user.to_string(), filtered_events
|
||||
|
@ -292,7 +292,9 @@ class SearchHandler:
|
|||
|
||||
rank_map.update({r["event"].event_id: r["rank"] for r in results})
|
||||
|
||||
filtered_events = search_filter.filter([r["event"] for r in results])
|
||||
filtered_events = await search_filter.filter(
|
||||
[r["event"] for r in results]
|
||||
)
|
||||
|
||||
events = await filter_events_for_client(
|
||||
self.storage, user.to_string(), filtered_events
|
||||
|
|
|
@ -510,7 +510,7 @@ class SyncHandler:
|
|||
log_kv({"limited": limited})
|
||||
|
||||
if potential_recents:
|
||||
recents = sync_config.filter_collection.filter_room_timeline(
|
||||
recents = await sync_config.filter_collection.filter_room_timeline(
|
||||
potential_recents
|
||||
)
|
||||
log_kv({"recents_after_sync_filtering": len(recents)})
|
||||
|
@ -575,8 +575,8 @@ class SyncHandler:
|
|||
|
||||
log_kv({"loaded_recents": len(events)})
|
||||
|
||||
loaded_recents = sync_config.filter_collection.filter_room_timeline(
|
||||
events
|
||||
loaded_recents = (
|
||||
await sync_config.filter_collection.filter_room_timeline(events)
|
||||
)
|
||||
|
||||
log_kv({"loaded_recents_after_sync_filtering": len(loaded_recents)})
|
||||
|
@ -1015,7 +1015,7 @@ class SyncHandler:
|
|||
|
||||
return {
|
||||
(e.type, e.state_key): e
|
||||
for e in sync_config.filter_collection.filter_room_state(
|
||||
for e in await sync_config.filter_collection.filter_room_state(
|
||||
list(state.values())
|
||||
)
|
||||
if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
|
||||
|
@ -1383,7 +1383,7 @@ class SyncHandler:
|
|||
sync_config.user
|
||||
)
|
||||
|
||||
account_data_for_user = sync_config.filter_collection.filter_account_data(
|
||||
account_data_for_user = await sync_config.filter_collection.filter_account_data(
|
||||
[
|
||||
{"type": account_data_type, "content": content}
|
||||
for account_data_type, content in account_data.items()
|
||||
|
@ -1448,7 +1448,7 @@ class SyncHandler:
|
|||
# Deduplicate the presence entries so that there's at most one per user
|
||||
presence = list({p.user_id: p for p in presence}.values())
|
||||
|
||||
presence = sync_config.filter_collection.filter_presence(presence)
|
||||
presence = await sync_config.filter_collection.filter_presence(presence)
|
||||
|
||||
sync_result_builder.presence = presence
|
||||
|
||||
|
@ -2021,12 +2021,14 @@ class SyncHandler:
|
|||
)
|
||||
|
||||
account_data_events = (
|
||||
sync_config.filter_collection.filter_room_account_data(
|
||||
await sync_config.filter_collection.filter_room_account_data(
|
||||
account_data_events
|
||||
)
|
||||
)
|
||||
|
||||
ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral)
|
||||
ephemeral = await sync_config.filter_collection.filter_room_ephemeral(
|
||||
ephemeral
|
||||
)
|
||||
|
||||
if not (
|
||||
always_include
|
||||
|
|
|
@ -583,6 +583,7 @@ class RoomEventContextServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.room_context_handler = hs.get_room_context_handler()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
@ -600,7 +601,9 @@ class RoomEventContextServlet(RestServlet):
|
|||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||
event_filter: Optional[Filter] = Filter(
|
||||
self._hs, json_decoder.decode(filter_json)
|
||||
)
|
||||
else:
|
||||
event_filter = None
|
||||
|
||||
|
|
|
@ -550,6 +550,7 @@ class RoomMessageListRestServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._hs = hs
|
||||
self.pagination_handler = hs.get_pagination_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -567,7 +568,9 @@ class RoomMessageListRestServlet(RestServlet):
|
|||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||
event_filter: Optional[Filter] = Filter(
|
||||
self._hs, json_decoder.decode(filter_json)
|
||||
)
|
||||
if (
|
||||
event_filter
|
||||
and event_filter.filter_json.get("event_format", "client")
|
||||
|
@ -672,6 +675,7 @@ class RoomEventContextServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self._hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.room_context_handler = hs.get_room_context_handler()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
@ -688,7 +692,9 @@ class RoomEventContextServlet(RestServlet):
|
|||
filter_str = parse_string(request, "filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter: Optional[Filter] = Filter(json_decoder.decode(filter_json))
|
||||
event_filter: Optional[Filter] = Filter(
|
||||
self._hs, json_decoder.decode(filter_json)
|
||||
)
|
||||
else:
|
||||
event_filter = None
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ from typing import (
|
|||
|
||||
from synapse.api.constants import Membership, PresenceState
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
|
||||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import (
|
||||
|
@ -150,7 +150,7 @@ class SyncRestServlet(RestServlet):
|
|||
request_key = (user, timeout, since, filter_id, full_state, device_id)
|
||||
|
||||
if filter_id is None:
|
||||
filter_collection = DEFAULT_FILTER_COLLECTION
|
||||
filter_collection = self.filtering.DEFAULT_FILTER_COLLECTION
|
||||
elif filter_id.startswith("{"):
|
||||
try:
|
||||
filter_object = json_decoder.decode(filter_id)
|
||||
|
@ -160,7 +160,7 @@ class SyncRestServlet(RestServlet):
|
|||
except Exception:
|
||||
raise SynapseError(400, "Invalid filter JSON")
|
||||
self.filtering.check_valid_filter(filter_object)
|
||||
filter_collection = FilterCollection(filter_object)
|
||||
filter_collection = FilterCollection(self.hs, filter_object)
|
||||
else:
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
|
|
|
@ -20,7 +20,7 @@ import attr
|
|||
from synapse.api.constants import RelationTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
|
||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||
from synapse.storage.relations import (
|
||||
AggregationPaginationToken,
|
||||
|
@ -334,6 +334,62 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
return count, latest_event
|
||||
|
||||
async def events_have_relations(
|
||||
self,
|
||||
parent_ids: List[str],
|
||||
relation_senders: Optional[List[str]],
|
||||
relation_types: Optional[List[str]],
|
||||
) -> List[str]:
|
||||
"""Check which events have a relationship from the given senders of the
|
||||
given types.
|
||||
|
||||
Args:
|
||||
parent_ids: The events being annotated
|
||||
relation_senders: The relation senders to check.
|
||||
relation_types: The relation types to check.
|
||||
|
||||
Returns:
|
||||
True if the event has at least one relationship from one of the given senders of the given type.
|
||||
"""
|
||||
# If no restrictions are given then the event has the required relations.
|
||||
if not relation_senders and not relation_types:
|
||||
return parent_ids
|
||||
|
||||
sql = """
|
||||
SELECT relates_to_id FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE
|
||||
%s;
|
||||
"""
|
||||
|
||||
def _get_if_event_has_relations(txn) -> List[str]:
|
||||
clauses: List[str] = []
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "relates_to_id", parent_ids
|
||||
)
|
||||
clauses.append(clause)
|
||||
|
||||
if relation_senders:
|
||||
clause, temp_args = make_in_list_sql_clause(
|
||||
txn.database_engine, "sender", relation_senders
|
||||
)
|
||||
clauses.append(clause)
|
||||
args.extend(temp_args)
|
||||
if relation_types:
|
||||
clause, temp_args = make_in_list_sql_clause(
|
||||
txn.database_engine, "relation_type", relation_types
|
||||
)
|
||||
clauses.append(clause)
|
||||
args.extend(temp_args)
|
||||
|
||||
txn.execute(sql % " AND ".join(clauses), args)
|
||||
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_if_event_has_relations", _get_if_event_has_relations
|
||||
)
|
||||
|
||||
async def has_user_annotated_event(
|
||||
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
|
||||
) -> bool:
|
||||
|
|
|
@ -272,31 +272,37 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
|||
args = []
|
||||
|
||||
if event_filter.types:
|
||||
clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
|
||||
clauses.append(
|
||||
"(%s)" % " OR ".join("event.type = ?" for _ in event_filter.types)
|
||||
)
|
||||
args.extend(event_filter.types)
|
||||
|
||||
for typ in event_filter.not_types:
|
||||
clauses.append("type != ?")
|
||||
clauses.append("event.type != ?")
|
||||
args.append(typ)
|
||||
|
||||
if event_filter.senders:
|
||||
clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
|
||||
clauses.append(
|
||||
"(%s)" % " OR ".join("event.sender = ?" for _ in event_filter.senders)
|
||||
)
|
||||
args.extend(event_filter.senders)
|
||||
|
||||
for sender in event_filter.not_senders:
|
||||
clauses.append("sender != ?")
|
||||
clauses.append("event.sender != ?")
|
||||
args.append(sender)
|
||||
|
||||
if event_filter.rooms:
|
||||
clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
|
||||
clauses.append(
|
||||
"(%s)" % " OR ".join("event.room_id = ?" for _ in event_filter.rooms)
|
||||
)
|
||||
args.extend(event_filter.rooms)
|
||||
|
||||
for room_id in event_filter.not_rooms:
|
||||
clauses.append("room_id != ?")
|
||||
clauses.append("event.room_id != ?")
|
||||
args.append(room_id)
|
||||
|
||||
if event_filter.contains_url:
|
||||
clauses.append("contains_url = ?")
|
||||
clauses.append("event.contains_url = ?")
|
||||
args.append(event_filter.contains_url)
|
||||
|
||||
# We're only applying the "labels" filter on the database query, because applying the
|
||||
|
@ -307,6 +313,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
|||
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
|
||||
args.extend(event_filter.labels)
|
||||
|
||||
# Filter on relation_senders / relation types from the joined tables.
|
||||
if event_filter.relation_senders:
|
||||
clauses.append(
|
||||
"(%s)"
|
||||
% " OR ".join(
|
||||
"related_event.sender = ?" for _ in event_filter.relation_senders
|
||||
)
|
||||
)
|
||||
args.extend(event_filter.relation_senders)
|
||||
|
||||
if event_filter.relation_types:
|
||||
clauses.append(
|
||||
"(%s)"
|
||||
% " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
|
||||
)
|
||||
args.extend(event_filter.relation_types)
|
||||
|
||||
return " AND ".join(clauses), args
|
||||
|
||||
|
||||
|
@ -1116,7 +1139,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
|
||||
bounds = generate_pagination_where_clause(
|
||||
direction=direction,
|
||||
column_names=("topological_ordering", "stream_ordering"),
|
||||
column_names=("event.topological_ordering", "event.stream_ordering"),
|
||||
from_token=from_bound,
|
||||
to_token=to_bound,
|
||||
engine=self.database_engine,
|
||||
|
@ -1133,32 +1156,51 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
|
||||
select_keywords = "SELECT"
|
||||
join_clause = ""
|
||||
# Using DISTINCT in this SELECT query is quite expensive, because it
|
||||
# requires the engine to sort on the entire (not limited) result set,
|
||||
# i.e. the entire events table. Only use it in scenarios that could result
|
||||
# in the same event ID occurring multiple times in the results.
|
||||
needs_distinct = False
|
||||
if event_filter and event_filter.labels:
|
||||
# If we're not filtering on a label, then joining on event_labels will
|
||||
# return as many row for a single event as the number of labels it has. To
|
||||
# avoid this, only join if we're filtering on at least one label.
|
||||
join_clause = """
|
||||
join_clause += """
|
||||
LEFT JOIN event_labels
|
||||
USING (event_id, room_id, topological_ordering)
|
||||
"""
|
||||
if len(event_filter.labels) > 1:
|
||||
# Using DISTINCT in this SELECT query is quite expensive, because it
|
||||
# requires the engine to sort on the entire (not limited) result set,
|
||||
# i.e. the entire events table. We only need to use it when we're
|
||||
# filtering on more than two labels, because that's the only scenario
|
||||
# in which we can possibly to get multiple times the same event ID in
|
||||
# the results.
|
||||
select_keywords += "DISTINCT"
|
||||
# Multiple labels could cause the same event to appear multiple times.
|
||||
needs_distinct = True
|
||||
|
||||
# If there is a filter on relation_senders and relation_types join to the
|
||||
# relations table.
|
||||
if event_filter and (
|
||||
event_filter.relation_senders or event_filter.relation_types
|
||||
):
|
||||
# Filtering by relations could cause the same event to appear multiple
|
||||
# times (since there's no limit on the number of relations to an event).
|
||||
needs_distinct = True
|
||||
join_clause += """
|
||||
LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
|
||||
"""
|
||||
if event_filter.relation_senders:
|
||||
join_clause += """
|
||||
LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
|
||||
"""
|
||||
|
||||
if needs_distinct:
|
||||
select_keywords += " DISTINCT"
|
||||
|
||||
sql = """
|
||||
%(select_keywords)s
|
||||
event_id, instance_name,
|
||||
topological_ordering, stream_ordering
|
||||
FROM events
|
||||
event.event_id, event.instance_name,
|
||||
event.topological_ordering, event.stream_ordering
|
||||
FROM events AS event
|
||||
%(join_clause)s
|
||||
WHERE outlier = ? AND room_id = ? AND %(bounds)s
|
||||
ORDER BY topological_ordering %(order)s,
|
||||
stream_ordering %(order)s LIMIT ?
|
||||
WHERE event.outlier = ? AND event.room_id = ? AND %(bounds)s
|
||||
ORDER BY event.topological_ordering %(order)s,
|
||||
event.stream_ordering %(order)s LIMIT ?
|
||||
""" % {
|
||||
"select_keywords": select_keywords,
|
||||
"join_clause": join_clause,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue