Add ability to wait for replication streams (#7542)

The idea here is that if an instance persists an event via the replication HTTP API it can return before we receive that event over replication, which can lead to races where code assumes that persisting an event immediately updates various caches (e.g. current state of the room).

Most of Synapse doesn't hit such races, so we don't do the waiting automagically, instead we do so where necessary to avoid unnecessary delays. We may decide to change our minds here if it turns out there are a lot of subtle races going on.

People probably want to look at this commit by commit.
This commit is contained in:
Erik Johnston 2020-05-22 14:21:54 +01:00 committed by GitHub
parent 06a02bc1ce
commit 1531b214fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 304 additions and 112 deletions

View file

@ -14,19 +14,23 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
import heapq
import logging
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Dict, List, Tuple
from twisted.internet.defer import Deferred
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.api.constants import EventTypes
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -35,6 +39,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# How long we allow callers to wait for replication updates before timing out.
_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS = 30
class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
"""Factory for building connections to the master. Will reconnect if the
connection is lost.
@ -92,6 +100,16 @@ class ReplicationDataHandler:
self.store = hs.get_datastore()
self.pusher_pool = hs.get_pusherpool()
self.notifier = hs.get_notifier()
self._reactor = hs.get_reactor()
self._clock = hs.get_clock()
self._streams = hs.get_replication_streams()
self._instance_name = hs.get_instance_name()
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters = (
{}
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@ -131,8 +149,76 @@ class ReplicationDataHandler:
await self.pusher_pool.on_new_notifications(token, token)
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
waiting_list = self._streams_to_waiters.get(stream_name, [])
# Index of first item with a position after the current token, i.e we
# have called all deferreds before this index. If not overwritten by
# loop below means either a) no items in list so no-op or b) all items
# in list were called and so the list should be cleared. Setting it to
# `len(list)` works for both cases.
index_of_first_deferred_not_called = len(waiting_list)
for idx, (position, deferred) in enumerate(waiting_list):
if position <= token:
try:
with PreserveLoggingContext():
deferred.callback(None)
except Exception:
# The deferred has been cancelled or timed out.
pass
else:
# The list is sorted by position so we don't need to continue
# checking any futher entries in the list.
index_of_first_deferred_not_called = idx
break
# Drop all entries in the waiting list that were called in the above
# loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int
):
"""Wait until this instance has received updates up to and including
the given stream position.
"""
if instance_name == self._instance_name:
# We don't get told about updates written by this process, and
# anyway in that case we don't need to wait.
return
current_position = self._streams[stream_name].current_token(self._instance_name)
if position <= current_position:
# We're already past the position
return
# Create a new deferred that times out after N seconds, as we don't want
# to wedge here forever.
deferred = Deferred()
deferred = timeout_deferred(
deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
)
waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
# We insert into the list using heapq as it is more efficient than
# pushing then resorting each time.
heapq.heappush(waiting_list, (position, deferred))
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
await make_deferred_yieldable(deferred)
logger.info(
"Finished waiting for repl stream %r to reach %s", stream_name, position
)