mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 11:16:07 -04:00
Fix limit logic for AccountDataStream (#7384)
Make sure that the AccountDataStream presents complete updates, in the right order. This is much the same fix as #7337 and #7358, but applied to a different stream.
This commit is contained in:
parent
34a43f0084
commit
6c1f7c722f
4 changed files with 220 additions and 34 deletions
|
@ -14,14 +14,27 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import heapq
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# the number of rows to request from an update_function.
|
||||
|
@ -37,7 +50,7 @@ Token = int
|
|||
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
|
||||
# just a row from a database query, though this is dependent on the stream in question.
|
||||
#
|
||||
StreamRow = Tuple
|
||||
StreamRow = TypeVar("StreamRow", bound=Tuple)
|
||||
|
||||
# The type returned by the update_function of a stream, as well as get_updates(),
|
||||
# get_updates_since, etc.
|
||||
|
@ -533,32 +546,63 @@ class AccountDataStream(Stream):
|
|||
"""
|
||||
|
||||
AccountDataStreamRow = namedtuple(
|
||||
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
|
||||
"AccountDataStream",
|
||||
("user_id", "room_id", "data_type"), # str # Optional[str] # str
|
||||
)
|
||||
|
||||
NAME = "account_data"
|
||||
ROW_TYPE = AccountDataStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(self.store.get_max_account_data_stream_id),
|
||||
db_query_to_update_function(self._update_function),
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
async def _update_function(self, from_token, to_token, limit):
|
||||
global_results, room_results = await self.store.get_all_updated_account_data(
|
||||
from_token, from_token, to_token, limit
|
||||
async def _update_function(
|
||||
self, instance_name: str, from_token: int, to_token: int, limit: int
|
||||
) -> StreamUpdateResult:
|
||||
limited = False
|
||||
global_results = await self.store.get_updated_global_account_data(
|
||||
from_token, to_token, limit
|
||||
)
|
||||
|
||||
results = list(room_results)
|
||||
results.extend(
|
||||
(stream_id, user_id, None, account_data_type)
|
||||
# if the global results hit the limit, we'll need to limit the room results to
|
||||
# the same stream token.
|
||||
if len(global_results) >= limit:
|
||||
to_token = global_results[-1][0]
|
||||
limited = True
|
||||
|
||||
room_results = await self.store.get_updated_room_account_data(
|
||||
from_token, to_token, limit
|
||||
)
|
||||
|
||||
# likewise, if the room results hit the limit, limit the global results to
|
||||
# the same stream token.
|
||||
if len(room_results) >= limit:
|
||||
to_token = room_results[-1][0]
|
||||
limited = True
|
||||
|
||||
# convert the global results to the right format, and limit them to the to_token
|
||||
# at the same time
|
||||
global_rows = (
|
||||
(stream_id, (user_id, None, account_data_type))
|
||||
for stream_id, user_id, account_data_type in global_results
|
||||
if stream_id <= to_token
|
||||
)
|
||||
|
||||
return results
|
||||
# we know that the room_results are already limited to `to_token` so no need
|
||||
# for a check on `stream_id` here.
|
||||
room_rows = (
|
||||
(stream_id, (user_id, room_id, account_data_type))
|
||||
for stream_id, user_id, room_id, account_data_type in room_results
|
||||
)
|
||||
|
||||
# we need to return a sorted list, so merge them together.
|
||||
updates = list(heapq.merge(room_rows, global_rows))
|
||||
return updates, to_token, limited
|
||||
|
||||
|
||||
class GroupServerStream(Stream):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue