mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 16:44:49 -04:00
Support any process writing to cache invalidation stream. (#7436)
This commit is contained in:
parent
2929ce29d6
commit
d7983b63a6
26 changed files with 225 additions and 230 deletions
|
@ -100,10 +100,10 @@ class ReplicationDataHandler:
|
|||
token: stream token for this batch of rows
|
||||
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
|
||||
"""
|
||||
self.store.process_replication_rows(stream_name, token, rows)
|
||||
self.store.process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
async def on_position(self, stream_name: str, token: int):
|
||||
self.store.process_replication_rows(stream_name, token, [])
|
||||
async def on_position(self, stream_name: str, instance_name: str, token: int):
|
||||
self.store.process_replication_rows(stream_name, instance_name, token, [])
|
||||
|
||||
def on_remote_server_up(self, server: str):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
|
|
|
@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
|
|||
return " ".join((self.app_id, self.push_key, self.user_id))
|
||||
|
||||
|
||||
class InvalidateCacheCommand(Command):
|
||||
"""Sent by the client to invalidate an upstream cache.
|
||||
|
||||
THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
|
||||
NOT DISASTROUS IF WE DROP ON THE FLOOR.
|
||||
|
||||
Mainly used to invalidate destination retry timing caches.
|
||||
|
||||
Format::
|
||||
|
||||
INVALIDATE_CACHE <cache_func> <keys_json>
|
||||
|
||||
Where <keys_json> is a json list.
|
||||
"""
|
||||
|
||||
NAME = "INVALIDATE_CACHE"
|
||||
|
||||
def __init__(self, cache_func, keys):
|
||||
self.cache_func = cache_func
|
||||
self.keys = keys
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
cache_func, keys_json = line.split(" ", 1)
|
||||
|
||||
return cls(cache_func, json.loads(keys_json))
|
||||
|
||||
def to_line(self):
|
||||
return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
|
||||
|
||||
|
||||
class UserIpCommand(Command):
|
||||
"""Sent periodically when a worker sees activity from a client.
|
||||
|
||||
|
@ -439,7 +408,6 @@ _COMMANDS = (
|
|||
UserSyncCommand,
|
||||
FederationAckCommand,
|
||||
RemovePusherCommand,
|
||||
InvalidateCacheCommand,
|
||||
UserIpCommand,
|
||||
RemoteServerUpCommand,
|
||||
ClearUserSyncsCommand,
|
||||
|
@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
|
|||
ClearUserSyncsCommand.NAME,
|
||||
FederationAckCommand.NAME,
|
||||
RemovePusherCommand.NAME,
|
||||
InvalidateCacheCommand.NAME,
|
||||
UserIpCommand.NAME,
|
||||
ErrorCommand.NAME,
|
||||
RemoteServerUpCommand.NAME,
|
||||
|
|
|
@ -15,18 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
|
@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
|
|||
ClearUserSyncsCommand,
|
||||
Command,
|
||||
FederationAckCommand,
|
||||
InvalidateCacheCommand,
|
||||
PositionCommand,
|
||||
RdataCommand,
|
||||
RemoteServerUpCommand,
|
||||
|
@ -171,7 +159,7 @@ class ReplicationCommandHandler:
|
|||
return
|
||||
|
||||
for stream_name, stream in self._streams.items():
|
||||
current_token = stream.current_token()
|
||||
current_token = stream.current_token(self._instance_name)
|
||||
self.send_command(
|
||||
PositionCommand(stream_name, self._instance_name, current_token)
|
||||
)
|
||||
|
@ -210,18 +198,6 @@ class ReplicationCommandHandler:
|
|||
|
||||
self._notifier.on_new_replication_data()
|
||||
|
||||
async def on_INVALIDATE_CACHE(
|
||||
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
|
||||
):
|
||||
invalidate_cache_counter.inc()
|
||||
|
||||
if self._is_master:
|
||||
# We invalidate the cache locally, but then also stream that to other
|
||||
# workers.
|
||||
await self._store.invalidate_cache_and_stream(
|
||||
cmd.cache_func, tuple(cmd.keys)
|
||||
)
|
||||
|
||||
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
|
||||
user_ip_cache_counter.inc()
|
||||
|
||||
|
@ -295,7 +271,7 @@ class ReplicationCommandHandler:
|
|||
rows: a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
|
||||
await self._replication_data_handler.on_rdata(
|
||||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
@ -326,7 +302,7 @@ class ReplicationCommandHandler:
|
|||
self._pending_batches.pop(stream_name, [])
|
||||
|
||||
# Find where we previously streamed up to.
|
||||
current_token = stream.current_token()
|
||||
current_token = stream.current_token(cmd.instance_name)
|
||||
|
||||
# If the position token matches our current token then we're up to
|
||||
# date and there's nothing to do. Otherwise, fetch all updates
|
||||
|
@ -363,7 +339,9 @@ class ReplicationCommandHandler:
|
|||
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
await self._replication_data_handler.on_position(stream_name, cmd.token)
|
||||
await self._replication_data_handler.on_position(
|
||||
cmd.stream_name, cmd.instance_name, cmd.token
|
||||
)
|
||||
|
||||
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
||||
|
||||
|
@ -491,12 +469,6 @@ class ReplicationCommandHandler:
|
|||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
|
||||
"""Poke the master to invalidate a cache.
|
||||
"""
|
||||
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_user_ip(
|
||||
self,
|
||||
user_id: str,
|
||||
|
|
|
@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory
|
|||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
|
||||
from synapse.replication.tcp.streams import (
|
||||
STREAMS_MAP,
|
||||
CachesStream,
|
||||
FederationStream,
|
||||
Stream,
|
||||
)
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
stream_updates_counter = Counter(
|
||||
|
@ -71,11 +76,16 @@ class ReplicationStreamer(object):
|
|||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self._replication_torture_level = hs.config.replication_torture_level
|
||||
|
||||
# Work out list of streams that this instance is the source of.
|
||||
self.streams = [] # type: List[Stream]
|
||||
|
||||
# All workers can write to the cache invalidation stream.
|
||||
self.streams.append(CachesStream(hs))
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
for stream in STREAMS_MAP.values():
|
||||
if stream == FederationStream and hs.config.send_federation:
|
||||
|
@ -83,6 +93,10 @@ class ReplicationStreamer(object):
|
|||
# has been disabled on the master.
|
||||
continue
|
||||
|
||||
if stream == CachesStream:
|
||||
# We've already added it above.
|
||||
continue
|
||||
|
||||
self.streams.append(stream(hs))
|
||||
|
||||
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
|
||||
|
@ -145,7 +159,9 @@ class ReplicationStreamer(object):
|
|||
random.shuffle(all_streams)
|
||||
|
||||
for stream in all_streams:
|
||||
if stream.last_token == stream.current_token():
|
||||
if stream.last_token == stream.current_token(
|
||||
self._instance_name
|
||||
):
|
||||
continue
|
||||
|
||||
if self._replication_torture_level:
|
||||
|
@ -157,7 +173,7 @@ class ReplicationStreamer(object):
|
|||
"Getting stream: %s: %s -> %s",
|
||||
stream.NAME,
|
||||
stream.last_token,
|
||||
stream.current_token(),
|
||||
stream.current_token(self._instance_name),
|
||||
)
|
||||
try:
|
||||
updates, current_token, limited = await stream.get_updates()
|
||||
|
|
|
@ -95,20 +95,25 @@ class Stream(object):
|
|||
def __init__(
|
||||
self,
|
||||
local_instance_name: str,
|
||||
current_token_function: Callable[[], Token],
|
||||
current_token_function: Callable[[str], Token],
|
||||
update_function: UpdateFunction,
|
||||
):
|
||||
"""Instantiate a Stream
|
||||
|
||||
current_token_function and update_function are callbacks which should be
|
||||
implemented by subclasses.
|
||||
`current_token_function` and `update_function` are callbacks which
|
||||
should be implemented by subclasses.
|
||||
|
||||
current_token_function is called to get the current token of the underlying
|
||||
stream. It is only meaningful on the process that is the source of the
|
||||
replication stream (ie, usually the master).
|
||||
`current_token_function` takes an instance name, which is a writer to
|
||||
the stream, and returns the position in the stream of the writer (as
|
||||
viewed from the current process). On the writer process this is where
|
||||
the writer has successfully written up to, whereas on other processes
|
||||
this is the position which we have received updates up to over
|
||||
replication. (Note that most streams have a single writer and so their
|
||||
implementations ignore the instance name passed in).
|
||||
|
||||
update_function is called to get updates for this stream between a pair of
|
||||
stream tokens. See the UpdateFunction type definition for more info.
|
||||
`update_function` is called to get updates for this stream between a
|
||||
pair of stream tokens. See the `UpdateFunction` type definition for more
|
||||
info.
|
||||
|
||||
Args:
|
||||
local_instance_name: The instance name of the current process
|
||||
|
@ -120,13 +125,13 @@ class Stream(object):
|
|||
self.update_function = update_function
|
||||
|
||||
# The token from which we last asked for updates
|
||||
self.last_token = self.current_token()
|
||||
self.last_token = self.current_token(self.local_instance_name)
|
||||
|
||||
def discard_updates_and_advance(self):
|
||||
"""Called when the stream should advance but the updates would be discarded,
|
||||
e.g. when there are no currently connected workers.
|
||||
"""
|
||||
self.last_token = self.current_token()
|
||||
self.last_token = self.current_token(self.local_instance_name)
|
||||
|
||||
async def get_updates(self) -> StreamUpdateResult:
|
||||
"""Gets all updates since the last time this function was called (or
|
||||
|
@ -138,7 +143,7 @@ class Stream(object):
|
|||
position in stream, and `limited` is whether there are more updates
|
||||
to fetch.
|
||||
"""
|
||||
current_token = self.current_token()
|
||||
current_token = self.current_token(self.local_instance_name)
|
||||
updates, current_token, limited = await self.get_updates_since(
|
||||
self.local_instance_name, self.last_token, current_token
|
||||
)
|
||||
|
@ -170,6 +175,16 @@ class Stream(object):
|
|||
return updates, upto_token, limited
|
||||
|
||||
|
||||
def current_token_without_instance(
|
||||
current_token: Callable[[], int]
|
||||
) -> Callable[[str], int]:
|
||||
"""Takes a current token callback function for a single writer stream
|
||||
that doesn't take an instance name parameter and wraps it in a function that
|
||||
does accept an instance name parameter but ignores it.
|
||||
"""
|
||||
return lambda instance_name: current_token()
|
||||
|
||||
|
||||
def db_query_to_update_function(
|
||||
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||
) -> UpdateFunction:
|
||||
|
@ -235,7 +250,7 @@ class BackfillStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_current_backfill_token,
|
||||
current_token_without_instance(store.get_current_backfill_token),
|
||||
db_query_to_update_function(store.get_all_new_backfill_event_rows),
|
||||
)
|
||||
|
||||
|
@ -271,7 +286,9 @@ class PresenceStream(Stream):
|
|||
update_function = make_http_update_function(hs, self.NAME)
|
||||
|
||||
super().__init__(
|
||||
hs.get_instance_name(), store.get_current_presence_token, update_function
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_current_presence_token),
|
||||
update_function,
|
||||
)
|
||||
|
||||
|
||||
|
@ -296,7 +313,9 @@ class TypingStream(Stream):
|
|||
update_function = make_http_update_function(hs, self.NAME)
|
||||
|
||||
super().__init__(
|
||||
hs.get_instance_name(), typing_handler.get_current_token, update_function
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(typing_handler.get_current_token),
|
||||
update_function,
|
||||
)
|
||||
|
||||
|
||||
|
@ -319,7 +338,7 @@ class ReceiptsStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_max_receipt_stream_id,
|
||||
current_token_without_instance(store.get_max_receipt_stream_id),
|
||||
db_query_to_update_function(store.get_all_updated_receipts),
|
||||
)
|
||||
|
||||
|
@ -339,7 +358,7 @@ class PushRulesStream(Stream):
|
|||
hs.get_instance_name(), self._current_token, self._update_function
|
||||
)
|
||||
|
||||
def _current_token(self) -> int:
|
||||
def _current_token(self, instance_name: str) -> int:
|
||||
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
||||
return push_rules_token
|
||||
|
||||
|
@ -373,7 +392,7 @@ class PushersStream(Stream):
|
|||
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_pushers_stream_token,
|
||||
current_token_without_instance(store.get_pushers_stream_token),
|
||||
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
||||
)
|
||||
|
||||
|
@ -402,13 +421,27 @@ class CachesStream(Stream):
|
|||
ROW_TYPE = CachesStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
self.store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_cache_stream_token,
|
||||
db_query_to_update_function(store.get_all_updated_caches),
|
||||
self.store.get_cache_stream_token,
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
async def _update_function(
|
||||
self, instance_name: str, from_token: int, upto_token: int, limit: int
|
||||
):
|
||||
rows = await self.store.get_all_updated_caches(
|
||||
instance_name, from_token, upto_token, limit
|
||||
)
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
|
||||
class PublicRoomsStream(Stream):
|
||||
"""The public rooms list changed
|
||||
|
@ -431,7 +464,7 @@ class PublicRoomsStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_current_public_room_stream_id,
|
||||
current_token_without_instance(store.get_current_public_room_stream_id),
|
||||
db_query_to_update_function(store.get_all_new_public_rooms),
|
||||
)
|
||||
|
||||
|
@ -452,7 +485,7 @@ class DeviceListsStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_device_stream_token,
|
||||
current_token_without_instance(store.get_device_stream_token),
|
||||
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
||||
)
|
||||
|
||||
|
@ -470,7 +503,7 @@ class ToDeviceStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_to_device_stream_token,
|
||||
current_token_without_instance(store.get_to_device_stream_token),
|
||||
db_query_to_update_function(store.get_all_new_device_messages),
|
||||
)
|
||||
|
||||
|
@ -490,7 +523,7 @@ class TagAccountDataStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_max_account_data_stream_id,
|
||||
current_token_without_instance(store.get_max_account_data_stream_id),
|
||||
db_query_to_update_function(store.get_all_updated_tags),
|
||||
)
|
||||
|
||||
|
@ -510,7 +543,7 @@ class AccountDataStream(Stream):
|
|||
self.store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
self.store.get_max_account_data_stream_id,
|
||||
current_token_without_instance(self.store.get_max_account_data_stream_id),
|
||||
db_query_to_update_function(self._update_function),
|
||||
)
|
||||
|
||||
|
@ -541,7 +574,7 @@ class GroupServerStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_group_stream_token,
|
||||
current_token_without_instance(store.get_group_stream_token),
|
||||
db_query_to_update_function(store.get_all_groups_changes),
|
||||
)
|
||||
|
||||
|
@ -559,7 +592,7 @@ class UserSignatureStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_device_stream_token,
|
||||
current_token_without_instance(store.get_device_stream_token),
|
||||
db_query_to_update_function(
|
||||
store.get_all_user_signature_changes_for_remotes
|
||||
),
|
||||
|
|
|
@ -20,7 +20,7 @@ from typing import List, Tuple, Type
|
|||
|
||||
import attr
|
||||
|
||||
from ._base import Stream, StreamUpdateResult, Token
|
||||
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
|
||||
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
|
@ -119,7 +119,7 @@ class EventsStream(Stream):
|
|||
self._store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
self._store.get_current_events_token,
|
||||
current_token_without_instance(self._store.get_current_events_token),
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
|
|
|
@ -15,7 +15,11 @@
|
|||
# limitations under the License.
|
||||
from collections import namedtuple
|
||||
|
||||
from synapse.replication.tcp.streams._base import Stream, make_http_update_function
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
Stream,
|
||||
current_token_without_instance,
|
||||
make_http_update_function,
|
||||
)
|
||||
|
||||
|
||||
class FederationStream(Stream):
|
||||
|
@ -41,7 +45,9 @@ class FederationStream(Stream):
|
|||
# will be a real FederationSender, which has stubs for current_token and
|
||||
# get_replication_rows.)
|
||||
federation_sender = hs.get_federation_sender()
|
||||
current_token = federation_sender.get_current_token
|
||||
current_token = current_token_without_instance(
|
||||
federation_sender.get_current_token
|
||||
)
|
||||
update_function = federation_sender.get_replication_rows
|
||||
|
||||
elif hs.should_send_federation():
|
||||
|
@ -58,7 +64,7 @@ class FederationStream(Stream):
|
|||
super().__init__(hs.get_instance_name(), current_token, update_function)
|
||||
|
||||
@staticmethod
|
||||
def _stub_current_token():
|
||||
def _stub_current_token(instance_name: str) -> int:
|
||||
# dummy current-token method for use on workers
|
||||
return 0
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue