Cache user IDs instead of profile objects (#13573)

The profile objects are never used and increase cache size significantly.
This commit is contained in:
Nick Mills-Barrett 2022-08-23 10:49:59 +01:00 committed by GitHub
parent 37f329c9ad
commit 5e7847dc92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 57 additions and 54 deletions

1
changelog.d/13573.misc Normal file
View File

@ -0,0 +1 @@
Cache user IDs instead of profiles to reduce cache memory usage. Contributed by Nick @ Beeper (@fizzadar).

View File

@ -2421,10 +2421,10 @@ class SyncHandler:
joined_room.room_id, joined_room.event_pos.stream joined_room.room_id, joined_room.event_pos.stream
) )
) )
users_in_room = await self.state.get_current_users_in_room( user_ids_in_room = await self.state.get_current_user_ids_in_room(
joined_room.room_id, extrems joined_room.room_id, extrems
) )
if user_id in users_in_room: if user_id in user_ids_in_room:
joined_room_ids.add(joined_room.room_id) joined_room_ids.add(joined_room.room_id)
return frozenset(joined_room_ids) return frozenset(joined_room_ids)

View File

@ -44,7 +44,6 @@ from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import StateMap from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -210,11 +209,11 @@ class StateHandler:
ret = await self.resolve_state_groups_for_events(room_id, event_ids) ret = await self.resolve_state_groups_for_events(room_id, event_ids)
return await ret.get_state(self._state_storage_controller, state_filter) return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_users_in_room( async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str] self, room_id: str, latest_event_ids: List[str]
) -> Dict[str, ProfileInfo]: ) -> Set[str]:
""" """
Get the users who are currently in a room. Get the users IDs who are currently in a room.
Note: This is much slower than using the equivalent method Note: This is much slower than using the equivalent method
`DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`, `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
@ -225,15 +224,15 @@ class StateHandler:
room_id: The ID of the room. room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns: Returns:
Dictionary of user IDs to their profileinfo. Set of user IDs in the room.
""" """
assert latest_event_ids is not None assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_users_in_room") logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = await entry.get_state(self._state_storage_controller, StateFilter.all()) state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_users_from_state(room_id, state, entry) return await self.store.get_joined_user_ids_from_state(room_id, state, entry)
async def get_hosts_in_room_at_events( async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str] self, room_id: str, event_ids: Collection[str]

View File

@ -835,9 +835,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return shared_room_ids or frozenset() return shared_room_ids or frozenset()
async def get_joined_users_from_state( async def get_joined_user_ids_from_state(
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry" self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> Dict[str, ProfileInfo]: ) -> Set[str]:
state_group: Union[object, int] = state_entry.state_group state_group: Union[object, int] = state_entry.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
@ -848,25 +848,25 @@ class RoomMemberWorkerStore(EventsWorkerStore):
assert state_group is not None assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"): with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context( return await self._get_joined_user_ids_from_context(
room_id, state_group, state, context=state_entry room_id, state_group, state, context=state_entry
) )
@cached(num_args=2, iterable=True, max_entries=100000) @cached(num_args=2, iterable=True, max_entries=100000)
async def _get_joined_users_from_context( async def _get_joined_user_ids_from_context(
self, self,
room_id: str, room_id: str,
state_group: Union[object, int], state_group: Union[object, int],
current_state_ids: StateMap[str], current_state_ids: StateMap[str],
event: Optional[EventBase] = None, event: Optional[EventBase] = None,
context: Optional["_StateCacheEntry"] = None, context: Optional["_StateCacheEntry"] = None,
) -> Dict[str, ProfileInfo]: ) -> Set[str]:
# We don't use `state_group`, it's there so that we can cache based # We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states # on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different. # with a state_group of None are likely to be different.
assert state_group is not None assert state_group is not None
users_in_room = {} users_in_room = set()
member_event_ids = [ member_event_ids = [
e_id e_id
for key, e_id in current_state_ids.items() for key, e_id in current_state_ids.items()
@ -879,11 +879,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# If we do then we can reuse that result and simply update it with # If we do then we can reuse that result and simply update it with
# any membership changes in `delta_ids` # any membership changes in `delta_ids`
if context.prev_group and context.delta_ids: if context.prev_group and context.delta_ids:
prev_res = self._get_joined_users_from_context.cache.get_immediate( prev_res = self._get_joined_user_ids_from_context.cache.get_immediate(
(room_id, context.prev_group), None (room_id, context.prev_group), None
) )
if prev_res and isinstance(prev_res, dict): if prev_res and isinstance(prev_res, set):
users_in_room = dict(prev_res) users_in_room = prev_res
member_event_ids = [ member_event_ids = [
e_id e_id
for key, e_id in context.delta_ids.items() for key, e_id in context.delta_ids.items()
@ -891,7 +891,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
] ]
for etype, state_key in context.delta_ids: for etype, state_key in context.delta_ids:
if etype == EventTypes.Member: if etype == EventTypes.Member:
users_in_room.pop(state_key, None) users_in_room.discard(state_key)
# We check if we have any of the member event ids in the event cache # We check if we have any of the member event ids in the event cache
# before we ask the DB # before we ask the DB
@ -908,42 +908,41 @@ class RoomMemberWorkerStore(EventsWorkerStore):
ev_entry = event_map.get(event_id) ev_entry = event_map.get(event_id)
if ev_entry and not ev_entry.event.rejected_reason: if ev_entry and not ev_entry.event.rejected_reason:
if ev_entry.event.membership == Membership.JOIN: if ev_entry.event.membership == Membership.JOIN:
users_in_room[ev_entry.event.state_key] = ProfileInfo( users_in_room.add(ev_entry.event.state_key)
display_name=ev_entry.event.content.get("displayname", None),
avatar_url=ev_entry.event.content.get("avatar_url", None),
)
else: else:
missing_member_event_ids.append(event_id) missing_member_event_ids.append(event_id)
if missing_member_event_ids: if missing_member_event_ids:
event_to_memberships = await self._get_joined_profiles_from_event_ids( event_to_memberships = await self._get_user_ids_from_membership_event_ids(
missing_member_event_ids missing_member_event_ids
) )
users_in_room.update(row for row in event_to_memberships.values() if row) users_in_room.update(event_to_memberships.values())
if event is not None and event.type == EventTypes.Member: if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if event.event_id in member_event_ids: if event.event_id in member_event_ids:
users_in_room[event.state_key] = ProfileInfo( users_in_room.add(event.state_key)
display_name=event.content.get("displayname", None),
avatar_url=event.content.get("avatar_url", None),
)
return users_in_room return users_in_room
@cached(max_entries=10000) @cached(
def _get_joined_profile_from_event_id( max_entries=10000,
# This name matches the old function that has been replaced - the cache name
# is kept here to maintain backwards compatibility.
name="_get_joined_profile_from_event_id",
)
def _get_user_id_from_membership_event_id(
self, event_id: str self, event_id: str
) -> Optional[Tuple[str, ProfileInfo]]: ) -> Optional[Tuple[str, ProfileInfo]]:
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(
cached_method_name="_get_joined_profile_from_event_id", cached_method_name="_get_user_id_from_membership_event_id",
list_name="event_ids", list_name="event_ids",
) )
async def _get_joined_profiles_from_event_ids( async def _get_user_ids_from_membership_event_ids(
self, event_ids: Iterable[str] self, event_ids: Iterable[str]
) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]: ) -> Dict[str, str]:
"""For given set of member event_ids check if they point to a join """For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info. event and if so return the associated user and profile info.
@ -958,21 +957,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
table="room_memberships", table="room_memberships",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
retcols=("user_id", "display_name", "avatar_url", "event_id"), retcols=("user_id", "event_id"),
keyvalues={"membership": Membership.JOIN}, keyvalues={"membership": Membership.JOIN},
batch_size=1000, batch_size=1000,
desc="_get_joined_profiles_from_event_ids", desc="_get_user_ids_from_membership_event_ids",
) )
return { return {row["event_id"]: row["user_id"] for row in rows}
row["event_id"]: (
row["user_id"],
ProfileInfo(
avatar_url=row["avatar_url"], display_name=row["display_name"]
),
)
for row in rows
}
@cached(max_entries=10000) @cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool: async def is_host_joined(self, room_id: str, host: str) -> bool:
@ -1131,12 +1122,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
else: else:
# The cache doesn't match the state group or prev state group, # The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles. # so we calculate the result from first principles.
joined_users = await self.get_joined_users_from_state( joined_user_ids = await self.get_joined_user_ids_from_state(
room_id, state, state_entry room_id, state, state_entry
) )
cache.hosts_to_joined_users = {} cache.hosts_to_joined_users = {}
for user_id in joined_users: for user_id in joined_user_ids:
host = intern_string(get_domain_from_id(user_id)) host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id) cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)

View File

@ -73,8 +73,10 @@ class _CacheDescriptorBase:
num_args: Optional[int], num_args: Optional[int],
uncached_args: Optional[Collection[str]] = None, uncached_args: Optional[Collection[str]] = None,
cache_context: bool = False, cache_context: bool = False,
name: Optional[str] = None,
): ):
self.orig = orig self.orig = orig
self.name = name or orig.__name__
arg_spec = inspect.getfullargspec(orig) arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args all_args = arg_spec.args
@ -211,7 +213,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: LruCache[CacheKey, Any] = LruCache( cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__, cache_name=self.name,
max_size=self.max_entries, max_size=self.max_entries,
) )
@ -241,7 +243,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
wrapped = cast(_CachedFunction, _wrapped) wrapped = cast(_CachedFunction, _wrapped)
wrapped.cache = cache wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.name] = wrapped
return wrapped return wrapped
@ -301,12 +303,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
cache_context: bool = False, cache_context: bool = False,
iterable: bool = False, iterable: bool = False,
prune_unread_entries: bool = True, prune_unread_entries: bool = True,
name: Optional[str] = None,
): ):
super().__init__( super().__init__(
orig, orig,
num_args=num_args, num_args=num_args,
uncached_args=uncached_args, uncached_args=uncached_args,
cache_context=cache_context, cache_context=cache_context,
name=name,
) )
if tree and self.num_args < 2: if tree and self.num_args < 2:
@ -321,7 +325,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
cache: DeferredCache[CacheKey, Any] = DeferredCache( cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__, name=self.name,
max_entries=self.max_entries, max_entries=self.max_entries,
tree=self.tree, tree=self.tree,
iterable=self.iterable, iterable=self.iterable,
@ -372,7 +376,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
wrapped.cache = cache wrapped.cache = cache
wrapped.num_args = self.num_args wrapped.num_args = self.num_args
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.name] = wrapped
return wrapped return wrapped
@ -393,6 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cached_method_name: str, cached_method_name: str,
list_name: str, list_name: str,
num_args: Optional[int] = None, num_args: Optional[int] = None,
name: Optional[str] = None,
): ):
""" """
Args: Args:
@ -403,7 +408,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
but including list_name) to use as cache keys. Defaults to all but including list_name) to use as cache keys. Defaults to all
named args of the function. named args of the function.
""" """
super().__init__(orig, num_args=num_args, uncached_args=None) super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
self.list_name = list_name self.list_name = list_name
@ -525,7 +530,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
else: else:
return defer.succeed(results) return defer.succeed(results)
obj.__dict__[self.orig.__name__] = wrapped obj.__dict__[self.name] = wrapped
return wrapped return wrapped
@ -577,6 +582,7 @@ def cached(
cache_context: bool = False, cache_context: bool = False,
iterable: bool = False, iterable: bool = False,
prune_unread_entries: bool = True, prune_unread_entries: bool = True,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]: ) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor( func = lambda orig: DeferredCacheDescriptor(
orig, orig,
@ -587,13 +593,18 @@ def cached(
cache_context=cache_context, cache_context=cache_context,
iterable=iterable, iterable=iterable,
prune_unread_entries=prune_unread_entries, prune_unread_entries=prune_unread_entries,
name=name,
) )
return cast(Callable[[F], _CachedFunction[F]], func) return cast(Callable[[F], _CachedFunction[F]], func)
def cachedList( def cachedList(
*, cached_method_name: str, list_name: str, num_args: Optional[int] = None *,
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]: ) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
@ -628,6 +639,7 @@ def cachedList(
cached_method_name=cached_method_name, cached_method_name=cached_method_name,
list_name=list_name, list_name=list_name,
num_args=num_args, num_args=num_args,
name=name,
) )
return cast(Callable[[F], _CachedFunction[F]], func) return cast(Callable[[F], _CachedFunction[F]], func)