diff --git a/changelog.d/12272.misc b/changelog.d/12272.misc new file mode 100644 index 000000000..95589f336 --- /dev/null +++ b/changelog.d/12272.misc @@ -0,0 +1 @@ +Add a new cache `_get_membership_from_event_id` to speed up push rule calculations in large rooms. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 030898e4d..a402a3e40 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -24,6 +24,7 @@ from synapse.event_auth import get_user_power_level from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY +from synapse.storage.databases.main.roommember import EventIdMembership from synapse.util.async_helpers import Linearizer from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.descriptors import lru_cache @@ -292,7 +293,7 @@ def _condition_checker( return True -MemberMap = Dict[str, Tuple[str, str]] +MemberMap = Dict[str, Optional[EventIdMembership]] Rule = Dict[str, dict] RulesByUser = Dict[str, List[Rule]] StateGroup = Union[object, int] @@ -306,7 +307,7 @@ class RulesForRoomData: *only* include data, and not references to e.g. the data stores. """ - # event_id -> (user_id, state) + # event_id -> EventIdMembership member_map: MemberMap = attr.Factory(dict) # user_id -> rules rules_by_user: RulesByUser = attr.Factory(dict) @@ -447,11 +448,10 @@ class RulesForRoom: res = self.data.member_map.get(event_id, None) if res: - user_id, state = res - if state == Membership.JOIN: - rules = self.data.rules_by_user.get(user_id, None) + if res.membership == Membership.JOIN: + rules = self.data.rules_by_user.get(res.user_id, None) if rules: - ret_rules_by_user[user_id] = rules + ret_rules_by_user[res.user_id] = rules continue # If a user has left a room we remove their push rule. If they @@ -502,24 +502,26 @@ class RulesForRoom: """ sequence = self.data.sequence - rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) + members = await self.store.get_membership_from_event_ids( + member_event_ids.values() + ) - members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} - - # If the event is a join event then it will be in current state evnts + # If the event is a join event then it will be in current state events # map but not in the DB, so we have to explicitly insert it. if event.type == EventTypes.Member: for event_id in member_event_ids.values(): if event_id == event.event_id: - members[event_id] = (event.state_key, event.membership) + members[event_id] = EventIdMembership( + user_id=event.state_key, membership=event.membership + ) if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) joined_user_ids = { - user_id - for user_id, membership in members.values() - if membership == Membership.JOIN + entry.user_id + for entry in members.values() + if entry and entry.membership == Membership.JOIN } logger.debug("Joined: %r", joined_user_ids) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2d7511d61..dd4e83a2a 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -192,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) + # The `_get_membership_from_event_id` is immutable, except for the + # case where we look up an event *before* persisting it. + self._get_membership_from_event_id.invalidate((event_id,)) + if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1f60aef18..d25324312 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1745,6 +1745,13 @@ class PersistEventsStore: (event.state_key,), ) + # The `_get_membership_from_event_id` is immutable, except for the + # case where we look up an event *before* persisting it. + txn.call_after( + self.store._get_membership_from_event_id.invalidate, + (event.event_id,), + ) + # We update the local_current_membership table only if the event is # "current", i.e., its something that has just happened. # diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index bef675b84..3248da535 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" +@attr.s(frozen=True, slots=True, auto_attribs=True) +class EventIdMembership: + """Returned by `get_membership_from_event_ids`""" + + user_id: str + membership: str + + class RoomMemberWorkerStore(EventsWorkerStore): def __init__( self, @@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): retcols=("user_id", "display_name", "avatar_url", "event_id"), keyvalues={"membership": Membership.JOIN}, batch_size=500, - desc="_get_membership_from_event_ids", + desc="_get_joined_profiles_from_event_ids", ) return { @@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) + @cached(max_entries=5000) + async def _get_membership_from_event_id( + self, member_event_id: str + ) -> Optional[EventIdMembership]: + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_membership_from_event_id", list_name="member_event_ids" + ) async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] - ) -> List[dict]: - """Get user_id and membership of a set of event IDs.""" + ) -> Dict[str, Optional[EventIdMembership]]: + """Get user_id and membership of a set of event IDs. - return await self.db_pool.simple_select_many_batch( + Returns: + Mapping from event ID to `EventIdMembership` if the event is a + membership event, otherwise the value is None. + """ + + rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): desc="get_membership_from_event_ids", ) + return { + row["event_id"]: EventIdMembership( + membership=row["membership"], user_id=row["user_id"] + ) + for row in rows + } + async def is_local_host_in_room_ignoring_users( self, room_id: str, ignore_users: Collection[str] ) -> bool: diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 7d543fdbe..b40292281 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -1023,8 +1023,13 @@ class EventsPersistenceStorage: # Check if any of the changes that we don't have events for are joins. if events_to_check: - rows = await self.main_store.get_membership_from_event_ids(events_to_check) - is_still_joined = any(row["membership"] == Membership.JOIN for row in rows) + members = await self.main_store.get_membership_from_event_ids( + events_to_check + ) + is_still_joined = any( + member and member.membership == Membership.JOIN + for member in members.values() + ) if is_still_joined: return True @@ -1060,9 +1065,11 @@ class EventsPersistenceStorage: ), event_id in current_state.items() if typ == EventTypes.Member and not self.is_mine_id(state_key) ] - rows = await self.main_store.get_membership_from_event_ids(remote_event_ids) + members = await self.main_store.get_membership_from_event_ids(remote_event_ids) potentially_left_users.update( - row["user_id"] for row in rows if row["membership"] == Membership.JOIN + member.user_id + for member in members.values() + if member and member.membership == Membership.JOIN ) return False