Add unit test for event persister sharding (#8433)

This commit is contained in:
Erik Johnston 2020-10-02 09:57:12 +01:00 committed by GitHub
parent 05ee048f2c
commit 6c5d5e507e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 371 additions and 27 deletions

1
changelog.d/8433.misc Normal file
View File

@ -0,0 +1 @@
Add unit test for event persister sharding.

View File

@ -143,3 +143,6 @@ ignore_missing_imports = True
[mypy-nacl.*] [mypy-nacl.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-hiredis]
ignore_missing_imports = True

View File

@ -16,7 +16,7 @@
"""Contains *incomplete* type hints for txredisapi. """Contains *incomplete* type hints for txredisapi.
""" """
from typing import List, Optional, Union from typing import List, Optional, Union, Type
class RedisProtocol: class RedisProtocol:
def publish(self, channel: str, message: bytes): ... def publish(self, channel: str, message: bytes): ...
@ -42,3 +42,21 @@ def lazyConnection(
class SubscriberFactory: class SubscriberFactory:
def buildProtocol(self, addr): ... def buildProtocol(self, addr): ...
class ConnectionHandler: ...
class RedisFactory:
continueTrying: bool
handler: RedisProtocol
def __init__(
self,
uuid: str,
dbid: Optional[int],
poolsize: int,
isLazy: bool = False,
handler: Type = ConnectionHandler,
charset: str = "utf-8",
password: Optional[str] = None,
replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True,
): ...

View File

@ -251,10 +251,9 @@ class ReplicationCommandHandler:
using TCP. using TCP.
""" """
if hs.config.redis.redis_enabled: if hs.config.redis.redis_enabled:
import txredisapi
from synapse.replication.tcp.redis import ( from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory, RedisDirectTcpReplicationClientFactory,
lazyConnection,
) )
logger.info( logger.info(
@ -271,7 +270,8 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called). # connection after SUBSCRIBE is called).
# First create the connection for sending commands. # First create the connection for sending commands.
outbound_redis_connection = txredisapi.lazyConnection( outbound_redis_connection = lazyConnection(
reactor=hs.get_reactor(),
host=hs.config.redis_host, host=hs.config.redis_host,
port=hs.config.redis_port, port=hs.config.redis_port,
password=hs.config.redis.redis_password, password=hs.config.redis.redis_password,

View File

@ -15,7 +15,7 @@
import logging import logging
from inspect import isawaitable from inspect import isawaitable
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import txredisapi import txredisapi
@ -228,3 +228,41 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.password = self.password p.password = self.password
return p return p
def lazyConnection(
reactor,
host: str = "localhost",
port: int = 6379,
dbid: Optional[int] = None,
reconnect: bool = True,
charset: str = "utf-8",
password: Optional[str] = None,
connectTimeout: Optional[int] = None,
replyTimeout: Optional[int] = None,
convertNumbers: bool = True,
) -> txredisapi.RedisProtocol:
"""Equivalent to `txredisapi.lazyConnection`, except allows specifying a
reactor.
"""
isLazy = True
poolsize = 1
uuid = "%s:%d" % (host, port)
factory = txredisapi.RedisFactory(
uuid,
dbid,
poolsize,
isLazy,
txredisapi.ConnectionHandler,
charset,
password,
replyTimeout,
convertNumbers,
)
factory.continueTrying = reconnect
for x in range(poolsize):
reactor.connectTCP(host, port, factory, connectTimeout)
return factory.handler

View File

@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
import attr import attr
import hiredis
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
@ -27,7 +28,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer, GenericWorkerServer,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource, streams from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@ -197,19 +198,37 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.server_factory = ReplicationStreamProtocolFactory(self.hs) self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer() self.streamer = self.hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()
store = self.hs.get_datastore() store = self.hs.get_datastore()
self.database_pool = store.db_pool self.database_pool = store.db_pool
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["localhost"] = "127.0.0.1"
self._worker_hs_to_resource = {} # A map from a HS instance to the associated HTTP Site to use for
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost", 6379, self.connect_any_redis_attempts,
)
self.hs.get_tcp_replication().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we # When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't # automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes # manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests). # it is impossible to write the handling explicitly in the tests).
#
# Register the master replication listener:
self.reactor.add_tcp_client_callback( self.reactor.add_tcp_client_callback(
"1.2.3.4", 8765, self._handle_http_replication_attempt "1.2.3.4",
8765,
lambda: self._handle_http_replication_attempt(self.hs, 8765),
) )
def create_test_json_resource(self): def create_test_json_resource(self):
@ -253,28 +272,63 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
**kwargs **kwargs
) )
# If the instance is in the `instance_map` config then workers may try
# and send HTTP requests to it, so we register it with
# `_handle_http_replication_attempt` like we do with the master HS.
instance_name = worker_hs.get_instance_name()
instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
if instance_loc:
# Ensure the host is one that has a fake DNS entry.
if instance_loc.host not in self.reactor.lookups:
raise Exception(
"Host does not have an IP for instance_map[%r].host = %r"
% (instance_name, instance_loc.host,)
)
self.reactor.add_tcp_client_callback(
self.reactor.lookups[instance_loc.host],
instance_loc.port,
lambda: self._handle_http_replication_attempt(
worker_hs, instance_loc.port
),
)
store = worker_hs.get_datastore() store = worker_hs.get_datastore()
store.db_pool._db_pool = self.database_pool._db_pool store.db_pool._db_pool = self.database_pool._db_pool
repl_handler = ReplicationCommandHandler(worker_hs) # Set up TCP replication between master and the new worker if we don't
client = ClientReplicationStreamProtocol( # have Redis support enabled.
worker_hs, "client", "test", self.clock, repl_handler, if not worker_hs.config.redis_enabled:
) repl_handler = ReplicationCommandHandler(worker_hs)
server = self.server_factory.buildProtocol(None) client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)
client_transport = FakeTransport(server, self.reactor) client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport) client.makeConnection(client_transport)
server_transport = FakeTransport(client, self.reactor) server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport) server.makeConnection(server_transport)
# Set up a resource for the worker # Set up a resource for the worker
resource = ReplicationRestResource(self.hs) resource = ReplicationRestResource(worker_hs)
for servlet in self.servlets: for servlet in self.servlets:
servlet(worker_hs, resource) servlet(worker_hs, resource)
self._worker_hs_to_resource[worker_hs] = resource self._hs_to_site[worker_hs] = SynapseSite(
logger_name="synapse.access.http.fake",
site_tag="{}-{}".format(
worker_hs.config.server.server_name, worker_hs.get_instance_name()
),
config=worker_hs.config.server.listeners[0],
resource=resource,
server_version_string="1",
)
if worker_hs.config.redis.redis_enabled:
worker_hs.get_tcp_replication().start_replication(worker_hs)
return worker_hs return worker_hs
@ -285,7 +339,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
return config return config
def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest): def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
render(request, self._worker_hs_to_resource[worker_hs], self.reactor) render(request, self._hs_to_site[worker_hs].resource, self.reactor)
def replicate(self): def replicate(self):
"""Tell the master side of replication that something has happened, and then """Tell the master side of replication that something has happened, and then
@ -294,9 +348,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke() self.streamer.on_notifier_poke()
self.pump() self.pump()
def _handle_http_replication_attempt(self): def _handle_http_replication_attempt(self, hs, repl_port):
"""Handles a connection attempt to the master replication HTTP """Handles a connection attempt to the given HS replication HTTP
listener. listener on the given port.
""" """
# We should have at least one outbound connection attempt, where the # We should have at least one outbound connection attempt, where the
@ -305,7 +359,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(len(clients), 1) self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop() (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4") self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8765) self.assertEqual(port, repl_port)
# Set up client side protocol # Set up client side protocol
client_protocol = client_factory.buildProtocol(None) client_protocol = client_factory.buildProtocol(None)
@ -315,7 +369,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up the server side protocol # Set up the server side protocol
channel = _PushHTTPChannel(self.reactor) channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory channel.requestFactory = request_factory
channel.site = self.site channel.site = self._hs_to_site[hs]
# Connect client to server and vice versa. # Connect client to server and vice versa.
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
@ -333,6 +387,32 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# inside `connecTCP` before the connection has been passed back to the # inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection. # code that requested the TCP connection.
def connect_any_redis_attempts(self):
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
"""
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, server_protocol
)
server_protocol.makeConnection(server_to_client_transport)
return client_to_server_transport, server_to_client_transport
class TestReplicationDataHandler(GenericWorkerReplicationHandler): class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows""" """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@ -467,3 +547,105 @@ class _PullToPushProducer:
pass pass
self.stopProducing() self.stopProducing()
class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub.
"""
def __init__(self):
self._subscribers = set()
def add_subscriber(self, conn):
"""A connection has called SUBSCRIBE
"""
self._subscribers.add(conn)
def remove_subscriber(self, conn):
"""A connection has called UNSUBSCRIBE
"""
self._subscribers.discard(conn)
def publish(self, conn, channel, msg) -> int:
"""A connection want to publish a message to subscribers.
"""
for sub in self._subscribers:
sub.send(["message", channel, msg])
return len(self._subscribers)
def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self)
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server.
"""
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
def dataReceived(self, data):
self._reader.feed(data)
# We might get multiple messages in one packet.
while True:
msg = self._reader.gets()
if msg is False:
# No more messages.
return
if not isinstance(msg, list):
# Inbound commands should always be a list
raise Exception("Expected redis list")
self.handle_command(msg[0], *msg[1:])
def handle_command(self, command, *args):
"""Received a Redis command from the client.
"""
# We currently only support pub/sub.
if command == b"PUBLISH":
channel, message = args
num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers)
elif command == b"SUBSCRIBE":
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
else:
raise Exception("Unknown command")
def send(self, msg):
"""Send a message back to the client.
"""
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)
self.transport.flush()
def encode(self, obj):
"""Encode an object to its Redis format.
Supports: strings/bytes, integers and list/tuples.
"""
if isinstance(obj, bytes):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")
if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):
return ":{val}\r\n".format(val=obj)
if isinstance(obj, (list, tuple)):
items = "".join(self.encode(a) for a in obj)
return "*{len}\r\n{items}".format(len=len(obj), items=items)
raise Exception("Unrecognized type for encoding redis: %r: %r", type(obj), obj)
def connectionLost(self, reason):
self._server.remove_subscriber(self)

View File

@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__)
class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks event persisting sharding works
"""
# Event persister sharding requires postgres (due to needing
# `MutliWriterIdGenerator`).
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
def prepare(self, reactor, clock, hs):
# Register a user who sends a message that we'll get notified about
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
def default_config(self):
conf = super().default_config()
conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
"worker1": {"host": "testserv", "port": 1001},
"worker2": {"host": "testserv", "port": 1002},
}
return conf
def test_basic(self):
"""Simple test to ensure that multiple rooms can be created and joined,
and that different rooms get handled by different instances.
"""
self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "worker1"},
)
self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "worker2"},
)
persisted_on_1 = False
persisted_on_2 = False
store = self.hs.get_datastore()
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
# Keep making new rooms until we see rooms being persisted on both
# workers.
for _ in range(10):
# Create a room
room = self.helper.create_room_as(user_id, tok=access_token)
# The other user joins
self.helper.join(
room=room, user=self.other_user_id, tok=self.other_access_token
)
# The other user sends some messages
rseponse = self.helper.send(room, body="Hi!", tok=self.other_access_token)
event_id = rseponse["event_id"]
# The event position includes which instance persisted the event.
pos = self.get_success(store.get_position_for_event(event_id))
persisted_on_1 |= pos.instance_name == "worker1"
persisted_on_2 |= pos.instance_name == "worker2"
if persisted_on_1 and persisted_on_2:
break
self.assertTrue(persisted_on_1)
self.assertTrue(persisted_on_2)

View File

@ -241,7 +241,7 @@ class HomeserverTestCase(TestCase):
# create a site to wrap the resource. # create a site to wrap the resource.
self.site = SynapseSite( self.site = SynapseSite(
logger_name="synapse.access.http.fake", logger_name="synapse.access.http.fake",
site_tag="test", site_tag=self.hs.config.server.server_name,
config=self.hs.config.server.listeners[0], config=self.hs.config.server.listeners[0],
resource=self.resource, resource=self.resource,
server_version_string="1", server_version_string="1",