diff --git a/changelog.d/14723.bugfix b/changelog.d/14723.bugfix new file mode 100644 index 000000000..e1f89cee3 --- /dev/null +++ b/changelog.d/14723.bugfix @@ -0,0 +1 @@ +Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 658d89210..b5e40da53 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -152,6 +152,9 @@ class ReplicationDataHandler: rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. """ self.store.process_replication_rows(stream_name, instance_name, token, rows) + # NOTE: this must be called after process_replication_rows to ensure any + # cache invalidations are first handled before any stream ID advances. + self.store.process_replication_position(stream_name, instance_name, token) if self.send_handler: await self.send_handler.process_replication_rows(stream_name, token, rows) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 69abf6fa8..41d911101 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta): token: int, rows: Iterable[Any], ) -> None: - pass + """ + Used by storage classes to invalidate caches based on incoming replication data. These + must not update any ID generators, use `process_replication_position`. + """ + + def process_replication_position( # noqa: B027 (no-op by design) + self, + stream_name: str, + instance_name: str, + token: int, + ) -> None: + """ + Used by storage classes to advance ID generators based on incoming replication data. This + is called after process_replication_rows such that caches are invalidated before any token + positions advance. + """ def _invalidate_state_caches( self, room_id: str, members_changed: Collection[str] diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index e59776f43..86032897f 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -436,10 +436,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) token: int, rows: Iterable[Any], ) -> None: - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) + if stream_name == AccountDataStream.NAME: for row in rows: if not row.room_id: self.get_global_account_data_by_type_for_user.invalidate( @@ -454,6 +451,15 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict ) -> int: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index a58668a38..2179a8bf5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -164,9 +164,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=True, ) elif stream_name == CachesStream.NAME: - if self._cache_id_gen: - self._cache_id_gen.advance(instance_name, token) - for row in rows: if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.keys is None: @@ -182,6 +179,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == CachesStream.NAME: + if self._cache_id_gen: + self._cache_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 48a54d9cb..713be91c5 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -157,6 +157,13 @@ class DeviceInboxWorkerStore(SQLBaseStore): ) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def get_to_device_stream_token(self) -> int: return self._device_inbox_id_gen.get_current_token() diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index a5bb4d404..db877e3f1 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -162,14 +162,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] ) -> None: if stream_name == DeviceListsStream.NAME: - self._device_list_id_gen.advance(instance_name, token) self._invalidate_caches_for_devices(token, rows) elif stream_name == UserSignatureStream.NAME: - self._device_list_id_gen.advance(instance_name, token) for row in rows: self._user_signature_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == DeviceListsStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + elif stream_name == UserSignatureStream.NAME: + self._device_list_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _invalidate_caches_for_devices( self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] ) -> None: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 761b15a81..d150fa8a9 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -388,11 +388,7 @@ class EventsWorkerStore(SQLBaseStore): token: int, rows: Iterable[Any], ) -> None: - if stream_name == EventsStream.NAME: - self._stream_id_gen.advance(instance_name, token) - elif stream_name == BackfillStream.NAME: - self._backfill_id_gen.advance(instance_name, -token) - elif stream_name == UnPartialStatedEventStream.NAME: + if stream_name == UnPartialStatedEventStream.NAME: for row in rows: assert isinstance(row, UnPartialStatedEventStreamRow) @@ -405,6 +401,15 @@ class EventsWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == EventsStream.NAME: + self._stream_id_gen.advance(instance_name, token) + elif stream_name == BackfillStream.NAME: + self._backfill_id_gen.advance(instance_name, -token) + super().process_replication_position(stream_name, instance_name, token) + async def have_censored_event(self, event_id: str) -> bool: """Check if an event has been censored, i.e. if the content of the event has been erased from the database due to a redaction. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 9769a18a9..7b6081504 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -439,8 +439,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) rows: Iterable[Any], ) -> None: if stream_name == PresenceStream.NAME: - self._presence_id_gen.advance(instance_name, token) for row in rows: self.presence_stream_cache.entity_has_changed(row.user_id, token) self._get_presence_for_user.invalidate((row.user_id,)) return super().process_replication_rows(stream_name, instance_name, token, rows) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PresenceStream.NAME: + self._presence_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index d4c64c46a..d4e4b777d 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -154,6 +154,13 @@ class PushRulesWorkerStore( self.push_rules_stream_cache.entity_has_changed(row.user_id, token) return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == PushRulesStream.NAME: + self._push_rules_stream_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + @cached(max_entries=5000) async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: rows = await self.db_pool.simple_select_list( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 40fd781a6..7f24a3b6e 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -111,12 +111,12 @@ class PusherWorkerStore(SQLBaseStore): def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() - def process_replication_rows( - self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + def process_replication_position( + self, stream_name: str, instance_name: str, token: int ) -> None: if stream_name == PushersStream.NAME: self._pushers_id_gen.advance(instance_name, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + super().process_replication_position(stream_name, instance_name, token) async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index e06725f69..86f5bce5f 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -588,6 +588,13 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index b0f5de67a..e23c927e0 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -300,13 +300,19 @@ class TagsWorkerStore(AccountDataWorkerStore): rows: Iterable[Any], ) -> None: if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) for row in rows: self.get_tags_for_user.invalidate((row.user_id,)) self._account_data_stream_cache.entity_has_changed(row.user_id, token) super().process_replication_rows(stream_name, instance_name, token, rows) + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + class TagsStore(TagsWorkerStore): pass