mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-10-02 09:38:25 -04:00
Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic. (#12672)
This commit is contained in:
parent
eb4aaa1b4b
commit
177b884ad7
5 changed files with 173 additions and 24 deletions
|
@ -12,7 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.protocol import Protocol
|
||||
|
@ -32,6 +33,7 @@ from synapse.server import HomeServer
|
|||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||
|
||||
try:
|
||||
import hiredis
|
||||
|
@ -475,22 +477,25 @@ class FakeRedisPubSubServer:
|
|||
"""A fake Redis server for pub/sub."""
|
||||
|
||||
def __init__(self):
|
||||
self._subscribers = set()
|
||||
self._subscribers_by_channel: Dict[
|
||||
bytes, Set["FakeRedisPubSubProtocol"]
|
||||
] = defaultdict(set)
|
||||
|
||||
def add_subscriber(self, conn):
|
||||
def add_subscriber(self, conn, channel: bytes):
|
||||
"""A connection has called SUBSCRIBE"""
|
||||
self._subscribers.add(conn)
|
||||
self._subscribers_by_channel[channel].add(conn)
|
||||
|
||||
def remove_subscriber(self, conn):
|
||||
"""A connection has called UNSUBSCRIBE"""
|
||||
self._subscribers.discard(conn)
|
||||
"""A connection has lost connection"""
|
||||
for subscribers in self._subscribers_by_channel.values():
|
||||
subscribers.discard(conn)
|
||||
|
||||
def publish(self, conn, channel, msg) -> int:
|
||||
def publish(self, conn, channel: bytes, msg) -> int:
|
||||
"""A connection want to publish a message to subscribers."""
|
||||
for sub in self._subscribers:
|
||||
for sub in self._subscribers_by_channel[channel]:
|
||||
sub.send(["message", channel, msg])
|
||||
|
||||
return len(self._subscribers)
|
||||
return len(self._subscribers_by_channel)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return FakeRedisPubSubProtocol(self)
|
||||
|
@ -531,9 +536,10 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
num_subscribers = self._server.publish(self, channel, message)
|
||||
self.send(num_subscribers)
|
||||
elif command == b"SUBSCRIBE":
|
||||
(channel,) = args
|
||||
self._server.add_subscriber(self)
|
||||
self.send(["subscribe", channel, 1])
|
||||
for idx, channel in enumerate(args):
|
||||
num_channels = idx + 1
|
||||
self._server.add_subscriber(self, channel)
|
||||
self.send(["subscribe", channel, num_channels])
|
||||
|
||||
# Since we use SET/GET to cache things we can safely no-op them.
|
||||
elif command == b"SET":
|
||||
|
@ -576,3 +582,27 @@ class FakeRedisPubSubProtocol(Protocol):
|
|||
|
||||
def connectionLost(self, reason):
|
||||
self._server.remove_subscriber(self)
|
||||
|
||||
|
||||
class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
|
||||
"""
|
||||
A test case that enables Redis, providing a fake Redis server.
|
||||
"""
|
||||
|
||||
if not hiredis:
|
||||
skip = "Requires hiredis"
|
||||
|
||||
if not USE_POSTGRES_FOR_TESTS:
|
||||
# Redis replication only takes place on Postgres
|
||||
skip = "Requires Postgres"
|
||||
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Overrides the default config to enable Redis.
|
||||
Even if the test only uses make_worker_hs, the main process needs Redis
|
||||
enabled otherwise it won't create a Fake Redis server to listen on the
|
||||
Redis port and accept fake TCP connections.
|
||||
"""
|
||||
base = super().default_config()
|
||||
base["redis"] = {"enabled": True}
|
||||
return base
|
||||
|
|
73
tests/replication/tcp/test_handler.py
Normal file
73
tests/replication/tcp/test_handler.py
Normal file
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tests.replication._base import RedisMultiWorkerStreamTestCase
|
||||
|
||||
|
||||
class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
|
||||
def test_subscribed_to_enough_redis_channels(self) -> None:
|
||||
# The default main process is subscribed to the USER_IP channel.
|
||||
self.assertCountEqual(
|
||||
self.hs.get_replication_command_handler()._channels_to_subscribe_to,
|
||||
["USER_IP"],
|
||||
)
|
||||
|
||||
def test_background_worker_subscribed_to_user_ip(self) -> None:
|
||||
# The default main process is subscribed to the USER_IP channel.
|
||||
worker1 = self.make_worker_hs(
|
||||
"synapse.app.generic_worker",
|
||||
extra_config={
|
||||
"worker_name": "worker1",
|
||||
"run_background_tasks_on": "worker1",
|
||||
"redis": {"enabled": True},
|
||||
},
|
||||
)
|
||||
self.assertIn(
|
||||
"USER_IP",
|
||||
worker1.get_replication_command_handler()._channels_to_subscribe_to,
|
||||
)
|
||||
|
||||
# Advance so the Redis subscription gets processed
|
||||
self.pump(0.1)
|
||||
|
||||
# The counts are 2 because both the main process and the worker are subscribed.
|
||||
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
|
||||
self.assertEqual(
|
||||
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2
|
||||
)
|
||||
|
||||
def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
|
||||
# The default main process is subscribed to the USER_IP channel.
|
||||
worker2 = self.make_worker_hs(
|
||||
"synapse.app.generic_worker",
|
||||
extra_config={
|
||||
"worker_name": "worker2",
|
||||
"run_background_tasks_on": "worker1",
|
||||
"redis": {"enabled": True},
|
||||
},
|
||||
)
|
||||
self.assertNotIn(
|
||||
"USER_IP",
|
||||
worker2.get_replication_command_handler()._channels_to_subscribe_to,
|
||||
)
|
||||
|
||||
# Advance so the Redis subscription gets processed
|
||||
self.pump(0.1)
|
||||
|
||||
# The count is 2 because both the main process and the worker are subscribed.
|
||||
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
|
||||
# For USER_IP, the count is 1 because only the main process is subscribed.
|
||||
self.assertEqual(
|
||||
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue