Update all stream IDs after processing replication rows (#14723)

This creates a new store method, `process_replication_position` that
is called after `process_replication_rows`. By moving stream ID advances
here this guarantees any relevant cache invalidations will have been
applied before the stream is advanced.

This avoids race conditions where Python switches between threads mid
way through processing the `process_replication_rows` method where stream
IDs may be advanced before caches are invalidated due to class resolution
ordering.

See this comment/issue for further discussion:
	https://github.com/matrix-org/synapse/issues/14158#issuecomment-1344048703
This commit is contained in:
Nick Mills-Barrett 2023-01-04 11:49:26 +00:00 committed by GitHub
parent c4456114e1
commit db1cfe9c80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 95 additions and 20 deletions

1
changelog.d/14723.bugfix Normal file
View File

@ -0,0 +1 @@
Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar).

View File

@ -152,6 +152,9 @@ class ReplicationDataHandler:
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row. rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
""" """
self.store.process_replication_rows(stream_name, instance_name, token, rows) self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
self.store.process_replication_position(stream_name, instance_name, token)
if self.send_handler: if self.send_handler:
await self.send_handler.process_replication_rows(stream_name, token, rows) await self.send_handler.process_replication_rows(stream_name, token, rows)

View File

@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta):
token: int, token: int,
rows: Iterable[Any], rows: Iterable[Any],
) -> None: ) -> None:
pass """
Used by storage classes to invalidate caches based on incoming replication data. These
must not update any ID generators, use `process_replication_position`.
"""
def process_replication_position( # noqa: B027 (no-op by design)
self,
stream_name: str,
instance_name: str,
token: int,
) -> None:
"""
Used by storage classes to advance ID generators based on incoming replication data. This
is called after process_replication_rows such that caches are invalidated before any token
positions advance.
"""
def _invalidate_state_caches( def _invalidate_state_caches(
self, room_id: str, members_changed: Collection[str] self, room_id: str, members_changed: Collection[str]

View File

@ -436,10 +436,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
token: int, token: int,
rows: Iterable[Any], rows: Iterable[Any],
) -> None: ) -> None:
if stream_name == TagAccountDataStream.NAME: if stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
if not row.room_id: if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate( self.get_global_account_data_by_type_for_user.invalidate(
@ -454,6 +451,15 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
async def add_account_data_to_room( async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int: ) -> int:

View File

@ -164,9 +164,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=True, backfilled=True,
) )
elif stream_name == CachesStream.NAME: elif stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None: if row.keys is None:
@ -182,6 +179,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data data = row.data

View File

@ -157,6 +157,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
def get_to_device_stream_token(self) -> int: def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()

View File

@ -162,14 +162,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows) self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME: elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token) self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
def _invalidate_caches_for_devices( def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow] self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None: ) -> None:

View File

@ -388,11 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
token: int, token: int,
rows: Iterable[Any], rows: Iterable[Any],
) -> None: ) -> None:
if stream_name == EventsStream.NAME: if stream_name == UnPartialStatedEventStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(instance_name, -token)
elif stream_name == UnPartialStatedEventStream.NAME:
for row in rows: for row in rows:
assert isinstance(row, UnPartialStatedEventStreamRow) assert isinstance(row, UnPartialStatedEventStreamRow)
@ -405,6 +401,15 @@ class EventsWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_position(stream_name, instance_name, token)
async def have_censored_event(self, event_id: str) -> bool: async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased """Check if an event has been censored, i.e. if the content of the event has been erased
from the database due to a redaction. from the database due to a redaction.

View File

@ -439,8 +439,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
rows: Iterable[Any], rows: Iterable[Any],
) -> None: ) -> None:
if stream_name == PresenceStream.NAME: if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token) self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,)) self._get_presence_for_user.invalidate((row.user_id,))
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

View File

@ -154,6 +154,13 @@ class PushRulesWorkerStore(
self.push_rules_stream_cache.entity_has_changed(row.user_id, token) self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
@cached(max_entries=5000) @cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules: async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(

View File

@ -111,12 +111,12 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_stream_token(self) -> int: def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
def process_replication_rows( def process_replication_position(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] self, stream_name: str, instance_name: str, token: int
) -> None: ) -> None:
if stream_name == PushersStream.NAME: if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token) self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_position(stream_name, instance_name, token)
async def get_pushers_by_app_id_and_pushkey( async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str self, app_id: str, pushkey: str

View File

@ -588,6 +588,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
def _insert_linearized_receipt_txn( def _insert_linearized_receipt_txn(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,

View File

@ -300,13 +300,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
rows: Iterable[Any], rows: Iterable[Any],
) -> None: ) -> None:
if stream_name == TagAccountDataStream.NAME: if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows: for row in rows:
self.get_tags_for_user.invalidate((row.user_id,)) self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token) self._account_data_stream_cache.entity_has_changed(row.user_id, token)
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
class TagsStore(TagsWorkerStore): class TagsStore(TagsWorkerStore):
pass pass