mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Fix limit logic for AccountDataStream (#7384)
Make sure that the AccountDataStream presents complete updates, in the right order. This is much the same fix as #7337 and #7358, but applied to a different stream.
This commit is contained in:
parent
34a43f0084
commit
6c1f7c722f
1
changelog.d/7384.bugfix
Normal file
1
changelog.d/7384.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
|
@ -14,14 +14,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import heapq
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# the number of rows to request from an update_function.
|
||||
@ -37,7 +50,7 @@ Token = int
|
||||
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
|
||||
# just a row from a database query, though this is dependent on the stream in question.
|
||||
#
|
||||
StreamRow = Tuple
|
||||
StreamRow = TypeVar("StreamRow", bound=Tuple)
|
||||
|
||||
# The type returned by the update_function of a stream, as well as get_updates(),
|
||||
# get_updates_since, etc.
|
||||
@ -533,32 +546,63 @@ class AccountDataStream(Stream):
|
||||
"""
|
||||
|
||||
AccountDataStreamRow = namedtuple(
|
||||
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
|
||||
"AccountDataStream",
|
||||
("user_id", "room_id", "data_type"), # str # Optional[str] # str
|
||||
)
|
||||
|
||||
NAME = "account_data"
|
||||
ROW_TYPE = AccountDataStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(self.store.get_max_account_data_stream_id),
|
||||
db_query_to_update_function(self._update_function),
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
async def _update_function(self, from_token, to_token, limit):
|
||||
global_results, room_results = await self.store.get_all_updated_account_data(
|
||||
from_token, from_token, to_token, limit
|
||||
async def _update_function(
|
||||
self, instance_name: str, from_token: int, to_token: int, limit: int
|
||||
) -> StreamUpdateResult:
|
||||
limited = False
|
||||
global_results = await self.store.get_updated_global_account_data(
|
||||
from_token, to_token, limit
|
||||
)
|
||||
|
||||
results = list(room_results)
|
||||
results.extend(
|
||||
(stream_id, user_id, None, account_data_type)
|
||||
# if the global results hit the limit, we'll need to limit the room results to
|
||||
# the same stream token.
|
||||
if len(global_results) >= limit:
|
||||
to_token = global_results[-1][0]
|
||||
limited = True
|
||||
|
||||
room_results = await self.store.get_updated_room_account_data(
|
||||
from_token, to_token, limit
|
||||
)
|
||||
|
||||
# likewise, if the room results hit the limit, limit the global results to
|
||||
# the same stream token.
|
||||
if len(room_results) >= limit:
|
||||
to_token = room_results[-1][0]
|
||||
limited = True
|
||||
|
||||
# convert the global results to the right format, and limit them to the to_token
|
||||
# at the same time
|
||||
global_rows = (
|
||||
(stream_id, (user_id, None, account_data_type))
|
||||
for stream_id, user_id, account_data_type in global_results
|
||||
if stream_id <= to_token
|
||||
)
|
||||
|
||||
return results
|
||||
# we know that the room_results are already limited to `to_token` so no need
|
||||
# for a check on `stream_id` here.
|
||||
room_rows = (
|
||||
(stream_id, (user_id, room_id, account_data_type))
|
||||
for stream_id, user_id, room_id, account_data_type in room_results
|
||||
)
|
||||
|
||||
# we need to return a sorted list, so merge them together.
|
||||
updates = list(heapq.merge(room_rows, global_rows))
|
||||
return updates, to_token, limited
|
||||
|
||||
|
||||
class GroupServerStream(Stream):
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
||||
)
|
||||
|
||||
def get_all_updated_account_data(
|
||||
self, last_global_id, last_room_id, current_id, limit
|
||||
):
|
||||
"""Get all the client account_data that has changed on the server
|
||||
Args:
|
||||
last_global_id(int): The position to fetch from for top level data
|
||||
last_room_id(int): The position to fetch from for per room data
|
||||
current_id(int): The position to fetch up to.
|
||||
Returns:
|
||||
A deferred pair of lists of tuples of stream_id int, user_id string,
|
||||
room_id string, and type string.
|
||||
"""
|
||||
if last_room_id == current_id and last_global_id == current_id:
|
||||
return defer.succeed(([], []))
|
||||
async def get_updated_global_account_data(
|
||||
self, last_id: int, current_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str]]:
|
||||
"""Get the global account_data that has changed, for the account_data stream
|
||||
|
||||
def get_updated_account_data_txn(txn):
|
||||
Args:
|
||||
last_id: the last stream_id from the previous batch.
|
||||
current_id: the maximum stream_id to return up to
|
||||
limit: the maximum number of rows to return
|
||||
|
||||
Returns:
|
||||
A list of tuples of stream_id int, user_id string,
|
||||
and type string.
|
||||
"""
|
||||
if last_id == current_id:
|
||||
return []
|
||||
|
||||
def get_updated_global_account_data_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, account_data_type"
|
||||
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_global_id, current_id, limit))
|
||||
global_results = txn.fetchall()
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_updated_global_account_data", get_updated_global_account_data_txn
|
||||
)
|
||||
|
||||
async def get_updated_room_account_data(
|
||||
self, last_id: int, current_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str, str]]:
|
||||
"""Get the global account_data that has changed, for the account_data stream
|
||||
|
||||
Args:
|
||||
last_id: the last stream_id from the previous batch.
|
||||
current_id: the maximum stream_id to return up to
|
||||
limit: the maximum number of rows to return
|
||||
|
||||
Returns:
|
||||
A list of tuples of stream_id int, user_id string,
|
||||
room_id string and type string.
|
||||
"""
|
||||
if last_id == current_id:
|
||||
return []
|
||||
|
||||
def get_updated_room_account_data_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, room_id, account_data_type"
|
||||
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_room_id, current_id, limit))
|
||||
room_results = txn.fetchall()
|
||||
return global_results, room_results
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_updated_account_data_txn", get_updated_account_data_txn
|
||||
return await self.db.runInteraction(
|
||||
"get_updated_room_account_data", get_updated_room_account_data_txn
|
||||
)
|
||||
|
||||
def get_updated_account_data_for_user(self, user_id, stream_id):
|
||||
|
117
tests/replication/tcp/streams/test_account_data.py
Normal file
117
tests/replication/tcp/streams/test_account_data.py
Normal file
@ -0,0 +1,117 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
_STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||
AccountDataStream,
|
||||
)
|
||||
|
||||
from tests.replication._base import BaseStreamTestCase
|
||||
|
||||
|
||||
class AccountDataStreamTestCase(BaseStreamTestCase):
|
||||
def test_update_function_room_account_data_limit(self):
|
||||
"""Test replication with many room account data updates
|
||||
"""
|
||||
store = self.hs.get_datastore()
|
||||
|
||||
# generate lots of account data updates
|
||||
updates = []
|
||||
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
|
||||
update = "m.test_type.%i" % (i,)
|
||||
self.get_success(
|
||||
store.add_account_data_to_room("test_user", "test_room", update, {})
|
||||
)
|
||||
updates.append(update)
|
||||
|
||||
# also one global update
|
||||
self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
|
||||
|
||||
# tell the notifier to catch up to avoid duplicate rows.
|
||||
# workaround for https://github.com/matrix-org/synapse/issues/7360
|
||||
# FIXME remove this when the above is fixed
|
||||
self.replicate()
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# received
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now reconnect to pull the updates
|
||||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
|
||||
for t in updates:
|
||||
(stream_name, token, row) = received_rows.pop(0)
|
||||
self.assertEqual(stream_name, AccountDataStream.NAME)
|
||||
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||
self.assertEqual(row.data_type, t)
|
||||
self.assertEqual(row.room_id, "test_room")
|
||||
|
||||
(stream_name, token, row) = received_rows.pop(0)
|
||||
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||
self.assertEqual(row.data_type, "m.global")
|
||||
self.assertIsNone(row.room_id)
|
||||
|
||||
self.assertEqual([], received_rows)
|
||||
|
||||
def test_update_function_global_account_data_limit(self):
|
||||
"""Test replication with many global account data updates
|
||||
"""
|
||||
store = self.hs.get_datastore()
|
||||
|
||||
# generate lots of account data updates
|
||||
updates = []
|
||||
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
|
||||
update = "m.test_type.%i" % (i,)
|
||||
self.get_success(store.add_account_data_for_user("test_user", update, {}))
|
||||
updates.append(update)
|
||||
|
||||
# also one per-room update
|
||||
self.get_success(
|
||||
store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
|
||||
)
|
||||
|
||||
# tell the notifier to catch up to avoid duplicate rows.
|
||||
# workaround for https://github.com/matrix-org/synapse/issues/7360
|
||||
# FIXME remove this when the above is fixed
|
||||
self.replicate()
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# received
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now reconnect to pull the updates
|
||||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
|
||||
for t in updates:
|
||||
(stream_name, token, row) = received_rows.pop(0)
|
||||
self.assertEqual(stream_name, AccountDataStream.NAME)
|
||||
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||
self.assertEqual(row.data_type, t)
|
||||
self.assertIsNone(row.room_id)
|
||||
|
||||
(stream_name, token, row) = received_rows.pop(0)
|
||||
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||
self.assertEqual(row.data_type, "m.per_room")
|
||||
self.assertEqual(row.room_id, "test_room")
|
||||
|
||||
self.assertEqual([], received_rows)
|
Loading…
Reference in New Issue
Block a user