diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index d490a374e..e05465bc1 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -35,6 +35,11 @@ RoomsForUser = namedtuple( class RoomMemberStore(SQLBaseStore): + def __init__(self, *args, **kw): + super(RoomMemberStore, self).__init__(*args, **kw) + + self._user_rooms_cache = {} + def _store_room_member_txn(self, txn, event): """Store a room member in the database. """ @@ -98,6 +103,8 @@ class RoomMemberStore(SQLBaseStore): txn.execute(sql, (event.room_id, domain)) + self.invalidate_rooms_for_user(target_user_id) + @defer.inlineCallbacks def get_room_member(self, user_id, room_id): """Retrieve the current state of a room member. @@ -240,23 +247,43 @@ class RoomMemberStore(SQLBaseStore): results = self._parse_events_txn(txn, rows) return results + # TODO(paul): Create a nice @cached decorator to do this + # @cached + # def get_foo(...) + # ... + # invalidate_foo = get_foo.invalidator + + @defer.inlineCallbacks + def get_rooms_for_user(self, user_id): + # TODO(paul): put some performance counters in here so we can easily + # track what impact this cache is having + if user_id in self._user_rooms_cache: + defer.returnValue(self._user_rooms_cache[user_id]) + + rooms = yield self.get_rooms_for_user_where_membership_is( + user_id, membership_list=[Membership.JOIN], + ) + + self._user_rooms_cache[user_id] = rooms + defer.returnValue(rooms) + + def invalidate_rooms_for_user(self, user_id): + if user_id in self._user_rooms_cache: + del self._user_rooms_cache[user_id] + @defer.inlineCallbacks def user_rooms_intersect(self, user_id_list): """ Checks whether all the users whose IDs are given in a list share a room. This is a "hot path" function that's called a lot, e.g. by presence for - generating the event stream. + generating the event stream. As such, it is implemented locally by + wrapping logic around heavily-cached database queries. """ if len(user_id_list) < 2: defer.returnValue(True) - deferreds = [ - self.get_rooms_for_user_where_membership_is( - u, membership_list=[Membership.JOIN], - ) - for u in user_id_list - ] + deferreds = [self.get_rooms_for_user(u) for u in user_id_list] results = yield defer.DeferredList(deferreds)