Add cache for get_membership_from_event_ids (#12272)

This should speed up push rule calculations for rooms with large numbers of local users when the main push rule cache fails.

Co-authored-by: reivilibre <oliverw@matrix.org>
This commit is contained in:
Erik Johnston 2022-03-25 14:58:56 +00:00 committed by GitHub
parent 38adf14998
commit 7ca8ee67a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 22 deletions

View file

@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
@attr.s(frozen=True, slots=True, auto_attribs=True)
class EventIdMembership:
"""Returned by `get_membership_from_event_ids`"""
user_id: str
membership: str
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(
self,
@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_membership_from_event_ids",
desc="_get_joined_profiles_from_event_ids",
)
return {
@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
@cached(max_entries=5000)
async def _get_membership_from_event_id(
self, member_event_id: str
) -> Optional[EventIdMembership]:
raise NotImplementedError()
@cachedList(
cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
)
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs."""
) -> Dict[str, Optional[EventIdMembership]]:
"""Get user_id and membership of a set of event IDs.
return await self.db_pool.simple_select_many_batch(
Returns:
Mapping from event ID to `EventIdMembership` if the event is a
membership event, otherwise the value is None.
"""
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_membership_from_event_ids",
)
return {
row["event_id"]: EventIdMembership(
membership=row["membership"], user_id=row["user_id"]
)
for row in rows
}
async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str]
) -> bool: