Support any process writing to cache invalidation stream. (#7436)

This commit is contained in:
Erik Johnston 2020-05-07 13:51:08 +01:00 committed by GitHub
parent 2929ce29d6
commit d7983b63a6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 225 additions and 230 deletions

View file

@ -100,10 +100,10 @@ class ReplicationDataHandler:
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, token, rows)
self.store.process_replication_rows(stream_name, instance_name, token, rows)
async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, [])
async def on_position(self, stream_name: str, instance_name: str, token: int):
self.store.process_replication_rows(stream_name, instance_name, token, [])
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""

View file

@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
return " ".join((self.app_id, self.push_key, self.user_id))
class InvalidateCacheCommand(Command):
"""Sent by the client to invalidate an upstream cache.
THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
NOT DISASTROUS IF WE DROP ON THE FLOOR.
Mainly used to invalidate destination retry timing caches.
Format::
INVALIDATE_CACHE <cache_func> <keys_json>
Where <keys_json> is a json list.
"""
NAME = "INVALIDATE_CACHE"
def __init__(self, cache_func, keys):
self.cache_func = cache_func
self.keys = keys
@classmethod
def from_line(cls, line):
cache_func, keys_json = line.split(" ", 1)
return cls(cache_func, json.loads(keys_json))
def to_line(self):
return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@ -439,7 +408,6 @@ _COMMANDS = (
UserSyncCommand,
FederationAckCommand,
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,

View file

@ -15,18 +15,7 @@
# limitations under the License.
import logging
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
)
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from prometheus_client import Counter
@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
@ -171,7 +159,7 @@ class ReplicationCommandHandler:
return
for stream_name, stream in self._streams.items():
current_token = stream.current_token()
current_token = stream.current_token(self._instance_name)
self.send_command(
PositionCommand(stream_name, self._instance_name, current_token)
)
@ -210,18 +198,6 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
):
invalidate_cache_counter.inc()
if self._is_master:
# We invalidate the cache locally, but then also stream that to other
# workers.
await self._store.invalidate_cache_and_stream(
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
@ -295,7 +271,7 @@ class ReplicationCommandHandler:
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
@ -326,7 +302,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to.
current_token = stream.current_token()
current_token = stream.current_token(cmd.instance_name)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@ -363,7 +339,9 @@ class ReplicationCommandHandler:
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(stream_name, cmd.token)
await self._replication_data_handler.on_position(
cmd.stream_name, cmd.instance_name, cmd.token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
@ -491,12 +469,6 @@ class ReplicationCommandHandler:
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(
self,
user_id: str,

View file

@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
from synapse.replication.tcp.streams import (
STREAMS_MAP,
CachesStream,
FederationStream,
Stream,
)
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@ -71,11 +76,16 @@ class ReplicationStreamer(object):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self._instance_name = hs.get_instance_name()
self._replication_torture_level = hs.config.replication_torture_level
# Work out list of streams that this instance is the source of.
self.streams = [] # type: List[Stream]
# All workers can write to the cache invalidation stream.
self.streams.append(CachesStream(hs))
if hs.config.worker_app is None:
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
@ -83,6 +93,10 @@ class ReplicationStreamer(object):
# has been disabled on the master.
continue
if stream == CachesStream:
# We've already added it above.
continue
self.streams.append(stream(hs))
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
@ -145,7 +159,9 @@ class ReplicationStreamer(object):
random.shuffle(all_streams)
for stream in all_streams:
if stream.last_token == stream.current_token():
if stream.last_token == stream.current_token(
self._instance_name
):
continue
if self._replication_torture_level:
@ -157,7 +173,7 @@ class ReplicationStreamer(object):
"Getting stream: %s: %s -> %s",
stream.NAME,
stream.last_token,
stream.current_token(),
stream.current_token(self._instance_name),
)
try:
updates, current_token, limited = await stream.get_updates()

View file

@ -95,20 +95,25 @@ class Stream(object):
def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[], Token],
current_token_function: Callable[[str], Token],
update_function: UpdateFunction,
):
"""Instantiate a Stream
current_token_function and update_function are callbacks which should be
implemented by subclasses.
`current_token_function` and `update_function` are callbacks which
should be implemented by subclasses.
current_token_function is called to get the current token of the underlying
stream. It is only meaningful on the process that is the source of the
replication stream (ie, usually the master).
`current_token_function` takes an instance name, which is a writer to
the stream, and returns the position in the stream of the writer (as
viewed from the current process). On the writer process this is where
the writer has successfully written up to, whereas on other processes
this is the position which we have received updates up to over
replication. (Note that most streams have a single writer and so their
implementations ignore the instance name passed in).
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
`update_function` is called to get updates for this stream between a
pair of stream tokens. See the `UpdateFunction` type definition for more
info.
Args:
local_instance_name: The instance name of the current process
@ -120,13 +125,13 @@ class Stream(object):
self.update_function = update_function
# The token from which we last asked for updates
self.last_token = self.current_token()
self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
self.last_token = self.current_token()
self.last_token = self.current_token(self.local_instance_name)
async def get_updates(self) -> StreamUpdateResult:
"""Gets all updates since the last time this function was called (or
@ -138,7 +143,7 @@ class Stream(object):
position in stream, and `limited` is whether there are more updates
to fetch.
"""
current_token = self.current_token()
current_token = self.current_token(self.local_instance_name)
updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token
)
@ -170,6 +175,16 @@ class Stream(object):
return updates, upto_token, limited
def current_token_without_instance(
current_token: Callable[[], int]
) -> Callable[[str], int]:
"""Takes a current token callback function for a single writer stream
that doesn't take an instance name parameter and wraps it in a function that
does accept an instance name parameter but ignores it.
"""
return lambda instance_name: current_token()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
@ -235,7 +250,7 @@ class BackfillStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_backfill_token,
current_token_without_instance(store.get_current_backfill_token),
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@ -271,7 +286,9 @@ class PresenceStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function
hs.get_instance_name(),
current_token_without_instance(store.get_current_presence_token),
update_function,
)
@ -296,7 +313,9 @@ class TypingStream(Stream):
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function
hs.get_instance_name(),
current_token_without_instance(typing_handler.get_current_token),
update_function,
)
@ -319,7 +338,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_receipt_stream_id,
current_token_without_instance(store.get_max_receipt_stream_id),
db_query_to_update_function(store.get_all_updated_receipts),
)
@ -339,7 +358,7 @@ class PushRulesStream(Stream):
hs.get_instance_name(), self._current_token, self._update_function
)
def _current_token(self) -> int:
def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
@ -373,7 +392,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
store.get_pushers_stream_token,
current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@ -402,13 +421,27 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
self.store.get_cache_stream_token,
self._update_function,
)
async def _update_function(
self, instance_name: str, from_token: int, upto_token: int, limit: int
):
rows = await self.store.get_all_updated_caches(
instance_name, from_token, upto_token, limit
)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
class PublicRoomsStream(Stream):
"""The public rooms list changed
@ -431,7 +464,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_public_room_stream_id,
current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms),
)
@ -452,7 +485,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@ -470,7 +503,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_to_device_stream_token,
current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages),
)
@ -490,7 +523,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_account_data_stream_id,
current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags),
)
@ -510,7 +543,7 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
current_token_without_instance(self.store.get_max_account_data_stream_id),
db_query_to_update_function(self._update_function),
)
@ -541,7 +574,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_group_stream_token,
current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes),
)
@ -559,7 +592,7 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),

View file

@ -20,7 +20,7 @@ from typing import List, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream
@ -119,7 +119,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self._store.get_current_events_token,
current_token_without_instance(self._store.get_current_events_token),
self._update_function,
)

View file

@ -15,7 +15,11 @@
# limitations under the License.
from collections import namedtuple
from synapse.replication.tcp.streams._base import Stream, make_http_update_function
from synapse.replication.tcp.streams._base import (
Stream,
current_token_without_instance,
make_http_update_function,
)
class FederationStream(Stream):
@ -41,7 +45,9 @@ class FederationStream(Stream):
# will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.)
federation_sender = hs.get_federation_sender()
current_token = federation_sender.get_current_token
current_token = current_token_without_instance(
federation_sender.get_current_token
)
update_function = federation_sender.get_replication_rows
elif hs.should_send_federation():
@ -58,7 +64,7 @@ class FederationStream(Stream):
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
def _stub_current_token():
def _stub_current_token(instance_name: str) -> int:
# dummy current-token method for use on workers
return 0