Faster room joins: fix race in recalculation of current room state (#13151)

Bounce recalculation of current state to the correct event persister and
move recalculation of current state into the event persistence queue, to
avoid concurrent updates to a room's current state.

Also give recalculation of a room's current state a real stream
ordering.

Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
Sean Quah 2022-07-07 13:19:31 +01:00 committed by GitHub
parent 2b5ab8e367
commit 1391a76cd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 214 additions and 55 deletions

1
changelog.d/13151.misc Normal file
View File

@ -0,0 +1 @@
Faster room joins: fix race in recalculation of current room state.

View File

@ -1559,14 +1559,9 @@ class FederationHandler:
# all the events are updated, so we can update current state and # all the events are updated, so we can update current state and
# clear the lazy-loading flag. # clear the lazy-loading flag.
logger.info("Updating current state for %s", room_id) logger.info("Updating current state for %s", room_id)
# TODO(faster_joins): support workers # TODO(faster_joins): notify workers in notify_room_un_partial_stated
# https://github.com/matrix-org/synapse/issues/12994 # https://github.com/matrix-org/synapse/issues/12994
assert ( await self.state_handler.update_current_state(room_id)
self._storage_controllers.persistence is not None
), "worker-mode deployments not currently supported here"
await self._storage_controllers.persistence.update_current_state(
room_id
)
logger.info("Clearing partial-state flag for %s", room_id) logger.info("Clearing partial-state flag for %s", room_id)
success = await self.store.clear_partial_state_room(room_id) success = await self.store.clear_partial_state_room(room_id)

View File

@ -25,6 +25,7 @@ from synapse.replication.http import (
push, push,
register, register,
send_event, send_event,
state,
streams, streams,
) )
@ -48,6 +49,7 @@ class ReplicationRestResource(JsonResource):
streams.register_servlets(hs, self) streams.register_servlets(hs, self)
account_data.register_servlets(hs, self) account_data.register_servlets(hs, self)
push.register_servlets(hs, self) push.register_servlets(hs, self)
state.register_servlets(hs, self)
# The following can't currently be instantiated on workers. # The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:

View File

@ -0,0 +1,75 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint):
"""Recalculates the current state for a room, and persists it.
The API looks like:
POST /_synapse/replication/update_current_state/:room_id
{}
200 OK
{}
"""
NAME = "update_current_state"
PATH_ARGS = ("room_id",)
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._state_handler = hs.get_state_handler()
self._events_shard_config = hs.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
@staticmethod
async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override]
return {}
async def _handle_request( # type: ignore[override]
self, request: Request, room_id: str
) -> Tuple[int, JsonDict]:
writer_instance = self._events_shard_config.get_instance(room_id)
if writer_instance != self._instance_name:
raise SynapseError(
400, "/update_current_state request was routed to the wrong worker"
)
await self._state_handler.update_current_state(room_id)
return 200, {}
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.get_instance_name() in hs.config.worker.writers.events:
ReplicationUpdateCurrentStateRestServlet(hs).register(http_server)

View File

@ -43,6 +43,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersio
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.logging.context import ContextResourceUsage from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
@ -129,6 +130,12 @@ class StateHandler:
self.hs = hs self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
self._events_shard_config = hs.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self._update_current_state_client = (
ReplicationUpdateCurrentStateRestServlet.make_client(hs)
)
async def get_current_state_ids( async def get_current_state_ids(
self, self,
@ -423,6 +430,24 @@ class StateHandler:
return {key: state_map[ev_id] for key, ev_id in new_state.items()} return {key: state_map[ev_id] for key, ev_id in new_state.items()}
async def update_current_state(self, room_id: str) -> None:
"""Recalculates the current state for a room, and persists it.
Raises:
SynapseError(502): if all attempts to connect to the event persister worker
fail
"""
writer_instance = self._events_shard_config.get_instance(room_id)
if writer_instance != self._instance_name:
await self._update_current_state_client(
instance_name=writer_instance,
room_id=room_id,
)
return
assert self._storage_controllers.persistence is not None
await self._storage_controllers.persistence.update_current_state(room_id)
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)
class _StateResMetrics: class _StateResMetrics:

View File

@ -22,6 +22,7 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
ClassVar,
Collection, Collection,
Deque, Deque,
Dict, Dict,
@ -33,6 +34,7 @@ from typing import (
Set, Set,
Tuple, Tuple,
TypeVar, TypeVar,
Union,
) )
import attr import attr
@ -111,9 +113,43 @@ times_pruned_extremities = Counter(
@attr.s(auto_attribs=True, slots=True) @attr.s(auto_attribs=True, slots=True)
class _EventPersistQueueItem: class _PersistEventsTask:
"""A batch of events to persist."""
name: ClassVar[str] = "persist_event_batch" # used for opentracing
events_and_contexts: List[Tuple[EventBase, EventContext]] events_and_contexts: List[Tuple[EventBase, EventContext]]
backfilled: bool backfilled: bool
def try_merge(self, task: "_EventPersistQueueTask") -> bool:
"""Batches events with the same backfilled option together."""
if (
not isinstance(task, _PersistEventsTask)
or self.backfilled != task.backfilled
):
return False
self.events_and_contexts.extend(task.events_and_contexts)
return True
@attr.s(auto_attribs=True, slots=True)
class _UpdateCurrentStateTask:
"""A room whose current state needs recalculating."""
name: ClassVar[str] = "update_current_state" # used for opentracing
def try_merge(self, task: "_EventPersistQueueTask") -> bool:
"""Deduplicates consecutive recalculations of current state."""
return isinstance(task, _UpdateCurrentStateTask)
_EventPersistQueueTask = Union[_PersistEventsTask, _UpdateCurrentStateTask]
@attr.s(auto_attribs=True, slots=True)
class _EventPersistQueueItem:
task: _EventPersistQueueTask
deferred: ObservableDeferred deferred: ObservableDeferred
parent_opentracing_span_contexts: List = attr.ib(factory=list) parent_opentracing_span_contexts: List = attr.ib(factory=list)
@ -127,14 +163,16 @@ _PersistResult = TypeVar("_PersistResult")
class _EventPeristenceQueue(Generic[_PersistResult]): class _EventPeristenceQueue(Generic[_PersistResult]):
"""Queues up events so that they can be persisted in bulk with only one """Queues up tasks so that they can be processed with only one concurrent
concurrent transaction per room. transaction per room.
Tasks can be bulk persistence of events or recalculation of a room's current state.
""" """
def __init__( def __init__(
self, self,
per_item_callback: Callable[ per_item_callback: Callable[
[List[Tuple[EventBase, EventContext]], bool], [str, _EventPersistQueueTask],
Awaitable[_PersistResult], Awaitable[_PersistResult],
], ],
): ):
@ -150,18 +188,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
async def add_to_queue( async def add_to_queue(
self, self,
room_id: str, room_id: str,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]], task: _EventPersistQueueTask,
backfilled: bool,
) -> _PersistResult: ) -> _PersistResult:
"""Add events to the queue, with the given persist_event options. """Add a task to the queue.
If we are not already processing events in this room, starts off a background If we are not already processing tasks in this room, starts off a background
process to to so, calling the per_item_callback for each item. process to to so, calling the per_item_callback for each item.
Args: Args:
room_id (str): room_id (str):
events_and_contexts (list[(EventBase, EventContext)]): task (_EventPersistQueueTask): A _PersistEventsTask or
backfilled (bool): _UpdateCurrentStateTask to process.
Returns: Returns:
the result returned by the `_per_item_callback` passed to the result returned by the `_per_item_callback` passed to
@ -169,26 +206,20 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
""" """
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())
# if the last item in the queue has the same `backfilled` setting, if queue and queue[-1].task.try_merge(task):
# we can just add these new events to that item. # the new task has been merged into the last task in the queue
if queue and queue[-1].backfilled == backfilled:
end_item = queue[-1] end_item = queue[-1]
else: else:
# need to make a new queue item
deferred: ObservableDeferred[_PersistResult] = ObservableDeferred( deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
defer.Deferred(), consumeErrors=True defer.Deferred(), consumeErrors=True
) )
end_item = _EventPersistQueueItem( end_item = _EventPersistQueueItem(
events_and_contexts=[], task=task,
backfilled=backfilled,
deferred=deferred, deferred=deferred,
) )
queue.append(end_item) queue.append(end_item)
# add our events to the queue item
end_item.events_and_contexts.extend(events_and_contexts)
# also add our active opentracing span to the item so that we get a link back # also add our active opentracing span to the item so that we get a link back
span = opentracing.active_span() span = opentracing.active_span()
if span: if span:
@ -202,7 +233,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
# add another opentracing span which links to the persist trace. # add another opentracing span which links to the persist trace.
with opentracing.start_active_span_follows_from( with opentracing.start_active_span_follows_from(
"persist_event_batch_complete", (end_item.opentracing_span_context,) f"{task.name}_complete", (end_item.opentracing_span_context,)
): ):
pass pass
@ -234,16 +265,14 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
for item in queue: for item in queue:
try: try:
with opentracing.start_active_span_follows_from( with opentracing.start_active_span_follows_from(
"persist_event_batch", item.task.name,
item.parent_opentracing_span_contexts, item.parent_opentracing_span_contexts,
inherit_force_tracing=True, inherit_force_tracing=True,
) as scope: ) as scope:
if scope: if scope:
item.opentracing_span_context = scope.span.context item.opentracing_span_context = scope.span.context
ret = await self._per_item_callback( ret = await self._per_item_callback(room_id, item.task)
item.events_and_contexts, item.backfilled
)
except Exception: except Exception:
with PreserveLoggingContext(): with PreserveLoggingContext():
item.deferred.errback() item.deferred.errback()
@ -292,9 +321,32 @@ class EventsPersistenceStorageController:
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self._event_persist_queue = _EventPeristenceQueue(self._persist_event_batch) self._event_persist_queue = _EventPeristenceQueue(
self._process_event_persist_queue_task
)
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
async def _process_event_persist_queue_task(
self,
room_id: str,
task: _EventPersistQueueTask,
) -> Dict[str, str]:
"""Callback for the _event_persist_queue
Returns:
A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID.
"""
if isinstance(task, _PersistEventsTask):
return await self._persist_event_batch(room_id, task)
elif isinstance(task, _UpdateCurrentStateTask):
await self._update_current_state(room_id, task)
return {}
else:
raise AssertionError(
f"Found an unexpected task type in event persistence queue: {task}"
)
@opentracing.trace @opentracing.trace
async def persist_events( async def persist_events(
self, self,
@ -329,7 +381,8 @@ class EventsPersistenceStorageController:
) -> Dict[str, str]: ) -> Dict[str, str]:
room_id, evs_ctxs = item room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue( return await self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled room_id,
_PersistEventsTask(events_and_contexts=evs_ctxs, backfilled=backfilled),
) )
ret_vals = await yieldable_gather_results(enqueue, partitioned.items()) ret_vals = await yieldable_gather_results(enqueue, partitioned.items())
@ -376,7 +429,10 @@ class EventsPersistenceStorageController:
# event was deduplicated. (The dict may also include other entries if # event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events.) # the event was persisted in a batch with other events.)
replaced_events = await self._event_persist_queue.add_to_queue( replaced_events = await self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled event.room_id,
_PersistEventsTask(
events_and_contexts=[(event, context)], backfilled=backfilled
),
) )
replaced_event = replaced_events.get(event.event_id) replaced_event = replaced_events.get(event.event_id)
if replaced_event: if replaced_event:
@ -391,20 +447,22 @@ class EventsPersistenceStorageController:
async def update_current_state(self, room_id: str) -> None: async def update_current_state(self, room_id: str) -> None:
"""Recalculate the current state for a room, and persist it""" """Recalculate the current state for a room, and persist it"""
await self._event_persist_queue.add_to_queue(
room_id,
_UpdateCurrentStateTask(),
)
async def _update_current_state(
self, room_id: str, _task: _UpdateCurrentStateTask
) -> None:
"""Callback for the _event_persist_queue
Recalculates the current state for a room, and persists it.
"""
state = await self._calculate_current_state(room_id) state = await self._calculate_current_state(room_id)
delta = await self._calculate_state_delta(room_id, state) delta = await self._calculate_state_delta(room_id, state)
# TODO(faster_joins): get a real stream ordering, to make this work correctly await self.persist_events_store.update_current_state(room_id, delta)
# across workers.
# https://github.com/matrix-org/synapse/issues/12994
#
# TODO(faster_joins): this can race against event persistence, in which case we
# will end up with incorrect state. Perhaps we should make this a job we
# farm out to the event persister thread, somehow.
# https://github.com/matrix-org/synapse/issues/13007
#
stream_id = self.main_store.get_room_max_stream_ordering()
await self.persist_events_store.update_current_state(room_id, delta, stream_id)
async def _calculate_current_state(self, room_id: str) -> StateMap[str]: async def _calculate_current_state(self, room_id: str) -> StateMap[str]:
"""Calculate the current state of a room, based on the forward extremities """Calculate the current state of a room, based on the forward extremities
@ -449,9 +507,7 @@ class EventsPersistenceStorageController:
return res.state return res.state
async def _persist_event_batch( async def _persist_event_batch(
self, self, _room_id: str, task: _PersistEventsTask
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> Dict[str, str]: ) -> Dict[str, str]:
"""Callback for the _event_persist_queue """Callback for the _event_persist_queue
@ -466,6 +522,9 @@ class EventsPersistenceStorageController:
PartialStateConflictError: if attempting to persist a partial state event in PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated. a room that has been un-partial stated.
""" """
events_and_contexts = task.events_and_contexts
backfilled = task.backfilled
replaced_events: Dict[str, str] = {} replaced_events: Dict[str, str] = {}
if not events_and_contexts: if not events_and_contexts:
return replaced_events return replaced_events

View File

@ -1007,15 +1007,15 @@ class PersistEventsStore:
self, self,
room_id: str, room_id: str,
state_delta: DeltaState, state_delta: DeltaState,
stream_id: int,
) -> None: ) -> None:
"""Update the current state stored in the datatabase for the given room""" """Update the current state stored in the datatabase for the given room"""
async with self._stream_id_gen.get_next() as stream_ordering:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"update_current_state", "update_current_state",
self._update_current_state_txn, self._update_current_state_txn,
state_delta_by_room={room_id: state_delta}, state_delta_by_room={room_id: state_delta},
stream_id=stream_id, stream_id=stream_ordering,
) )
def _update_current_state_txn( def _update_current_state_txn(

View File

@ -195,6 +195,8 @@ class StateTestCase(unittest.TestCase):
"get_state_resolution_handler", "get_state_resolution_handler",
"get_account_validity_handler", "get_account_validity_handler",
"get_macaroon_generator", "get_macaroon_generator",
"get_instance_name",
"get_simple_http_client",
"hostname", "hostname",
] ]
) )