diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index f59b0eabb..735c03c7e 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -15,7 +15,10 @@ from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker +from synapse.storage import DataStore from synapse.storage.account_data import AccountDataStore +from synapse.storage.tags import TagsStore +from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedAccountDataStore(BaseSlavedStore): @@ -25,6 +28,14 @@ class SlavedAccountDataStore(BaseSlavedStore): self._account_data_id_gen = SlavedIdTracker( db_conn, "account_data_max_stream_id", "stream_id", ) + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", + self._account_data_id_gen.get_current_token(), + ) + + get_account_data_for_user = ( + AccountDataStore.__dict__["get_account_data_for_user"] + ) get_global_account_data_by_type_for_users = ( AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] @@ -34,6 +45,16 @@ class SlavedAccountDataStore(BaseSlavedStore): AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] ) + get_tags_for_user = TagsStore.__dict__["get_tags_for_user"] + + get_updated_tags = DataStore.get_updated_tags.__func__ + get_updated_account_data_for_user = ( + DataStore.get_updated_account_data_for_user.__func__ + ) + + def get_max_account_data_stream_id(self): + return self._account_data_id_gen.get_current_token() + def stream_positions(self): result = super(SlavedAccountDataStore, self).stream_positions() position = self._account_data_id_gen.get_current_token() @@ -47,15 +68,33 @@ class SlavedAccountDataStore(BaseSlavedStore): if stream: self._account_data_id_gen.advance(int(stream["position"])) for row in stream["rows"]: - user_id, data_type = row[1:3] + position, user_id, data_type = row[:3] self.get_global_account_data_by_type_for_user.invalidate( (data_type, user_id,) ) + self.get_account_data_for_user.invalidate((user_id,)) + self._account_data_stream_cache.entity_has_changed( + user_id, position + ) stream = result.get("room_account_data") if stream: self._account_data_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + position, user_id = row[:2] + self.get_account_data_for_user.invalidate((user_id,)) + self._account_data_stream_cache.entity_has_changed( + user_id, position + ) stream = result.get("tag_account_data") if stream: self._account_data_id_gen.advance(int(stream["position"])) + for row in stream["rows"]: + position, user_id = row[:2] + self.get_tags_for_user.invalidate((user_id,)) + self._account_data_stream_cache.entity_has_changed( + user_id, position + ) + + return super(SlavedAccountDataStore, self).process_replication(result) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index c0d741452..cbc1ae419 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -23,6 +23,7 @@ from synapse.storage.roommember import RoomMemberStore from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.state import StateStore +from synapse.storage.stream import StreamStore from synapse.util.caches.stream_change_cache import StreamChangeCache import ujson as json @@ -57,6 +58,9 @@ class SlavedEventStore(BaseSlavedStore): "EventsRoomStreamChangeCache", min_event_val, prefilled_cache=event_cache_prefill, ) + self._membership_stream_cache = StreamChangeCache( + "MembershipStreamChangeCache", events_max, + ) # Cached functions can't be accessed through a class instance so we need # to reach inside the __dict__ to extract them. @@ -87,6 +91,9 @@ class SlavedEventStore(BaseSlavedStore): _get_state_group_from_group = ( StateStore.__dict__["_get_state_group_from_group"] ) + get_recent_event_ids_for_room = ( + StreamStore.__dict__["get_recent_event_ids_for_room"] + ) get_unread_push_actions_for_user_in_range = ( DataStore.get_unread_push_actions_for_user_in_range.__func__ @@ -109,10 +116,16 @@ class SlavedEventStore(BaseSlavedStore): DataStore.get_room_events_stream_for_room.__func__ ) get_events_around = DataStore.get_events_around.__func__ + get_state_for_event = DataStore.get_state_for_event.__func__ get_state_for_events = DataStore.get_state_for_events.__func__ get_state_groups = DataStore.get_state_groups.__func__ + get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ + get_room_events_stream_for_rooms = ( + DataStore.get_room_events_stream_for_rooms.__func__ + ) + get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__ - _set_before_and_after = DataStore._set_before_and_after + _set_before_and_after = staticmethod(DataStore._set_before_and_after) _get_events = DataStore._get_events.__func__ _get_events_from_cache = DataStore._get_events_from_cache.__func__ @@ -220,9 +233,9 @@ class SlavedEventStore(BaseSlavedStore): self.get_rooms_for_user.invalidate((event.state_key,)) # self.get_joined_hosts_for_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,)) - # self._membership_stream_cache.entity_has_changed( - # event.state_key, event.internal_metadata.stream_ordering - # ) + self._membership_stream_cache.entity_has_changed( + event.state_key, event.internal_metadata.stream_ordering + ) self.get_invited_rooms_for_user.invalidate((event.state_key,)) if not event.is_state(): diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index ec007516d..ac9662d39 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker from synapse.storage import DataStore from synapse.storage.receipts import ReceiptsStore +from synapse.util.caches.stream_change_cache import StreamChangeCache # So, um, we want to borrow a load of functions intended for reading from # a DataStore, but we don't want to take functions that either write to the @@ -37,11 +38,28 @@ class SlavedReceiptsStore(BaseSlavedStore): db_conn, "receipts_linearized", "stream_id" ) + self._receipts_stream_cache = StreamChangeCache( + "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() + ) + get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] + get_linearized_receipts_for_room = ( + ReceiptsStore.__dict__["get_linearized_receipts_for_room"] + ) + _get_linearized_receipts_for_rooms = ( + ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"] + ) + get_last_receipt_event_id_for_user = ( + ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"] + ) get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ + get_linearized_receipts_for_rooms = ( + DataStore.get_linearized_receipts_for_rooms.__func__ + ) + def stream_positions(self): result = super(SlavedReceiptsStore, self).stream_positions() result["receipts"] = self._receipts_id_gen.get_current_token() @@ -52,10 +70,15 @@ class SlavedReceiptsStore(BaseSlavedStore): if stream: self._receipts_id_gen.advance(int(stream["position"])) for row in stream["rows"]: - room_id, receipt_type, user_id = row[1:4] + position, room_id, receipt_type, user_id = row[:4] self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) + self._receipts_stream_cache.entity_has_changed(room_id, position) return super(SlavedReceiptsStore, self).process_replication(result) def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) + self.get_linearized_receipts_for_room.invalidate_many((room_id,)) + self.get_last_receipt_event_id_for_user.invalidate( + (user_id, room_id, receipt_type) + )