Move catchup of replication streams to worker. (#7024)

This changes the replication protocol so that the server does not send down `RDATA` for rows that happened before the client connected. Instead, the server will send a `POSITION` and clients then query the database (or master out of band) to get up to date.
This commit is contained in:
Erik Johnston 2020-03-25 14:54:01 +00:00 committed by GitHub
parent 7bab642707
commit 4cff617df1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 635 additions and 487 deletions

1
changelog.d/7024.misc Normal file
View File

@ -0,0 +1 @@
Move catchup of replication streams logic to worker.

View File

@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and
'<' worker to master flows): '<' worker to master flows):
> SERVER example.com > SERVER example.com
< REPLICATE events 53 < REPLICATE
> POSITION events 53
> RDATA events 54 ["$foo1:bar.com", ...] > RDATA events 54 ["$foo1:bar.com", ...]
> RDATA events 55 ["$foo4:bar.com", ...] > RDATA events 55 ["$foo4:bar.com", ...]
The example shows the server accepting a new connection and sending its The example shows the server accepting a new connection and sending its identity
identity with the `SERVER` command, followed by the client asking to with the `SERVER` command, followed by the client server to respond with the
subscribe to the `events` stream from the token `53`. The server then position of all streams. The server then periodically sends `RDATA` commands
periodically sends `RDATA` commands which have the format which have the format `RDATA <stream_name> <token> <row>`, where the format of
`RDATA <stream_name> <token> <row>`, where the format of `<row>` is `<row>` is defined by the individual streams.
defined by the individual streams.
Error reporting happens by either the client or server sending an ERROR Error reporting happens by either the client or server sending an ERROR
command, and usually the connection will be closed. command, and usually the connection will be closed.
@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually
connect to the server using a tool like netcat. A few things should be connect to the server using a tool like netcat. A few things should be
noted when manually using the protocol: noted when manually using the protocol:
- When subscribing to a stream using `REPLICATE`, the special token
`NOW` can be used to get all future updates. The special stream name
`ALL` can be used with `NOW` to subscribe to all available streams.
- The federation stream is only available if federation sending has - The federation stream is only available if federation sending has
been disabled on the main process. been disabled on the main process.
- The server will only time connections out that have sent a `PING` - The server will only time connections out that have sent a `PING`
@ -91,9 +88,7 @@ The client:
- Sends a `NAME` command, allowing the server to associate a human - Sends a `NAME` command, allowing the server to associate a human
friendly name with the connection. This is optional. friendly name with the connection. This is optional.
- Sends a `PING` as above - Sends a `PING` as above
- For each stream the client wishes to subscribe to it sends a - Sends a `REPLICATE` to get the current position of all streams.
`REPLICATE` with the `stream_name` and token it wants to subscribe
from.
- On receipt of a `SERVER` command, checks that the server name - On receipt of a `SERVER` command, checks that the server name
matches the expected server name. matches the expected server name.
@ -140,9 +135,7 @@ the wire:
> PING 1490197665618 > PING 1490197665618
< NAME synapse.app.appservice < NAME synapse.app.appservice
< PING 1490197665618 < PING 1490197665618
< REPLICATE events 1 < REPLICATE
< REPLICATE backfill 1
< REPLICATE caches 1
> POSITION events 1 > POSITION events 1
> POSITION backfill 1 > POSITION backfill 1
> POSITION caches 1 > POSITION caches 1
@ -181,9 +174,9 @@ client (C):
#### POSITION (S) #### POSITION (S)
The position of the stream has been updated. Sent to the client On receipt of a POSITION command clients should check if they have missed any
after all missing updates for a stream have been sent to the client updates, and if so then fetch them out of band. Sent in response to a
and they're now up to date. REPLICATE command (but can happen at any time).
#### ERROR (S, C) #### ERROR (S, C)
@ -199,20 +192,7 @@ client (C):
#### REPLICATE (C) #### REPLICATE (C)
Asks the server to replicate a given stream. The syntax is: Asks the server for the current position of all streams.
```
REPLICATE <stream_name> <token>
```
Where `<token>` may be either:
* a numeric stream_id to stream updates since (exclusive)
* `NOW` to stream all subsequent updates.
The `<stream_name>` is the name of a replication stream to subscribe
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
of streams). It can also be `ALL` to subscribe to all known streams,
in which case the `<token>` must be set to `NOW`.
#### USER_SYNC (C) #### USER_SYNC (C)

View File

@ -401,6 +401,9 @@ class GenericWorkerTyping(object):
self._room_serials[row.room_id] = token self._room_serials[row.room_id] = token
self._room_typing[row.room_id] = row.user_ids self._room_typing[row.room_id] = row.user_ids
def get_current_token(self) -> int:
return self._latest_room_serial
class GenericWorkerSlavedStore( class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly # FIXME(#3714): We need to add UserDirectoryStore as we write directly

View File

@ -499,4 +499,13 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction() self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self) -> int: def get_current_token(self) -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0 return 0
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return []

View File

@ -21,6 +21,7 @@ from synapse.replication.http import (
membership, membership,
register, register,
send_event, send_event,
streams,
) )
REPLICATION_PREFIX = "/_synapse/replication" REPLICATION_PREFIX = "/_synapse/replication"
@ -38,3 +39,4 @@ class ReplicationRestResource(JsonResource):
login.register_servlets(hs, self) login.register_servlets(hs, self)
register.register_servlets(hs, self) register.register_servlets(hs, self)
devices.register_servlets(hs, self) devices.register_servlets(hs, self)
streams.register_servlets(hs, self)

View File

@ -0,0 +1,78 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationGetStreamUpdates(ReplicationEndpoint):
"""Fetches stream updates from a server. Used for streams not persisted to
the database, e.g. typing notifications.
The API looks like:
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
200 OK
{
updates: [ ... ],
upto_token: 10,
limited: False,
}
"""
NAME = "get_repl_stream_updates"
PATH_ARGS = ("stream_name",)
METHOD = "GET"
def __init__(self, hs):
super().__init__(hs)
# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token, limit):
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
async def _handle_request(self, request, stream_name):
stream = self.streams.get(stream_name)
if stream is None:
raise SynapseError(400, "Unknown stream")
from_token = parse_integer(request, "from_token", required=True)
upto_token = parse_integer(request, "upto_token", required=True)
limit = parse_integer(request, "limit", required=True)
updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token, limit
)
return (
200,
{"updates": updates, "upto_token": upto_token, "limited": limited},
)
def register_servlets(hs, http_server):
ReplicationGetStreamUpdates(hs).register(http_server)

View File

@ -18,8 +18,10 @@ from typing import Dict, Optional
import six import six
from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.cache import (
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME CURRENT_STATE_CACHE_NAME,
CacheInvalidationWorkerStore,
)
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
@ -35,7 +37,7 @@ def __func__(inp):
return inp.__func__ return inp.__func__
class BaseSlavedStore(SQLBaseStore): class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs): def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs) super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
pos["caches"] = self._cache_id_gen.get_current_token() pos["caches"] = self._cache_id_gen.get_current_token()
return pos return pos
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches": if stream_name == "caches":
if self._cache_id_gen: if self._cache_id_gen:

View File

@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
result["pushers"] = self._pushers_id_gen.get_current_token() result["pushers"] = self._pushers_id_gen.get_current_token()
return result return result
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "pushers": if stream_name == "pushers":
self._pushers_id_gen.advance(token) self._pushers_id_gen.advance(token)

View File

@ -55,6 +55,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
self.client_name = client_name self.client_name = client_name
self.handler = handler self.handler = handler
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@ -65,7 +66,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr): def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr) logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol( return ClientReplicationStreamProtocol(
self.client_name, self.server_name, self._clock, self.handler self.hs, self.client_name, self.server_name, self._clock, self.handler,
) )
def clientConnectionLost(self, connector, reason): def clientConnectionLost(self, connector, reason):

View File

@ -136,8 +136,8 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without """Sent by the server to tell the client the stream postition without
needing to send an RDATA. needing to send an RDATA.
Sent to the client after all missing updates for a stream have been sent On receipt of a POSITION command clients should check if they have missed
to the client and they're now up to date. any updates, and if so then fetch them out of band.
""" """
NAME = "POSITION" NAME = "POSITION"
@ -179,42 +179,24 @@ class NameCommand(Command):
class ReplicateCommand(Command): class ReplicateCommand(Command):
"""Sent by the client to subscribe to the stream. """Sent by the client to subscribe to streams.
Format:: Format::
REPLICATE <stream_name> <token> REPLICATE
Where <token> may be either:
* a numeric stream_id to stream updates from
* "NOW" to stream all subsequent updates.
The <stream_name> can be "ALL" to subscribe to all known streams, in which
case the <token> must be set to "NOW", i.e.::
REPLICATE ALL NOW
""" """
NAME = "REPLICATE" NAME = "REPLICATE"
def __init__(self, stream_name, token): def __init__(self):
self.stream_name = stream_name pass
self.token = token
@classmethod @classmethod
def from_line(cls, line): def from_line(cls, line):
stream_name, token = line.split(" ", 1) return cls()
if token in ("NOW", "now"):
token = "NOW"
else:
token = int(token)
return cls(stream_name, token)
def to_line(self): def to_line(self):
return " ".join((self.stream_name, str(self.token))) return ""
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
class UserSyncCommand(Command): class UserSyncCommand(Command):

View File

@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618 > PING 1490197665618
< NAME synapse.app.appservice < NAME synapse.app.appservice
< PING 1490197665618 < PING 1490197665618
< REPLICATE events 1 < REPLICATE
< REPLICATE backfill 1
< REPLICATE caches 1
> POSITION events 1 > POSITION events 1
> POSITION backfill 1 > POSITION backfill 1
> POSITION caches 1 > POSITION caches 1
@ -53,17 +51,15 @@ import fcntl
import logging import logging
import struct import struct
from collections import defaultdict from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Tuple from typing import Any, DefaultDict, Dict, List, Set
from six import iteritems, iterkeys from six import iteritems
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import ( from synapse.replication.tcp.commands import (
@ -82,11 +78,16 @@ from synapse.replication.tcp.commands import (
SyncCommand, SyncCommand,
UserSyncCommand, UserSyncCommand,
) )
from synapse.replication.tcp.streams import STREAMS_MAP from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.types import Collection from synapse.types import Collection
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
from synapse.server import HomeServer
connection_close_counter = Counter( connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
) )
@ -411,16 +412,6 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name self.server_name = server_name
self.streamer = streamer self.streamer = streamer
# The streams the client has subscribed to and is up to date with
self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self): def connectionMade(self):
self.send_command(ServerCommand(self.server_name)) self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self) BaseReplicationStreamProtocol.connectionMade(self)
@ -436,21 +427,10 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
) )
async def on_REPLICATE(self, cmd): async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name # Subscribe to all streams we're publishing to.
token = cmd.token for stream_name in self.streamer.streams_by_name:
current_token = self.streamer.get_stream_token(stream_name)
if stream_name == "ALL": self.send_command(PositionCommand(stream_name, current_token))
# Subscribe to all streams we're publishing to.
deferreds = [
run_in_background(self.subscribe_to_stream, stream, token)
for stream in iterkeys(self.streamer.streams_by_name)
]
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
await self.subscribe_to_stream(stream_name, token)
async def on_FEDERATION_ACK(self, cmd): async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token) self.streamer.federation_ack(cmd.token)
@ -474,87 +454,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen, cmd.last_seen,
) )
async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
updates down if they have. During that time new updates for the stream
are queued and sent once we've sent down any missed updates.
"""
self.replication_streams.discard(stream_name)
self.connecting_streams.add(stream_name)
try:
# Get missing updates
updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)
# Send all the missing updates
for update in updates:
token, row = update[0], update[1]
self.send_command(RdataCommand(stream_name, token, row))
# We send a POSITION command to ensure that they have an up to
# date token (especially useful if we didn't send any updates
# above)
self.send_command(PositionCommand(stream_name, current_token))
# Now we can send any updates that came in while we were subscribing
pending_rdata = self.pending_rdata.pop(stream_name, [])
updates = []
for token, update in pending_rdata:
# If the token is null, it is part of a batch update. Batches
# are multiple updates that share a single token. To denote
# this, the token is set to None for all tokens in the batch
# except for the last. If we find a None token, we keep looking
# through tokens until we find one that is not None and then
# process all previous updates in the batch as if they had the
# final token.
if token is None:
# Store this update as part of a batch
updates.append(update)
continue
if token <= current_token:
# This update or batch of updates is older than
# current_token, dismiss it
updates = []
continue
updates.append(update)
# Send all updates that are part of this batch with the
# found token
for update in updates:
self.send_command(RdataCommand(stream_name, token, update))
# Clear stored updates
updates = []
# They're now fully subscribed
self.replication_streams.add(stream_name)
except Exception as e:
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
self.send_error("failed to handle replicate: %r", e)
finally:
self.connecting_streams.discard(stream_name)
def stream_update(self, stream_name, token, data): def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients. """Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not We need to check if the client is interested in the stream or not
""" """
if stream_name in self.replication_streams: self.send_command(RdataCommand(stream_name, token, data))
# The client is subscribed to the stream
self.send_command(RdataCommand(stream_name, token, data))
elif stream_name in self.connecting_streams:
# The client is being subscribed to the stream
logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
self.pending_rdata.setdefault(stream_name, []).append((token, data))
else:
# The client isn't subscribed
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
def send_sync(self, data): def send_sync(self, data):
self.send_command(SyncCommand(data)) self.send_command(SyncCommand(data))
@ -638,6 +543,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def __init__( def __init__(
self, self,
hs: "HomeServer",
client_name: str, client_name: str,
server_name: str, server_name: str,
clock: Clock, clock: Clock,
@ -649,22 +555,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name self.server_name = server_name
self.handler = handler self.handler = handler
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
# Set of stream names that have been subscribe to, but haven't yet # Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully # caught up with. This is used to track when the client has been fully
# connected to the remote. # connected to the remote.
self.streams_connecting = set() # type: Set[str] self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how # Map of stream to batched updates. See RdataCommand for info on how
# batching works. # batching works.
self.pending_batches = {} # type: Dict[str, Any] self.pending_batches = {} # type: Dict[str, List[Any]]
def connectionMade(self): def connectionMade(self):
self.send_command(NameCommand(self.client_name)) self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self) BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams # Once we've connected subscribe to the necessary streams
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): self.replicate()
self.replicate(stream_name, token)
# Tell the server if we have any users currently syncing (should only # Tell the server if we have any users currently syncing (should only
# happen on synchrotrons) # happen on synchrotrons)
@ -676,10 +585,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler # We've now finished connecting to so inform the client handler
self.handler.update_connection(self) self.handler.update_connection(self)
# This will happen if we don't actually subscribe to any streams
if not self.streams_connecting:
self.handler.finished_connecting()
async def on_SERVER(self, cmd): async def on_SERVER(self, cmd):
if cmd.data != self.server_name: if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@ -697,7 +602,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
) )
raise raise
if cmd.token is None: if cmd.token is None or stream_name in self.streams_connecting:
# I.e. this is part of a batch of updates for this stream. Batch # I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token # until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row) self.pending_batches.setdefault(stream_name, []).append(row)
@ -707,14 +612,55 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
rows.append(row) rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows) await self.handler.on_rdata(stream_name, cmd.token, rows)
async def on_POSITION(self, cmd): async def on_POSITION(self, cmd: PositionCommand):
# When we get a `POSITION` command it means we've finished getting stream = self.streams.get(cmd.stream_name)
# missing updates for the given stream, and are now up to date. if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return
# Find where we previously streamed up to.
current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
)
return
# Fetch all updates between then and now.
limited = True
while limited:
updates, current_token, limited = await stream.get_updates_since(
current_token, cmd.token
)
# Check if the connection was closed underneath us, if so we bail
# rather than risk having concurrent catch ups going on.
if self.state == ConnectionStates.CLOSED:
return
if updates:
await self.handler.on_rdata(
cmd.stream_name,
current_token,
[stream.parse_row(update[1]) for update in updates],
)
# We've now caught up to position sent to us, notify handler.
await self.handler.on_position(cmd.stream_name, cmd.token)
self.streams_connecting.discard(cmd.stream_name) self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting: if not self.streams_connecting:
self.handler.finished_connecting() self.handler.finished_connecting()
await self.handler.on_position(cmd.stream_name, cmd.token) # Check if the connection was closed underneath us, if so we bail
# rather than risk having concurrent catch ups going on.
if self.state == ConnectionStates.CLOSED:
return
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])
if rows:
await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
async def on_SYNC(self, cmd): async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data) self.handler.on_sync(cmd.data)
@ -722,22 +668,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data) self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token): def replicate(self):
"""Send the subscription request to the server """Send the subscription request to the server
""" """
if stream_name not in STREAMS_MAP: logger.info("[%s] Subscribing to replication streams", self.id())
raise Exception("Invalid stream name %r" % (stream_name,))
logger.info( self.send_command(ReplicateCommand())
"[%s] Subscribing to replication stream: %r from %r",
self.id(),
stream_name,
token,
)
self.streams_connecting.add(stream_name)
self.send_command(ReplicateCommand(stream_name, token))
def on_connection_closed(self): def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self) BaseReplicationStreamProtocol.on_connection_closed(self)

View File

@ -17,7 +17,7 @@
import logging import logging
import random import random
from typing import Any, List from typing import Any, Dict, List
from six import itervalues from six import itervalues
@ -30,7 +30,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func from synapse.util.metrics import Measure, measure_func
from .protocol import ServerReplicationStreamProtocol from .protocol import ServerReplicationStreamProtocol
from .streams import STREAMS_MAP from .streams import STREAMS_MAP, Stream
from .streams.federation import FederationStream from .streams.federation import FederationStream
stream_updates_counter = Counter( stream_updates_counter = Counter(
@ -52,7 +52,7 @@ class ReplicationStreamProtocolFactory(Factory):
""" """
def __init__(self, hs): def __init__(self, hs):
self.streamer = ReplicationStreamer(hs) self.streamer = hs.get_replication_streamer()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
@ -133,6 +133,11 @@ class ReplicationStreamer(object):
for conn in self.connections: for conn in self.connections:
conn.send_error("server shutting down") conn.send_error("server shutting down")
def get_streams(self) -> Dict[str, Stream]:
"""Get a mapp from stream name to stream instance.
"""
return self.streams_by_name
def on_notifier_poke(self): def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the """Checks if there is actually any new data and sends it to the
connections if there are. connections if there are.
@ -190,7 +195,8 @@ class ReplicationStreamer(object):
stream.current_token(), stream.current_token(),
) )
try: try:
updates, current_token = await stream.get_updates() updates, current_token, limited = await stream.get_updates()
self.pending_updates |= limited
except Exception: except Exception:
logger.info("Failed to handle stream %s", stream.NAME) logger.info("Failed to handle stream %s", stream.NAME)
raise raise
@ -226,8 +232,7 @@ class ReplicationStreamer(object):
self.pending_updates = False self.pending_updates = False
self.is_looping = False self.is_looping = False
@measure_func("repl.get_stream_updates") def get_stream_token(self, stream_name):
async def get_stream_updates(self, stream_name, token):
"""For a given stream get all updates since token. This is called when """For a given stream get all updates since token. This is called when
a client first subscribes to a stream. a client first subscribes to a stream.
""" """
@ -235,7 +240,7 @@ class ReplicationStreamer(object):
if not stream: if not stream:
raise Exception("unknown stream %s", stream_name) raise Exception("unknown stream %s", stream_name)
return await stream.get_updates_since(token) return stream.current_token()
@measure_func("repl.federation_ack") @measure_func("repl.federation_ack")
def federation_ack(self, token): def federation_ack(self, token):

View File

@ -24,6 +24,9 @@ Each stream is defined by the following information:
current_token: The function that returns the current token for the stream current_token: The function that returns the current token for the stream
update_function: The function that returns a list of updates between two tokens update_function: The function that returns a list of updates between two tokens
""" """
from typing import Dict, Type
from synapse.replication.tcp.streams._base import ( from synapse.replication.tcp.streams._base import (
AccountDataStream, AccountDataStream,
BackfillStream, BackfillStream,
@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import (
PushersStream, PushersStream,
PushRulesStream, PushRulesStream,
ReceiptsStream, ReceiptsStream,
Stream,
TagAccountDataStream, TagAccountDataStream,
ToDeviceStream, ToDeviceStream,
TypingStream, TypingStream,
@ -63,10 +67,12 @@ STREAMS_MAP = {
GroupServerStream, GroupServerStream,
UserSignatureStream, UserSignatureStream,
) )
} } # type: Dict[str, Type[Stream]]
__all__ = [ __all__ = [
"STREAMS_MAP", "STREAMS_MAP",
"Stream",
"BackfillStream", "BackfillStream",
"PresenceStream", "PresenceStream",
"TypingStream", "TypingStream",

View File

@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Any, List, Optional, Tuple from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000 MAX_EVENTS_BEHIND = 500000
# Some type aliases to make things a bit easier.
# A stream position token
Token = int
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
StreamRow = Tuple[Token, tuple]
class Stream(object): class Stream(object):
"""Base class for the streams. """Base class for the streams.
@ -56,6 +65,7 @@ class Stream(object):
return cls.ROW_TYPE(*row) return cls.ROW_TYPE(*row)
def __init__(self, hs): def __init__(self, hs):
# The token from which we last asked for updates # The token from which we last asked for updates
self.last_token = self.current_token() self.last_token = self.current_token()
@ -65,61 +75,46 @@ class Stream(object):
""" """
self.last_token = self.current_token() self.last_token = self.current_token()
async def get_updates(self): async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Gets all updates since the last time this function was called (or """Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before). since the stream was constructed if it hadn't been called before).
Returns: Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]: A triplet `(updates, new_last_token, limited)`, where `updates` is
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a a list of `(token, row)` entries, `new_last_token` is the new
list of ``(token, row)`` entries. ``row`` will be json-serialised and position in stream, and `limited` is whether there are more updates
sent over the replication steam. to fetch.
""" """
updates, current_token = await self.get_updates_since(self.last_token) current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token
)
self.last_token = current_token self.last_token = current_token
return updates, current_token return updates, current_token, limited
async def get_updates_since( async def get_updates_since(
self, from_token: int self, from_token: Token, upto_token: Token, limit: int = 100
) -> Tuple[List[Tuple[int, JsonDict]], int]: ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Like get_updates except allows specifying from when we should """Like get_updates except allows specifying from when we should
stream updates stream updates
Returns: Returns:
Resolves to a pair `(updates, new_last_token)`, where `updates` is A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries and `new_last_token` is the new a list of `(token, row)` entries, `new_last_token` is the new
position in stream. position in stream, and `limited` is whether there are more updates
to fetch.
""" """
if from_token in ("NOW", "now"):
return [], self.current_token()
current_token = self.current_token()
from_token = int(from_token) from_token = int(from_token)
if from_token == current_token: if from_token == upto_token:
return [], current_token return [], upto_token, False
rows = await self.update_function( updates, upto_token, limited = await self.update_function(
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1 from_token, upto_token, limit=limit,
) )
return updates, upto_token, limited
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
updates = [(row[0], row[1:]) for row in rows]
# check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator.
if len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
# The update function didn't hit the limit, so we must have got all
# the updates to `current_token`, and can return that as our new
# stream position.
return updates, current_token
def current_token(self): def current_token(self):
"""Gets the current token of the underlying streams. Should be provided """Gets the current token of the underlying streams. Should be provided
@ -141,6 +136,48 @@ class Stream(object):
raise NotImplementedError() raise NotImplementedError()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) == limit:
upto_token = rows[-1][0]
limited = True
return updates, upto_token, limited
return update_function
def make_http_update_function(
hs, stream_name: str
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
"""
client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function(
from_token: int, upto_token: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
return await client(
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
limit=limit,
)
return update_function
class BackfillStream(Stream): class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before """We fetched some old events and either we had never seen that event before
or it went from being an outlier to not. or it went from being an outlier to not.
@ -164,7 +201,7 @@ class BackfillStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore self.current_token = store.get_current_backfill_token # type: ignore
self.update_function = store.get_all_new_backfill_event_rows # type: ignore self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs) super(BackfillStream, self).__init__(hs)
@ -190,8 +227,15 @@ class PresenceStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
presence_handler = hs.get_presence_handler() presence_handler = hs.get_presence_handler()
self._is_worker = hs.config.worker_app is not None
self.current_token = store.get_current_presence_token # type: ignore self.current_token = store.get_current_presence_token # type: ignore
self.update_function = presence_handler.get_all_presence_updates # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(PresenceStream, self).__init__(hs) super(PresenceStream, self).__init__(hs)
@ -208,7 +252,12 @@ class TypingStream(Stream):
typing_handler = hs.get_typing_handler() typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore self.current_token = typing_handler.get_current_token # type: ignore
self.update_function = typing_handler.get_all_typing_updates # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(TypingStream, self).__init__(hs) super(TypingStream, self).__init__(hs)
@ -232,7 +281,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore self.current_token = store.get_max_receipt_stream_id # type: ignore
self.update_function = store.get_all_updated_receipts # type: ignore self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs) super(ReceiptsStream, self).__init__(hs)
@ -256,7 +305,13 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit): async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit) rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
return [(row[0], row[2]) for row in rows]
limited = False
if len(rows) == limit:
to_token = rows[-1][0]
limited = True
return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream): class PushersStream(Stream):
@ -275,7 +330,7 @@ class PushersStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore self.current_token = store.get_pushers_stream_token # type: ignore
self.update_function = store.get_all_updated_pushers_rows # type: ignore self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs) super(PushersStream, self).__init__(hs)
@ -307,7 +362,7 @@ class CachesStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore self.current_token = store.get_cache_stream_token # type: ignore
self.update_function = store.get_all_updated_caches # type: ignore self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs) super(CachesStream, self).__init__(hs)
@ -333,7 +388,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore self.current_token = store.get_current_public_room_stream_id # type: ignore
self.update_function = store.get_all_new_public_rooms # type: ignore self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs) super(PublicRoomsStream, self).__init__(hs)
@ -354,7 +409,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs) super(DeviceListsStream, self).__init__(hs)
@ -372,7 +427,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore self.current_token = store.get_to_device_stream_token # type: ignore
self.update_function = store.get_all_new_device_messages # type: ignore self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs) super(ToDeviceStream, self).__init__(hs)
@ -392,7 +447,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore self.current_token = store.get_max_account_data_stream_id # type: ignore
self.update_function = store.get_all_updated_tags # type: ignore self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs) super(TagAccountDataStream, self).__init__(hs)
@ -412,10 +467,11 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore self.current_token = self.store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs) super(AccountDataStream, self).__init__(hs)
async def update_function(self, from_token, to_token, limit): async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data( global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit from_token, from_token, to_token, limit
) )
@ -442,7 +498,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore self.current_token = store.get_group_stream_token # type: ignore
self.update_function = store.get_all_groups_changes # type: ignore self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs) super(GroupServerStream, self).__init__(hs)
@ -460,6 +516,6 @@ class UserSignatureStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs) super(UserSignatureStream, self).__init__(hs)

View File

@ -19,7 +19,7 @@ from typing import Tuple, Type
import attr import attr
from ._base import Stream from ._base import Stream, db_query_to_update_function
"""Handling of the 'events' replication stream """Handling of the 'events' replication stream
@ -117,10 +117,11 @@ class EventsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
self._store = hs.get_datastore() self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore self.current_token = self._store.get_current_events_token # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs) super(EventsStream, self).__init__(hs)
async def update_function(self, from_token, current_token, limit=None): async def _update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows( event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit from_token, current_token, limit
) )

View File

@ -15,7 +15,9 @@
# limitations under the License. # limitations under the License.
from collections import namedtuple from collections import namedtuple
from ._base import Stream from twisted.internet import defer
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
class FederationStream(Stream): class FederationStream(Stream):
@ -33,11 +35,18 @@ class FederationStream(Stream):
NAME = "federation" NAME = "federation"
ROW_TYPE = FederationStreamRow ROW_TYPE = FederationStreamRow
_QUERY_MASTER = True
def __init__(self, hs): def __init__(self, hs):
federation_sender = hs.get_federation_sender() # Not all synapse instances will have a federation sender instance,
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
self.current_token = federation_sender.get_current_token # type: ignore # so we stub the stream out when that is the case.
self.update_function = federation_sender.get_replication_rows # type: ignore if hs.config.worker_app is None or hs.should_send_federation():
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
else:
self.current_token = lambda: 0 # type: ignore
self.update_function = lambda from_token, upto_token, limit: defer.succeed(([], upto_token, bool)) # type: ignore
super(FederationStream, self).__init__(hs) super(FederationStream, self).__init__(hs)

View File

@ -85,6 +85,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.rest.media.v1.media_repository import ( from synapse.rest.media.v1.media_repository import (
MediaRepository, MediaRepository,
MediaRepositoryResource, MediaRepositoryResource,
@ -199,6 +200,7 @@ class HomeServer(object):
"saml_handler", "saml_handler",
"event_client_serializer", "event_client_serializer",
"storage", "storage",
"replication_streamer",
] ]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"] REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@ -536,6 +538,9 @@ class HomeServer(object):
def build_storage(self) -> Storage: def build_storage(self) -> Storage:
return Storage(self, self.datastores) return Storage(self, self.datastores)
def build_replication_streamer(self) -> ReplicationStreamer:
return ReplicationStreamer(self)
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
CURRENT_STATE_CACHE_NAME = "cs_cache_fake" CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationStore(SQLBaseStore): class CacheInvalidationWorkerStore(SQLBaseStore):
def get_all_updated_caches(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
class CacheInvalidationStore(CacheInvalidationWorkerStore):
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves """Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches. will know to invalidate their caches.
@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
}, },
) )
def get_all_updated_caches(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def get_cache_stream_token(self): def get_cache_stream_token(self):
if self._cache_id_gen: if self._cache_id_gen:
return self._cache_id_gen.get_current_token() return self._cache_id_gen.get_current_token()

View File

@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
) )
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
last_pos(int):
current_pos(int):
limit(int):
Returns:
A deferred list of rows from the device inbox
"""
if last_pos == current_pos:
return defer.succeed([])
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
sql = (
"SELECT max(stream_id), destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn)
# Order by ascending stream ordering
rows.sort()
return rows
return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
class DeviceInboxBackgroundUpdateStore(SQLBaseStore): class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((user_id, device_id, stream_id, message_json)) rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows) txn.executemany(sql, rows)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
last_pos(int):
current_pos(int):
limit(int):
Returns:
A deferred list of rows from the device inbox
"""
if last_pos == current_pos:
return defer.succeed([])
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
sql = (
"SELECT max(stream_id), destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn)
# Order by ascending stream ordering
rows.sort()
return rows
return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)

View File

@ -1267,104 +1267,6 @@ class EventsStore(
ret = yield self.db.runInteraction("count_daily_active_rooms", _count) ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
return ret return ret
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_id, upper_bound))
new_event_updates.extend(txn)
return new_event_updates
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound))
new_event_updates.extend(txn.fetchall())
return new_event_updates
return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
@cached(num_args=5, max_entries=10) @cached(num_args=5, max_entries=10)
def get_all_new_events( def get_all_new_events(
self, self,
@ -1850,22 +1752,6 @@ class EventsStore(
return (int(res["topological_ordering"]), int(res["stream_ordering"])) return (int(res["topological_ordering"]), int(res["stream_ordering"]))
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
def insert_labels_for_event_txn( def insert_labels_for_event_txn(
self, txn, event_id, labels, room_id, topological_ordering self, txn, event_id, labels, room_id, topological_ordering
): ):

View File

@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore):
complexity_v1 = round(state_events / 500, 2) complexity_v1 = round(state_events / 500, 2)
return {"v1": complexity_v1} return {"v1": complexity_v1}
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_id, upper_bound))
new_event_updates.extend(txn)
return new_event_updates
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound))
new_event_updates.extend(txn.fetchall())
return new_event_updates
return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)

View File

@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined return total_media_quarantined
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
class RoomBackgroundUpdateStore(SQLBaseStore): class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self): def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token() return self._public_room_id_gen.get_current_token()
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
@defer.inlineCallbacks @defer.inlineCallbacks
def block_room(self, room_id, user_id): def block_room(self, room_id, user_id):
"""Marks the room as blocked. Can be called multiple times. """Marks the room as blocked. Can be called multiple times.

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock from mock import Mock
from synapse.replication.tcp.commands import ReplicateCommand from synapse.replication.tcp.commands import ReplicateCommand
@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server # build a replication server
server_factory = ReplicationStreamProtocolFactory(self.hs) server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = server_factory.streamer self.streamer = server_factory.streamer
server = server_factory.buildProtocol(None) self.server = server_factory.buildProtocol(None)
# build a replication client, with a dummy handler self.test_handler = Mock(wraps=TestReplicationClientHandler())
handler_factory = Mock()
self.test_handler = TestReplicationClientHandler()
self.test_handler.factory = handler_factory
self.client = ClientReplicationStreamProtocol( self.client = ClientReplicationStreamProtocol(
"client", "test", clock, self.test_handler hs, "client", "test", clock, self.test_handler,
) )
# wire them together self._client_transport = None
self.client.makeConnection(FakeTransport(server, reactor)) self._server_transport = None
server.makeConnection(FakeTransport(self.client, reactor))
def reconnect(self):
if self._client_transport:
self.client.close()
if self._server_transport:
self.server.close()
self._client_transport = FakeTransport(self.server, self.reactor)
self.client.makeConnection(self._client_transport)
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)
def disconnect(self):
if self._client_transport:
self._client_transport = None
self.client.close()
if self._server_transport:
self._server_transport = None
self.server.close()
def replicate(self): def replicate(self):
"""Tell the master side of replication that something has happened, and then """Tell the master side of replication that something has happened, and then
@ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke() self.streamer.on_notifier_poke()
self.pump(0.1) self.pump(0.1)
def replicate_stream(self, stream, token="NOW"): def replicate_stream(self):
"""Make the client end a REPLICATE command to set up a subscription to a stream""" """Make the client end a REPLICATE command to set up a subscription to a stream"""
self.client.send_command(ReplicateCommand(stream, token)) self.client.send_command(ReplicateCommand())
class TestReplicationClientHandler(object): class TestReplicationClientHandler(object):
"""Drop-in for ReplicationClientHandler which just collects RDATA rows""" """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
def __init__(self): def __init__(self):
self.received_rdata_rows = [] self.streams = set()
self._received_rdata_rows = []
def get_streams_to_replicate(self): def get_streams_to_replicate(self):
return {} positions = {s: 0 for s in self.streams}
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
def get_currently_syncing_users(self): def get_currently_syncing_users(self):
return [] return []
@ -73,6 +97,9 @@ class TestReplicationClientHandler(object):
def finished_connecting(self): def finished_connecting(self):
pass pass
async def on_position(self, stream_name, token):
"""Called when we get new position data."""
async def on_rdata(self, stream_name, token, rows): async def on_rdata(self, stream_name, token, rows):
for r in rows: for r in rows:
self.received_rdata_rows.append((stream_name, token, r)) self._received_rdata_rows.append((stream_name, token, r))

View File

@ -17,30 +17,64 @@ from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase from tests.replication.tcp.streams._base import BaseStreamTestCase
USER_ID = "@feeling:blue" USER_ID = "@feeling:blue"
ROOM_ID = "!room:blue"
EVENT_ID = "$event:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase): class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self): def test_receipt(self):
self.reconnect()
# make the client subscribe to the receipts stream # make the client subscribe to the receipts stream
self.replicate_stream("receipts", "NOW") self.replicate_stream()
self.test_handler.streams.add("receipts")
# tell the master to send a new receipt # tell the master to send a new receipt
self.get_success( self.get_success(
self.hs.get_datastore().insert_receipt( self.hs.get_datastore().insert_receipt(
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1} "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
) )
) )
self.replicate() self.replicate()
# there should be one RDATA command # there should be one RDATA command
rdata_rows = self.test_handler.received_rdata_rows self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
self.assertEqual(rdata_rows[0][0], "receipts") row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow self.assertEqual("!room:blue", row.room_id)
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual("m.read", row.receipt_type) self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id) self.assertEqual(USER_ID, row.user_id)
self.assertEqual(EVENT_ID, row.event_id) self.assertEqual("$event:blue", row.event_id)
self.assertEqual({"a": 1}, row.data) self.assertEqual({"a": 1}, row.data)
# Now let's disconnect and insert some data.
self.disconnect()
self.test_handler.on_rdata.reset_mock()
self.get_success(
self.hs.get_datastore().insert_receipt(
"!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
)
)
self.replicate()
# Nothing should have happened as we are disconnected
self.test_handler.on_rdata.assert_not_called()
self.reconnect()
self.pump(0.1)
# We should now have caught up and get the missing data
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
self.assertEqual("!room2:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
self.assertEqual("$event2:foo", row.event_id)
self.assertEqual({"a": 2}, row.data)