Merge account data streams (#14826)

This commit is contained in:
Erik Johnston 2023-01-13 14:57:43 +00:00 committed by GitHub
parent 1416096527
commit 73ff493dfb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 75 additions and 83 deletions

View file

@ -27,7 +27,7 @@ from typing import (
)
from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import (
DatabasePool,
@ -454,9 +454,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
elif stream_name == AccountDataStream.NAME:
if stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)

View file

@ -17,7 +17,8 @@
import logging
from typing import Any, Dict, Iterable, List, Tuple, cast
from synapse.replication.tcp.streams import TagAccountDataStream
from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream
from synapse.storage._base import db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@ -54,7 +55,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, Tuple[str, str, str]]], int, bool]:
) -> Tuple[List[Tuple[int, str, str]], int, bool]:
"""Get updates for tags replication stream.
Args:
@ -73,7 +74,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
The updates are a list of tuples of stream ID, user ID and room ID
"""
if last_id == current_id:
@ -96,38 +97,13 @@ class TagsWorkerStore(AccountDataWorkerStore):
"get_all_updated_tags", get_all_updated_tags_txn
)
def get_tag_content(
txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
) -> List[Tuple[int, Tuple[str, str, str]]]:
sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
tags = []
for tag, content in txn:
tags.append(json_encoder.encode(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, (user_id, room_id, tag_json)))
return results
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
tags = await self.db_pool.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
)
results.extend(tags)
limited = False
upto_token = current_id
if len(results) >= limit:
upto_token = results[-1][0]
if len(tag_ids) >= limit:
upto_token = tag_ids[-1][0]
limited = True
return results, upto_token, limited
return tag_ids, upto_token, limited
async def get_updated_tags(
self, user_id: str, stream_id: int
@ -299,20 +275,16 @@ class TagsWorkerStore(AccountDataWorkerStore):
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == TagAccountDataStream.NAME:
if stream_name == AccountDataStream.NAME:
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
if row.data_type == AccountDataTypes.TAG:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
row.user_id, token
)
super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
class TagsStore(TagsWorkerStore):
pass