Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic. (#12672)

This commit is contained in:
reivilibre 2022-05-19 16:29:08 +01:00 committed by GitHub
parent eb4aaa1b4b
commit 177b884ad7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 173 additions and 24 deletions

View file

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
@ -32,6 +33,7 @@ from synapse.server import HomeServer
from tests import unittest
from tests.server import FakeTransport
from tests.utils import USE_POSTGRES_FOR_TESTS
try:
import hiredis
@ -475,22 +477,25 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
def __init__(self):
self._subscribers = set()
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)
def add_subscriber(self, conn):
def add_subscriber(self, conn, channel: bytes):
"""A connection has called SUBSCRIBE"""
self._subscribers.add(conn)
self._subscribers_by_channel[channel].add(conn)
def remove_subscriber(self, conn):
"""A connection has called UNSUBSCRIBE"""
self._subscribers.discard(conn)
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)
def publish(self, conn, channel, msg) -> int:
def publish(self, conn, channel: bytes, msg) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers:
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])
return len(self._subscribers)
return len(self._subscribers_by_channel)
def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self)
@ -531,9 +536,10 @@ class FakeRedisPubSubProtocol(Protocol):
num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers)
elif command == b"SUBSCRIBE":
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
for idx, channel in enumerate(args):
num_channels = idx + 1
self._server.add_subscriber(self, channel)
self.send(["subscribe", channel, num_channels])
# Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET":
@ -576,3 +582,27 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason):
self._server.remove_subscriber(self)
class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
"""
A test case that enables Redis, providing a fake Redis server.
"""
if not hiredis:
skip = "Requires hiredis"
if not USE_POSTGRES_FOR_TESTS:
# Redis replication only takes place on Postgres
skip = "Requires Postgres"
def default_config(self) -> Dict[str, Any]:
"""
Overrides the default config to enable Redis.
Even if the test only uses make_worker_hs, the main process needs Redis
enabled otherwise it won't create a Fake Redis server to listen on the
Redis port and accept fake TCP connections.
"""
base = super().default_config()
base["redis"] = {"enabled": True}
return base

View file

@ -0,0 +1,73 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.replication._base import RedisMultiWorkerStreamTestCase
class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
def test_subscribed_to_enough_redis_channels(self) -> None:
# The default main process is subscribed to the USER_IP channel.
self.assertCountEqual(
self.hs.get_replication_command_handler()._channels_to_subscribe_to,
["USER_IP"],
)
def test_background_worker_subscribed_to_user_ip(self) -> None:
# The default main process is subscribed to the USER_IP channel.
worker1 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker1",
"run_background_tasks_on": "worker1",
"redis": {"enabled": True},
},
)
self.assertIn(
"USER_IP",
worker1.get_replication_command_handler()._channels_to_subscribe_to,
)
# Advance so the Redis subscription gets processed
self.pump(0.1)
# The counts are 2 because both the main process and the worker are subscribed.
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
self.assertEqual(
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2
)
def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
# The default main process is subscribed to the USER_IP channel.
worker2 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker2",
"run_background_tasks_on": "worker1",
"redis": {"enabled": True},
},
)
self.assertNotIn(
"USER_IP",
worker2.get_replication_command_handler()._channels_to_subscribe_to,
)
# Advance so the Redis subscription gets processed
self.pump(0.1)
# The count is 2 because both the main process and the worker are subscribed.
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
# For USER_IP, the count is 1 because only the main process is subscribed.
self.assertEqual(
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
)