Fix client reader sharding tests (#7853)

* Fix client reader sharding tests

* Newsfile

* Fix typing

* Update changelog.d/7853.misc

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>

* Move mocking of http_client to tests

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
Erik Johnston 2020-07-15 15:27:35 +01:00 committed by GitHub
parent b11450dedc
commit f13061d515
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 302 additions and 176 deletions

1
changelog.d/7853.misc Normal file
View File

@ -0,0 +1 @@
Add support for handling registration requests across multiple client reader workers.

View File

@ -31,6 +31,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IResolutionReceiver, IResolutionReceiver,
) )
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody from twisted.web.client import Agent, HTTPConnectionPool, readBody
@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
return False return False
_EPSILON = 0.00000001
def _make_scheduler(reactor):
"""Makes a schedular suitable for a Cooperator using the given reactor.
(This is effectively just a copy from `twisted.internet.task`)
"""
def _scheduler(x):
return reactor.callLater(_EPSILON, x)
return _scheduler
class IPBlacklistingResolver(object): class IPBlacklistingResolver(object):
""" """
A proxy for reactor.nameResolver which only produces non-blacklisted IP A proxy for reactor.nameResolver which only produces non-blacklisted IP
@ -212,6 +228,10 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix: if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
# We use this for our body producers to ensure that they use the correct
# reactor.
self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
self.user_agent = self.user_agent.encode("ascii") self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist: if self._ip_blacklist:
@ -292,7 +312,9 @@ class SimpleHttpClient(object):
try: try:
body_producer = None body_producer = None
if data is not None: if data is not None:
body_producer = QuieterFileBodyProducer(BytesIO(data)) body_producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator,
)
request_deferred = treq.request( request_deferred = treq.request(
method, method,

View File

@ -20,6 +20,7 @@ import synapse.handlers.room
import synapse.handlers.room_member import synapse.handlers.room_member
import synapse.handlers.set_password import synapse.handlers.set_password
import synapse.http.client import synapse.http.client
import synapse.http.matrixfederationclient
import synapse.notifier import synapse.notifier
import synapse.push.pusherpool import synapse.push.pusherpool
import synapse.replication.tcp.client import synapse.replication.tcp.client
@ -143,3 +144,7 @@ class HomeServer(object):
pass pass
def get_replication_streams(self) -> Dict[str, Stream]: def get_replication_streams(self) -> Dict[str, Stream]:
pass pass
def get_http_client(
self,
) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
pass

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
import attr import attr
@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
GenericWorkerReplicationHandler, GenericWorkerReplicationHandler,
GenericWorkerServer, GenericWorkerServer,
) )
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.replication.http import streams from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@ -35,7 +36,7 @@ from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeTransport from tests.server import FakeTransport, render
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET") self.assertEqual(request.method, b"GET")
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.
Automatically handle HTTP replication requests from workers to master,
unlike `BaseStreamTestCase`.
"""
servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
def setUp(self):
super().setUp()
# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
store = self.hs.get_datastore()
self.database = store.db
self.reactor.lookups["testserv"] = "1.2.3.4"
self._worker_hs_to_resource = {}
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
self.reactor.add_tcp_client_callback(
"1.2.3.4", 8765, self._handle_http_replication_attempt
)
def create_test_json_resource(self):
"""Overrides `HomeserverTestCase.create_test_json_resource`.
"""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
# subclassses.
resource = ReplicationRestResource(self.hs)
for servlet in self.servlets:
servlet(self.hs, resource)
return resource
def make_worker_hs(
self, worker_app: str, extra_config: dict = {}, **kwargs
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
Args:
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
useful to e.g. pass some mocks for things like `http_client`
Returns:
The new worker HomeServer instance.
"""
config = self._get_worker_hs_config()
config["worker_app"] = worker_app
config.update(extra_config)
worker_hs = self.setup_test_homeserver(
homeserverToUse=GenericWorkerServer,
config=config,
reactor=self.reactor,
**kwargs
)
store = worker_hs.get_datastore()
store.db._db_pool = self.database._db_pool
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)
# Set up a resource for the worker
resource = ReplicationRestResource(self.hs)
for servlet in self.servlets:
servlet(worker_hs, resource)
self._worker_hs_to_resource[worker_hs] = resource
return worker_hs
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()
def _handle_http_replication_attempt(self):
"""Handles a connection attempt to the master replication HTTP
listener.
"""
# We should have at least one outbound connection attempt, where the
# last is one to the HTTP repication IP/port.
clients = self.reactor.tcpClients
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8765)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
channel.site = self.site
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
)
channel.makeConnection(server_to_client_transport)
# Note: at this point we've wired everything up, but we need to return
# before the data starts flowing over the connections as this is called
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
class TestReplicationDataHandler(GenericWorkerReplicationHandler): class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows""" """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
# We need to manually stop the _PullToPushProducer. # We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop() self._pull_to_push_producer.stop()
def checkPersistence(self, request, version):
"""Check whether the connection can be re-used
"""
# We hijack this to always say no for ease of wiring stuff up in
# `handle_http_replication_attempt`.
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
class _PullToPushProducer: class _PullToPushProducer:
"""A push producer that wraps a pull producer. """A push producer that wraps a pull producer.

View File

@ -15,63 +15,26 @@
import logging import logging
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest.client.v2_alpha import register from synapse.rest.client.v2_alpha import register
from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel, render from tests.server import FakeChannel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ClientReaderTestCase(unittest.HomeserverTestCase): class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Base class for tests of the replication streams""" """Base class for tests of the replication streams"""
servlets = [ servlets = [register.register_servlets]
register.register_servlets,
]
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
store = hs.get_datastore()
self.database = store.db
self.recaptcha_checker = DummyRecaptchaChecker(hs) self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler() auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
self.reactor.lookups["testserv"] = "1.2.3.4"
def make_worker_hs(self, extra_config={}):
config = self._get_worker_hs_config()
config.update(extra_config)
worker_hs = self.setup_test_homeserver(
homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor,
)
store = worker_hs.get_datastore()
store.db._db_pool = self.database._db_pool
# Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource.
resource = JsonResource(self.hs)
for servlet in self.servlets:
servlet(worker_hs, resource)
# Essentially HomeserverTestCase.render.
def _render(request):
render(request, self.resource, self.reactor)
return worker_hs, _render
def _get_worker_hs_config(self) -> dict: def _get_worker_hs_config(self) -> dict:
config = self.default_config() config = self.default_config()
config["worker_app"] = "synapse.app.client_reader" config["worker_app"] = "synapse.app.client_reader"
@ -82,14 +45,14 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
def test_register_single_worker(self): def test_register_single_worker(self):
"""Test that registration works when using a single client reader worker. """Test that registration works when using a single client reader worker.
""" """
_, worker_render = self.make_worker_hs() worker_hs = self.make_worker_hs("synapse.app.client_reader")
request_1, channel_1 = self.make_request( request_1, channel_1 = self.make_request(
"POST", "POST",
"register", "register",
{"username": "user", "type": "m.login.password", "password": "bar"}, {"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
worker_render(request_1) self.render_on_worker(worker_hs, request_1)
self.assertEqual(request_1.code, 401) self.assertEqual(request_1.code, 401)
# Grab the session # Grab the session
@ -99,7 +62,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
request_2, channel_2 = self.make_request( request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
worker_render(request_2) self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200) self.assertEqual(request_2.code, 200)
# We're given a registered user. # We're given a registered user.
@ -108,15 +71,15 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
def test_register_multi_worker(self): def test_register_multi_worker(self):
"""Test that registration works when using multiple client reader workers. """Test that registration works when using multiple client reader workers.
""" """
_, worker_render_1 = self.make_worker_hs() worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
_, worker_render_2 = self.make_worker_hs() worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
request_1, channel_1 = self.make_request( request_1, channel_1 = self.make_request(
"POST", "POST",
"register", "register",
{"username": "user", "type": "m.login.password", "password": "bar"}, {"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
worker_render_1(request_1) self.render_on_worker(worker_hs_1, request_1)
self.assertEqual(request_1.code, 401) self.assertEqual(request_1.code, 401)
# Grab the session # Grab the session
@ -126,7 +89,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
request_2, channel_2 = self.make_request( request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
worker_render_2(request_2) self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200) self.assertEqual(request_2.code, 200)
# We're given a registered user. # We're given a registered user.

View File

@ -19,132 +19,40 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.app.generic_worker import GenericWorkerServer
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.replication.http import streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeTransport
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase): class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"""Base class for tests of the replication streams"""
servlets = [
streams.register_servlets,
]
def prepare(self, reactor, clock, hs):
# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
store = hs.get_datastore()
self.database = store.db
self.reactor.lookups["testserv"] = "1.2.3.4"
def default_config(self):
conf = super().default_config()
conf["send_federation"] = False
return conf
def make_worker_hs(self, extra_config={}):
config = self._get_worker_hs_config()
config.update(extra_config)
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
worker_hs = self.setup_test_homeserver(
http_client=mock_federation_client,
homeserverToUse=GenericWorkerServer,
config=config,
reactor=self.reactor,
)
store = worker_hs.get_datastore()
store.db._db_pool = self.database._db_pool
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)
return worker_hs
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.federation_sender"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
federation = self.hs.get_handlers().federation_handler
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
factory = EventBuilderFactory(self.hs)
factory.hostname = remote_server
user_id = UserID("user", remote_server).to_string()
event_dict = {
"type": EventTypes.Member,
"state_key": user_id,
"content": {"membership": Membership.JOIN},
"sender": user_id,
"room_id": room,
}
builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(builder.build(prev_event_ids))
self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate()
return room
class FederationSenderTestCase(BaseStreamTestCase):
servlets = [ servlets = [
login.register_servlets, login.register_servlets,
register_servlets_for_client_rest_resource, register_servlets_for_client_rest_resource,
room.register_servlets, room.register_servlets,
] ]
def default_config(self):
conf = super().default_config()
conf["send_federation"] = False
return conf
def test_send_event_single_sender(self): def test_send_event_single_sender(self):
"""Test that using a single federation sender worker correctly sends a """Test that using a single federation sender worker correctly sends a
new event. new event.
""" """
worker_hs = self.make_worker_hs({"send_federation": True}) mock_client = Mock(spec=["put_json"])
mock_client = worker_hs.get_http_client() mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
self.make_worker_hs(
"synapse.app.federation_sender",
{"send_federation": True},
http_client=mock_client,
)
user = self.register_user("user", "pass") user = self.register_user("user", "pass")
token = self.login("user", "pass") token = self.login("user", "pass")
@ -165,23 +73,29 @@ class FederationSenderTestCase(BaseStreamTestCase):
"""Test that using two federation sender workers correctly sends """Test that using two federation sender workers correctly sends
new events. new events.
""" """
worker1 = self.make_worker_hs( mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
self.make_worker_hs(
"synapse.app.federation_sender",
{ {
"send_federation": True, "send_federation": True,
"worker_name": "sender1", "worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"], "federation_sender_instances": ["sender1", "sender2"],
} },
http_client=mock_client1,
) )
mock_client1 = worker1.get_http_client()
worker2 = self.make_worker_hs( mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
self.make_worker_hs(
"synapse.app.federation_sender",
{ {
"send_federation": True, "send_federation": True,
"worker_name": "sender2", "worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"], "federation_sender_instances": ["sender1", "sender2"],
} },
http_client=mock_client2,
) )
mock_client2 = worker2.get_http_client()
user = self.register_user("user2", "pass") user = self.register_user("user2", "pass")
token = self.login("user2", "pass") token = self.login("user2", "pass")
@ -191,8 +105,8 @@ class FederationSenderTestCase(BaseStreamTestCase):
for i in range(20): for i in range(20):
server_name = "other_server_%d" % (i,) server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name) room = self.create_room_with_remote_server(user, token, server_name)
mock_client1.reset_mock() mock_client1.reset_mock() # type: ignore[attr-defined]
mock_client2.reset_mock() mock_client2.reset_mock() # type: ignore[attr-defined]
self.create_and_send_event(room, UserID.from_string(user)) self.create_and_send_event(room, UserID.from_string(user))
self.replicate() self.replicate()
@ -222,23 +136,29 @@ class FederationSenderTestCase(BaseStreamTestCase):
"""Test that using two federation sender workers correctly sends """Test that using two federation sender workers correctly sends
new typing EDUs. new typing EDUs.
""" """
worker1 = self.make_worker_hs( mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
self.make_worker_hs(
"synapse.app.federation_sender",
{ {
"send_federation": True, "send_federation": True,
"worker_name": "sender1", "worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"], "federation_sender_instances": ["sender1", "sender2"],
} },
http_client=mock_client1,
) )
mock_client1 = worker1.get_http_client()
worker2 = self.make_worker_hs( mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
self.make_worker_hs(
"synapse.app.federation_sender",
{ {
"send_federation": True, "send_federation": True,
"worker_name": "sender2", "worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"], "federation_sender_instances": ["sender1", "sender2"],
} },
http_client=mock_client2,
) )
mock_client2 = worker2.get_http_client()
user = self.register_user("user3", "pass") user = self.register_user("user3", "pass")
token = self.login("user3", "pass") token = self.login("user3", "pass")
@ -250,8 +170,8 @@ class FederationSenderTestCase(BaseStreamTestCase):
for i in range(20): for i in range(20):
server_name = "other_server_%d" % (i,) server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name) room = self.create_room_with_remote_server(user, token, server_name)
mock_client1.reset_mock() mock_client1.reset_mock() # type: ignore[attr-defined]
mock_client2.reset_mock() mock_client2.reset_mock() # type: ignore[attr-defined]
self.get_success( self.get_success(
typing_handler.started_typing( typing_handler.started_typing(
@ -284,3 +204,32 @@ class FederationSenderTestCase(BaseStreamTestCase):
self.assertTrue(sent_on_1) self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2) self.assertTrue(sent_on_2)
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
federation = self.hs.get_handlers().federation_handler
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
factory = EventBuilderFactory(self.hs)
factory.hostname = remote_server
user_id = UserID("user", remote_server).to_string()
event_dict = {
"type": EventTypes.Member,
"state_key": user_id,
"content": {"membership": Membership.JOIN},
"sender": user_id,
"room_id": room,
}
builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(builder.build(prev_event_ids))
self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate()
return room

View File

@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self): def __init__(self):
self.threadpool = ThreadPool(self) self.threadpool = ThreadPool(self)
self._tcp_callbacks = {}
self._udp = [] self._udp = []
lookups = self.lookups = {} lookups = self.lookups = {}
@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def getThreadPool(self): def getThreadPool(self):
return self.threadpool return self.threadpool
def add_tcp_client_callback(self, host, port, callback):
"""Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`.
Note that the callback gets run before we return the connection to the
client, which means callbacks cannot block while waiting for writes.
"""
self._tcp_callbacks[(host, port)] = callback
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
"""Fake L{IReactorTCP.connectTCP}.
"""
conn = super().connectTCP(
host, port, factory, timeout=timeout, bindAddress=None
)
callback = self._tcp_callbacks.get((host, port))
if callback:
callback()
return conn
class ThreadPool: class ThreadPool:
""" """
@ -486,7 +510,7 @@ class FakeTransport(object):
try: try:
self.other.dataReceived(to_write) self.other.dataReceived(to_write)
except Exception as e: except Exception as e:
logger.warning("Exception writing to protocol: %s", e) logger.exception("Exception writing to protocol: %s", e)
return return
self.buffer = self.buffer[len(to_write) :] self.buffer = self.buffer[len(to_write) :]