Merge account data streams (#14826)

This commit is contained in:
Erik Johnston 2023-01-13 14:57:43 +00:00 committed by GitHub
parent 1416096527
commit 73ff493dfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 75 additions and 83 deletions

1
changelog.d/14826.misc Normal file
View File

@ -0,0 +1 @@
Merge tag and normal account data replication streams.

View File

@ -88,6 +88,18 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
``` ```
# Upgrading to v1.76.0
## Changes to the account data replication streams
Synapse has changed the format of the account data replication streams (between
workers). This is a forwards- and backwards-incompatible change: v1.75 workers
cannot process account data replicated by v1.76 workers, and vice versa.
Once all workers are upgraded to v1.76 (or downgraded to v1.75), account data
replication will resume as normal.
# Upgrading to v1.74.0 # Upgrading to v1.74.0
## Unicode support in user search ## Unicode support in user search

View File

@ -249,6 +249,7 @@ class RoomEncryptionAlgorithms:
class AccountDataTypes: class AccountDataTypes:
DIRECT: Final = "m.direct" DIRECT: Final = "m.direct"
IGNORED_USER_LIST: Final = "m.ignored_user_list" IGNORED_USER_LIST: Final = "m.ignored_user_list"
TAG: Final = "m.tag"
class HistoryVisibility: class HistoryVisibility:

View File

@ -16,6 +16,7 @@ import logging
import random import random
from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
from synapse.api.constants import AccountDataTypes
from synapse.replication.http.account_data import ( from synapse.replication.http.account_data import (
ReplicationAddRoomAccountDataRestServlet, ReplicationAddRoomAccountDataRestServlet,
ReplicationAddTagRestServlet, ReplicationAddTagRestServlet,
@ -335,7 +336,11 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
for room_id, room_tags in tags.items(): for room_id, room_tags in tags.items():
results.append( results.append(
{"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id} {
"type": AccountDataTypes.TAG,
"content": {"tags": room_tags},
"room_id": room_id,
}
) )
( (

View File

@ -15,7 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig from synapse.events.utils import SerializeEventConfig
@ -239,7 +239,7 @@ class InitialSyncHandler:
tags = tags_by_room.get(event.room_id) tags = tags_by_room.get(event.room_id)
if tags: if tags:
account_data_events.append( account_data_events.append(
{"type": "m.tag", "content": {"tags": tags}} {"type": AccountDataTypes.TAG, "content": {"tags": tags}}
) )
account_data = account_data_by_room.get(event.room_id, {}) account_data = account_data_by_room.get(event.room_id, {})
@ -326,7 +326,9 @@ class InitialSyncHandler:
account_data_events = [] account_data_events = []
tags = await self.store.get_tags_for_room(user_id, room_id) tags = await self.store.get_tags_for_room(user_id, room_id)
if tags: if tags:
account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) account_data_events.append(
{"type": AccountDataTypes.TAG, "content": {"tags": tags}}
)
account_data = await self.store.get_account_data_for_room(user_id, room_id) account_data = await self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items(): for account_data_type, content in account_data.items():

View File

@ -31,7 +31,12 @@ from typing import (
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import (
AccountDataTypes,
EventContentFields,
EventTypes,
Membership,
)
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@ -2331,7 +2336,9 @@ class SyncHandler:
account_data_events = [] account_data_events = []
if tags is not None: if tags is not None:
account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) account_data_events.append(
{"type": AccountDataTypes.TAG, "content": {"tags": tags}}
)
for account_data_type, content in account_data.items(): for account_data_type, content in account_data.items():
account_data_events.append( account_data_events.append(

View File

@ -33,7 +33,6 @@ from synapse.replication.tcp.streams import (
PushersStream, PushersStream,
PushRulesStream, PushRulesStream,
ReceiptsStream, ReceiptsStream,
TagAccountDataStream,
ToDeviceStream, ToDeviceStream,
TypingStream, TypingStream,
UnPartialStatedEventStream, UnPartialStatedEventStream,
@ -168,7 +167,7 @@ class ReplicationDataHandler:
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows] StreamKeyType.PUSH_RULES, token, users=[row.user_id for row in rows]
) )
elif stream_name in (AccountDataStream.NAME, TagAccountDataStream.NAME): elif stream_name in AccountDataStream.NAME:
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows] StreamKeyType.ACCOUNT_DATA, token, users=[row.user_id for row in rows]
) )

View File

@ -58,7 +58,6 @@ from synapse.replication.tcp.streams import (
PresenceStream, PresenceStream,
ReceiptsStream, ReceiptsStream,
Stream, Stream,
TagAccountDataStream,
ToDeviceStream, ToDeviceStream,
TypingStream, TypingStream,
) )
@ -145,7 +144,7 @@ class ReplicationCommandHandler:
continue continue
if isinstance(stream, (AccountDataStream, TagAccountDataStream)): if isinstance(stream, AccountDataStream):
# Only add AccountDataStream and TagAccountDataStream as a source on the # Only add AccountDataStream and TagAccountDataStream as a source on the
# instance in charge of account_data persistence. # instance in charge of account_data persistence.
if hs.get_instance_name() in hs.config.worker.writers.account_data: if hs.get_instance_name() in hs.config.worker.writers.account_data:

View File

@ -35,7 +35,6 @@ from synapse.replication.tcp.streams._base import (
PushRulesStream, PushRulesStream,
ReceiptsStream, ReceiptsStream,
Stream, Stream,
TagAccountDataStream,
ToDeviceStream, ToDeviceStream,
TypingStream, TypingStream,
UserSignatureStream, UserSignatureStream,
@ -62,7 +61,6 @@ STREAMS_MAP = {
DeviceListsStream, DeviceListsStream,
ToDeviceStream, ToDeviceStream,
FederationStream, FederationStream,
TagAccountDataStream,
AccountDataStream, AccountDataStream,
UserSignatureStream, UserSignatureStream,
UnPartialStatedRoomStream, UnPartialStatedRoomStream,
@ -83,7 +81,6 @@ __all__ = [
"CachesStream", "CachesStream",
"DeviceListsStream", "DeviceListsStream",
"ToDeviceStream", "ToDeviceStream",
"TagAccountDataStream",
"AccountDataStream", "AccountDataStream",
"UserSignatureStream", "UserSignatureStream",
"UnPartialStatedRoomStream", "UnPartialStatedRoomStream",

View File

@ -28,8 +28,8 @@ from typing import (
import attr import attr
from synapse.api.constants import AccountDataTypes
from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -495,27 +495,6 @@ class ToDeviceStream(Stream):
) )
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class TagAccountDataStreamRow:
user_id: str
room_id: str
data: JsonDict
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
def __init__(self, hs: "HomeServer"):
store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_account_data_stream_id),
store.get_all_updated_tags,
)
class AccountDataStream(Stream): class AccountDataStream(Stream):
"""Global or per room account data was changed""" """Global or per room account data was changed"""
@ -560,6 +539,19 @@ class AccountDataStream(Stream):
to_token = room_results[-1][0] to_token = room_results[-1][0]
limited = True limited = True
tags, tag_to_token, tags_limited = await self.store.get_all_updated_tags(
instance_name,
from_token,
to_token,
limit,
)
# again, if the tag results hit the limit, limit the global results to
# the same stream token.
if tags_limited:
to_token = tag_to_token
limited = True
# convert the global results to the right format, and limit them to the to_token # convert the global results to the right format, and limit them to the to_token
# at the same time # at the same time
global_rows = ( global_rows = (
@ -568,11 +560,16 @@ class AccountDataStream(Stream):
if stream_id <= to_token if stream_id <= to_token
) )
# we know that the room_results are already limited to `to_token` so no need
# for a check on `stream_id` here.
room_rows = ( room_rows = (
(stream_id, (user_id, room_id, account_data_type)) (stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results for stream_id, user_id, room_id, account_data_type in room_results
if stream_id <= to_token
)
tag_rows = (
(stream_id, (user_id, room_id, AccountDataTypes.TAG))
for stream_id, user_id, room_id in tags
if stream_id <= to_token
) )
# We need to return a sorted list, so merge them together. # We need to return a sorted list, so merge them together.
@ -582,7 +579,9 @@ class AccountDataStream(Stream):
# leading to a comparison between the data tuples. The comparison could # leading to a comparison between the data tuples. The comparison could
# fail due to attempting to compare the `room_id` which results in a # fail due to attempting to compare the `room_id` which results in a
# `TypeError` from comparing a `str` vs `None`. # `TypeError` from comparing a `str` vs `None`.
updates = list(heapq.merge(room_rows, global_rows, key=lambda row: row[0])) updates = list(
heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
)
return updates, to_token, limited return updates, to_token, limited

View File

@ -27,7 +27,7 @@ from typing import (
) )
from synapse.api.constants import AccountDataTypes from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json from synapse.storage._base import db_to_json
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -454,9 +454,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def process_replication_position( def process_replication_position(
self, stream_name: str, instance_name: str, token: int self, stream_name: str, instance_name: str, token: int
) -> None: ) -> None:
if stream_name == TagAccountDataStream.NAME: if stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token) self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token) super().process_replication_position(stream_name, instance_name, token)

View File

@ -17,7 +17,8 @@
import logging import logging
from typing import Any, Dict, Iterable, List, Tuple, cast from typing import Any, Dict, Iterable, List, Tuple, cast
from synapse.replication.tcp.streams import TagAccountDataStream from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@ -54,7 +55,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_all_updated_tags( async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]: ) -> Tuple[List[Tuple[int, str, str]], int, bool]:
"""Get updates for tags replication stream. """Get updates for tags replication stream.
Args: Args:
@ -73,7 +74,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
The token returned can be used in a subsequent call to this The token returned can be used in a subsequent call to this
function to get further updatees. function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data The updates are a list of tuples of stream ID, user ID and room ID
""" """
if last_id == current_id: if last_id == current_id:
@ -96,38 +97,13 @@ class TagsWorkerStore(AccountDataWorkerStore):
"get_all_updated_tags", get_all_updated_tags_txn "get_all_updated_tags", get_all_updated_tags_txn
) )
def get_tag_content(
txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
tags = []
for tag, content in txn:
tags.append(json_encoder.encode(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, (user_id, room_id, tag_json)))
return results
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
tags = await self.db_pool.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
)
results.extend(tags)
limited = False limited = False
upto_token = current_id upto_token = current_id
if len(results) >= limit: if len(tag_ids) >= limit:
upto_token = results[-1][0] upto_token = tag_ids[-1][0]
limited = True limited = True
return results, upto_token, limited return tag_ids, upto_token, limited
async def get_updated_tags( async def get_updated_tags(
self, user_id: str, stream_id: int self, user_id: str, stream_id: int
@ -299,20 +275,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
token: int, token: int,
rows: Iterable[Any], rows: Iterable[Any],
) -> None: ) -> None:
if stream_name == TagAccountDataStream.NAME: if stream_name == AccountDataStream.NAME:
for row in rows: for row in rows:
self.get_tags_for_user.invalidate((row.user_id,)) if row.data_type == AccountDataTypes.TAG:
self._account_data_stream_cache.entity_has_changed(row.user_id, token) self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
row.user_id, token
)
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
class TagsStore(TagsWorkerStore): class TagsStore(TagsWorkerStore):
pass pass