Fix limit logic for AccountDataStream (#7384)

Make sure that the AccountDataStream presents complete updates, in the right
order.

This is much the same fix as #7337 and #7358, but applied to a different stream.
This commit is contained in:
Richard van der Hoff 2020-05-15 19:03:25 +01:00 committed by GitHub
parent 34a43f0084
commit 6c1f7c722f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 220 additions and 34 deletions

View file

@ -14,14 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
import logging
from collections import namedtuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Optional,
Tuple,
TypeVar,
)
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)
# the number of rows to request from an update_function.
@ -37,7 +50,7 @@ Token = int
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
StreamRow = Tuple
StreamRow = TypeVar("StreamRow", bound=Tuple)
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
@ -533,32 +546,63 @@ class AccountDataStream(Stream):
"""
AccountDataStreamRow = namedtuple(
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
"AccountDataStream",
("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
def __init__(self, hs):
def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self.store.get_max_account_data_stream_id),
db_query_to_update_function(self._update_function),
self._update_function,
)
async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
async def _update_function(
self, instance_name: str, from_token: int, to_token: int, limit: int
) -> StreamUpdateResult:
limited = False
global_results = await self.store.get_updated_global_account_data(
from_token, to_token, limit
)
results = list(room_results)
results.extend(
(stream_id, user_id, None, account_data_type)
# if the global results hit the limit, we'll need to limit the room results to
# the same stream token.
if len(global_results) >= limit:
to_token = global_results[-1][0]
limited = True
room_results = await self.store.get_updated_room_account_data(
from_token, to_token, limit
)
# likewise, if the room results hit the limit, limit the global results to
# the same stream token.
if len(room_results) >= limit:
to_token = room_results[-1][0]
limited = True
# convert the global results to the right format, and limit them to the to_token
# at the same time
global_rows = (
(stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
if stream_id <= to_token
)
return results
# we know that the room_results are already limited to `to_token` so no need
# for a check on `stream_id` here.
room_rows = (
(stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results
)
# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(room_rows, global_rows))
return updates, to_token, limited
class GroupServerStream(Stream):