add a filtered_types param to limit filtering to specific types

This commit is contained in:
Matthew Hodgson 2018-07-19 18:32:02 +01:00
parent be3adfc331
commit 924eb34d94
2 changed files with 96 additions and 82 deletions

View File

@ -417,38 +417,44 @@ class SyncHandler(object):
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_after_event(self, event, types=None): def get_state_after_event(self, event, types=None, filtered_types=None):
""" """
Get the room state after the given event Get the room state after the given event
Args: Args:
event(synapse.events.EventBase): event of interest event(synapse.events.EventBase): event of interest
types(list[(str|None, str|None)]|None): List of (type, state_key) tuples types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None, which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type. Presence of type of `None` all events are returned of the given type.
indicates that types not in the list should not be filtered out.
May be None, which matches any key. May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state_ids = yield self.store.get_state_ids_for_event(event.event_id, types) state_ids = yield self.store.get_state_ids_for_event(
event.event_id, types, filtered_types=filtered_types
)
if event.is_state(): if event.is_state():
state_ids = state_ids.copy() state_ids = state_ids.copy()
state_ids[(event.type, event.state_key)] = event.event_id state_ids[(event.type, event.state_key)] = event.event_id
defer.returnValue(state_ids) defer.returnValue(state_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at(self, room_id, stream_position, types=None): def get_state_at(self, room_id, stream_position, types=None, filtered_types=None):
""" Get the room state at a particular stream position """ Get the room state at a particular stream position
Args: Args:
room_id(str): room for which to get state room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state stream_position(StreamToken): point at which to get state
types(list[(str|None, str|None)]|None): List of (type, state_key) tuples types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None, which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type. Presence of type of `None` all events are returned of the given type.
indicates that types not in the list should not be filtered out. filtered_types(list[str]|None): Only apply filtering via `types` to this
May be None, which matches any key. list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
@ -463,7 +469,9 @@ class SyncHandler(object):
if last_events: if last_events:
last_event = last_events[-1] last_event = last_events[-1]
state = yield self.get_state_after_event(last_event, types) state = yield self.get_state_after_event(
last_event, types, filtered_types=filtered_types
)
else: else:
# no events in this room - so presumably no state # no events in this room - so presumably no state
@ -499,6 +507,7 @@ class SyncHandler(object):
types = None types = None
member_state_ids = {} member_state_ids = {}
lazy_load_members = sync_config.filter_collection.lazy_load_members() lazy_load_members = sync_config.filter_collection.lazy_load_members()
filtered_types = None
if lazy_load_members: if lazy_load_members:
# We only request state for the members needed to display the # We only request state for the members needed to display the
@ -516,29 +525,25 @@ class SyncHandler(object):
# to be done based on event_id, and we don't have the member # to be done based on event_id, and we don't have the member
# event ids until we've pulled them out of the DB. # event ids until we've pulled them out of the DB.
if not types: # only apply the filtering to room members
# an optimisation to stop needlessly trying to calculate filtered_types = [EventTypes.Member]
# member_state_ids
#
# XXX: i can't remember what this trying to do. why would
# types ever be []? --matthew
lazy_load_members = False
types.append((None, None)) # don't just filter to room members
if full_state: if full_state:
if batch: if batch:
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id, types=types batch.events[-1].event_id, types=types,
filtered_types=filtered_types
) )
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types batch.events[0].event_id, types=types,
filtered_types=filtered_types
) )
else: else:
current_state_ids = yield self.get_state_at( current_state_ids = yield self.get_state_at(
room_id, stream_position=now_token, types=types room_id, stream_position=now_token, types=types,
filtered_types=filtered_types
) )
state_ids = current_state_ids state_ids = current_state_ids
@ -563,15 +568,18 @@ class SyncHandler(object):
) )
elif batch.limited: elif batch.limited:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token, types=types room_id, stream_position=since_token, types=types,
filtered_types=filtered_types
) )
current_state_ids = yield self.store.get_state_ids_for_event( current_state_ids = yield self.store.get_state_ids_for_event(
batch.events[-1].event_id, types=types batch.events[-1].event_id, types=types,
filtered_types=filtered_types
) )
state_at_timeline_start = yield self.store.get_state_ids_for_event( state_at_timeline_start = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types batch.events[0].event_id, types=types,
filtered_types=filtered_types
) )
if lazy_load_members: if lazy_load_members:
@ -603,11 +611,10 @@ class SyncHandler(object):
# event_ids) at this point. We know we can do it based on mxid as this # event_ids) at this point. We know we can do it based on mxid as this
# is an non-gappy incremental sync. # is an non-gappy incremental sync.
# strip off the (None, None) and filter to just room members
types = types[:-1]
if types: if types:
state_ids = yield self.store.get_state_ids_for_event( state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types batch.events[0].event_id, types=types,
filtered_types=filtered_types
) )
state = {} state = {}

View File

@ -185,7 +185,7 @@ class StateGroupWorkerStore(SQLBaseStore):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, types): def _get_state_groups_from_groups(self, groups, types, filtered_types=None):
"""Returns the state groups for a given set of groups, filtering on """Returns the state groups for a given set of groups, filtering on
types of state events. types of state events.
@ -193,9 +193,10 @@ class StateGroupWorkerStore(SQLBaseStore):
groups(list[int]): list of state group IDs to query groups(list[int]): list of state group IDs to query
types(list[str|None, str|None])|None: List of 2-tuples of the form types(list[str|None, str|None])|None: List of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all (`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`. Presence of type of `None` indicates state_keys for the `type`. If None, all types are returned.
that types not in the list should not be filtered out. If None, filtered_types(list[str]|None): Only apply filtering via `types` to this
all types are returned. list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
dictionary state_group -> (dict of (type, state_key) -> event id) dictionary state_group -> (dict of (type, state_key) -> event id)
@ -206,26 +207,21 @@ class StateGroupWorkerStore(SQLBaseStore):
for chunk in chunks: for chunk in chunks:
res = yield self.runInteraction( res = yield self.runInteraction(
"_get_state_groups_from_groups", "_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn, chunk, types, self._get_state_groups_from_groups_txn, chunk, types, filtered_types
) )
results.update(res) results.update(res)
defer.returnValue(results) defer.returnValue(results)
def _get_state_groups_from_groups_txn(self, txn, groups, types=None): def _get_state_groups_from_groups_txn(
self, txn, groups, types=None, filtered_types=None
):
results = {group: {} for group in groups} results = {group: {} for group in groups}
include_other_types = False include_other_types = False if filtered_types is None else True
if types is not None: if types is not None:
type_set = set(types) types = list(set(types)) # deduplicate types list
if (None, None) in type_set:
# special case (None, None) to mean that other types should be
# returned - i.e. we were just filtering down the state keys
# for particular types.
include_other_types = True
type_set.remove((None, None))
types = list(type_set) # deduplicate types list
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is # Temporarily disable sequential scans in this transaction. This is
@ -276,7 +272,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if include_other_types: if include_other_types:
# XXX: check whether this slows postgres down like a list of # XXX: check whether this slows postgres down like a list of
# ORs does too? # ORs does too?
unique_types = set([t for (t, _) in types]) unique_types = set(filtered_types)
clause_to_args.append( clause_to_args.append(
( (
"AND type <> ? " * len(unique_types), "AND type <> ? " * len(unique_types),
@ -313,7 +309,7 @@ class StateGroupWorkerStore(SQLBaseStore):
where_args.extend([typ[0], typ[1]]) where_args.extend([typ[0], typ[1]])
if include_other_types: if include_other_types:
unique_types = set([t for (t, _) in types]) unique_types = set(filtered_types)
where_clauses.append( where_clauses.append(
"(" + " AND ".join(["type <> ?"] * len(unique_types)) + ")" "(" + " AND ".join(["type <> ?"] * len(unique_types)) + ")"
) )
@ -373,18 +369,20 @@ class StateGroupWorkerStore(SQLBaseStore):
return results return results
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_events(self, event_ids, types): def get_state_for_events(self, event_ids, types, filtered_types):
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. The state dicts will only have the type/state_keys dicts for each event. The state dicts will only have the type/state_keys
that are in the `types` list. that are in the `types` list.
Args: Args:
event_ids (list[string]) event_ids (list[string])
types (list[(str|None, str|None)]|None): List of (type, state_key) tuples types (list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None, which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type. Presence of type of `None` all events are returned of the given type.
indicates that types not in the list should not be filtered out.
May be None, which matches any key. May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
deferred: A list of dicts corresponding to the event_ids given. deferred: A list of dicts corresponding to the event_ids given.
@ -395,7 +393,7 @@ class StateGroupWorkerStore(SQLBaseStore):
) )
groups = set(itervalues(event_to_groups)) groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
state_event_map = yield self.get_events( state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
@ -414,17 +412,19 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types=None): def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
""" """
Get the state dicts corresponding to a list of events Get the state dicts corresponding to a list of events
Args: Args:
event_ids(list(str)): events whose state should be returned event_ids(list(str)): events whose state should be returned
types(list[(str|None, str|None)]|None): List of (type, state_key) tuples types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None, which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type. Presence of type of `None` all events are returned of the given type.
indicates that types not in the list should not be filtered out.
May be None, which matches any key. May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A deferred dict from event_id -> (type, state_key) -> state_event A deferred dict from event_id -> (type, state_key) -> state_event
@ -434,7 +434,7 @@ class StateGroupWorkerStore(SQLBaseStore):
) )
groups = set(itervalues(event_to_groups)) groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, types) group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
event_to_state = { event_to_state = {
event_id: group_to_state[group] event_id: group_to_state[group]
@ -444,41 +444,45 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_event(self, event_id, types=None): def get_state_for_event(self, event_id, types=None, filtered_types=None):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
Args: Args:
event_id(str): event whose state should be returned event_id(str): event whose state should be returned
types(list[(str|None, str|None)]|None): List of (type, state_key) tuples types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None, which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type. Presence of type of `None` all events are returned of the given type.
indicates that types not in the list should not be filtered out.
May be None, which matches any key. May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A deferred dict from (type, state_key) -> state_event A deferred dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_for_events([event_id], types) state_map = yield self.get_state_for_events([event_id], types, filtered_types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, types=None): def get_state_ids_for_event(self, event_id, types=None, filtered_types=None):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
Args: Args:
event_id(str): event whose state should be returned event_id(str): event whose state should be returned
types(list[(str|None, str|None)]|None): List of (type, state_key) tuples types(list[(str, str|None)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. If `state_key` is None, which are used to filter the state fetched. If `state_key` is None,
all events are returned of the given type. Presence of type of `None` all events are returned of the given type.
indicates that types not in the list should not be filtered out.
May be None, which matches any key. May be None, which matches any key.
filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
A deferred dict from (type, state_key) -> state_event A deferred dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_ids_for_events([event_id], types) state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types)
defer.returnValue(state_map[event_id]) defer.returnValue(state_map[event_id])
@cached(max_entries=50000) @cached(max_entries=50000)
@ -509,7 +513,7 @@ class StateGroupWorkerStore(SQLBaseStore):
defer.returnValue({row["event_id"]: row["state_group"] for row in rows}) defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
def _get_some_state_from_cache(self, group, types): def _get_some_state_from_cache(self, group, types, filtered_types=None):
"""Checks if group is in cache. See `_get_state_for_groups` """Checks if group is in cache. See `_get_state_for_groups`
Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
@ -520,29 +524,30 @@ class StateGroupWorkerStore(SQLBaseStore):
Args: Args:
group(int): The state group to lookup group(int): The state group to lookup
types(list[str|None, str|None]): List of 2-tuples of the form types(list[str, str|None]): List of 2-tuples of the form
(`type`, `state_key`), where a `state_key` of `None` matches all (`type`, `state_key`), where a `state_key` of `None` matches all
state_keys for the `type`. Presence of type of `None` indicates state_keys for the `type`.
that types not in the list should not be filtered out. filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
""" """
is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
type_to_key = {} type_to_key = {}
# tracks which of the requested types are missing from our cache
missing_types = set() missing_types = set()
include_other_types = False include_other_types = True if filtered_types is None else False
for typ, state_key in types: for typ, state_key in types:
key = (typ, state_key) key = (typ, state_key)
if typ is None:
include_other_types = True
next
if state_key is None: if state_key is None:
type_to_key[typ] = None type_to_key[typ] = None
# XXX: why do we mark the type as missing from our cache just # XXX: why do we mark the type as missing from our cache just
# because we weren't filtering on a specific value of state_key? # because we weren't filtering on a specific value of state_key?
# is it because the cache doesn't handle wildcards?
missing_types.add(key) missing_types.add(key)
else: else:
if type_to_key.get(typ, object()) is not None: if type_to_key.get(typ, object()) is not None:
@ -556,7 +561,7 @@ class StateGroupWorkerStore(SQLBaseStore):
def include(typ, state_key): def include(typ, state_key):
valid_state_keys = type_to_key.get(typ, sentinel) valid_state_keys = type_to_key.get(typ, sentinel)
if valid_state_keys is sentinel: if valid_state_keys is sentinel:
return include_other_types return include_other_types and typ not in filtered_types
if valid_state_keys is None: if valid_state_keys is None:
return True return True
if state_key in valid_state_keys: if state_key in valid_state_keys:
@ -585,21 +590,23 @@ class StateGroupWorkerStore(SQLBaseStore):
return state_dict_ids, is_all return state_dict_ids, is_all
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None): def _get_state_for_groups(self, groups, types=None, filtered_types=None):
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
Args: Args:
groups (iterable[int]): list of state groups for which we want groups (iterable[int]): list of state groups for which we want
to get the state. to get the state.
types (None|iterable[(None|str, None|str)]): types (None|iterable[(None, None|str)]):
indicates the state type/keys required. If None, the whole indicates the state type/keys required. If None, the whole
state is fetched and returned. state is fetched and returned.
Otherwise, each entry should be a `(type, state_key)` tuple to Otherwise, each entry should be a `(type, state_key)` tuple to
include in the response. A `state_key` of None is a wildcard include in the response. A `state_key` of None is a wildcard
meaning that we require all state with that type. A `type` of None meaning that we require all state with that type.
indicates that types not in the list should not be filtered out. filtered_types(list[str]|None): Only apply filtering via `types` to this
list of event types. Other types of events are returned unfiltered.
If None, `types` filtering is applied to all events.
Returns: Returns:
Deferred[dict[int, dict[(type, state_key), EventBase]]] Deferred[dict[int, dict[(type, state_key), EventBase]]]
@ -612,7 +619,7 @@ class StateGroupWorkerStore(SQLBaseStore):
if types is not None: if types is not None:
for group in set(groups): for group in set(groups):
state_dict_ids, _, got_all = self._get_some_state_from_cache( state_dict_ids, _, got_all = self._get_some_state_from_cache(
group, types, group, types, filtered_types
) )
results[group] = state_dict_ids results[group] = state_dict_ids
@ -645,7 +652,7 @@ class StateGroupWorkerStore(SQLBaseStore):
types_to_fetch = types types_to_fetch = types
group_to_state_dict = yield self._get_state_groups_from_groups( group_to_state_dict = yield self._get_state_groups_from_groups(
missing_groups, types_to_fetch, missing_groups, types_to_fetch, filtered_types
) )
for group, group_state_dict in iteritems(group_to_state_dict): for group, group_state_dict in iteritems(group_to_state_dict):