diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6928d9d3e..795c655ae 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -50,16 +50,14 @@ from twisted.cred import checkers, portal from twisted.internet import reactor, task, defer from twisted.application import service -from twisted.enterprise import adbapi from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.static import File from twisted.web.server import Site, GzipEncoderFactory, Request -from synapse.http.server import JsonResource, RootRedirect +from synapse.http.server import RootRedirect from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.key.v1.server_key_resource import LocalKey from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.api.urls import ( FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX, @@ -69,6 +67,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.util.logcontext import LoggingContext from synapse.metrics.resource import MetricsResource, METRICS_PREFIX +from synapse.federation.transport.server import TransportLayerServer from synapse import events @@ -95,80 +94,37 @@ def gz_wrap(r): return EncodingResourceWrapper(r, [GzipEncoderFactory()]) +def build_resource_for_web_client(hs): + webclient_path = hs.get_config().web_client_location + if not webclient_path: + try: + import syweb + except ImportError: + quit_with_error( + "Could not find a webclient.\n\n" + "Please either install the matrix-angular-sdk or configure\n" + "the location of the source to serve via the configuration\n" + "option `web_client_location`\n\n" + "To install the `matrix-angular-sdk` via pip, run:\n\n" + " pip install '%(dep)s'\n" + "\n" + "You can also disable hosting of the webclient via the\n" + "configuration option `web_client`\n" + % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]} + ) + syweb_path = os.path.dirname(syweb.__file__) + webclient_path = os.path.join(syweb_path, "webclient") + # GZip is disabled here due to + # https://twistedmatrix.com/trac/ticket/7678 + # (It can stay enabled for the API resources: they call + # write() with the whole body and then finish() straight + # after and so do not trigger the bug. + # GzipFile was removed in commit 184ba09 + # return GzipFile(webclient_path) # TODO configurable? + return File(webclient_path) # TODO configurable? + + class SynapseHomeServer(HomeServer): - - def build_http_client(self): - return MatrixFederationHttpClient(self) - - def build_client_resource(self): - return ClientRestResource(self) - - def build_resource_for_federation(self): - return JsonResource(self) - - def build_resource_for_web_client(self): - webclient_path = self.get_config().web_client_location - if not webclient_path: - try: - import syweb - except ImportError: - quit_with_error( - "Could not find a webclient.\n\n" - "Please either install the matrix-angular-sdk or configure\n" - "the location of the source to serve via the configuration\n" - "option `web_client_location`\n\n" - "To install the `matrix-angular-sdk` via pip, run:\n\n" - " pip install '%(dep)s'\n" - "\n" - "You can also disable hosting of the webclient via the\n" - "configuration option `web_client`\n" - % {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]} - ) - syweb_path = os.path.dirname(syweb.__file__) - webclient_path = os.path.join(syweb_path, "webclient") - # GZip is disabled here due to - # https://twistedmatrix.com/trac/ticket/7678 - # (It can stay enabled for the API resources: they call - # write() with the whole body and then finish() straight - # after and so do not trigger the bug. - # GzipFile was removed in commit 184ba09 - # return GzipFile(webclient_path) # TODO configurable? - return File(webclient_path) # TODO configurable? - - def build_resource_for_static_content(self): - # This is old and should go away: not going to bother adding gzip - return File( - os.path.join(os.path.dirname(synapse.__file__), "static") - ) - - def build_resource_for_content_repo(self): - return ContentRepoResource( - self, self.config.uploads_path, self.auth, self.content_addr - ) - - def build_resource_for_media_repository(self): - return MediaRepositoryResource(self) - - def build_resource_for_server_key(self): - return LocalKey(self) - - def build_resource_for_server_key_v2(self): - return KeyApiV2Resource(self) - - def build_resource_for_metrics(self): - if self.get_config().enable_metrics: - return MetricsResource(self) - else: - return None - - def build_db_pool(self): - name = self.db_config["name"] - - return adbapi.ConnectionPool( - name, - **self.db_config.get("args", {}) - ) - def _listener_http(self, config, listener_config): port = listener_config["port"] bind_address = listener_config.get("bind_address", "") @@ -178,13 +134,11 @@ class SynapseHomeServer(HomeServer): if tls and config.no_tls: return - metrics_resource = self.get_resource_for_metrics() - resources = {} for res in listener_config["resources"]: for name in res["names"]: if name == "client": - client_resource = self.get_client_resource() + client_resource = ClientRestResource(self) if res["compress"]: client_resource = gz_wrap(client_resource) @@ -198,31 +152,35 @@ class SynapseHomeServer(HomeServer): if name == "federation": resources.update({ - FEDERATION_PREFIX: self.get_resource_for_federation(), + FEDERATION_PREFIX: TransportLayerServer(self), }) if name in ["static", "client"]: resources.update({ - STATIC_PREFIX: self.get_resource_for_static_content(), + STATIC_PREFIX: File( + os.path.join(os.path.dirname(synapse.__file__), "static") + ), }) if name in ["media", "federation", "client"]: resources.update({ - MEDIA_PREFIX: self.get_resource_for_media_repository(), - CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(), + MEDIA_PREFIX: MediaRepositoryResource(self), + CONTENT_REPO_PREFIX: ContentRepoResource( + self, self.config.uploads_path, self.auth, self.content_addr + ), }) if name in ["keys", "federation"]: resources.update({ - SERVER_KEY_PREFIX: self.get_resource_for_server_key(), - SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(), + SERVER_KEY_PREFIX: LocalKey(self), + SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self), }) if name == "webclient": - resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client() + resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self) - if name == "metrics" and metrics_resource: - resources[METRICS_PREFIX] = metrics_resource + if name == "metrics" and self.get_config().enable_metrics: + resources[METRICS_PREFIX] = MetricsResource(self) root_resource = create_resource_tree(resources) if tls: @@ -675,7 +633,7 @@ def _resource_id(resource, path_seg): the mapping should looks like _resource_id(A,C) = B. Args: - resource (Resource): The *parent* Resource + resource (Resource): The *parent* Resourceb path_seg (str): The name of the child Resource to be attached. Returns: str: A unique string which can be a key to the child Resource. @@ -684,7 +642,7 @@ def _resource_id(resource, path_seg): def run(hs): - PROFILE_SYNAPSE = False + PROFILE_SYNAPSE = True if PROFILE_SYNAPSE: def profile(func): from cProfile import Profile @@ -761,6 +719,7 @@ def run(hs): auto_close_fds=False, verbose=True, logger=logger, + chdir=os.path.dirname(os.path.abspath(__file__)), ) daemon.start() diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index 0bfb79d09..979fdf243 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -17,15 +17,10 @@ """ from .replication import ReplicationLayer -from .transport import TransportLayer +from .transport.client import TransportLayerClient def initialize_http_replication(homeserver): - transport = TransportLayer( - homeserver, - homeserver.hostname, - server=homeserver.get_resource_for_federation(), - client=homeserver.get_http_client() - ) + transport = TransportLayerClient(homeserver) return ReplicationLayer(homeserver, transport) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 6e0be8ef1..3e062a5ea 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -54,8 +54,6 @@ class ReplicationLayer(FederationClient, FederationServer): self.keyring = hs.get_keyring() self.transport_layer = transport_layer - self.transport_layer.register_received_handler(self) - self.transport_layer.register_request_handler(self) self.federation_client = self diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py index 155a7d587..d9fcc520a 100644 --- a/synapse/federation/transport/__init__.py +++ b/synapse/federation/transport/__init__.py @@ -20,55 +20,3 @@ By default this is done over HTTPS (and all home servers are required to support HTTPS), however individual pairings of servers may decide to communicate over a different (albeit still reliable) protocol. """ - -from .server import TransportLayerServer -from .client import TransportLayerClient - -from synapse.util.ratelimitutils import FederationRateLimiter - - -class TransportLayer(TransportLayerServer, TransportLayerClient): - """This is a basic implementation of the transport layer that translates - transactions and other requests to/from HTTP. - - Attributes: - server_name (str): Local home server host - - server (synapse.http.server.HttpServer): the http server to - register listeners on - - client (synapse.http.client.HttpClient): the http client used to - send requests - - request_handler (TransportRequestHandler): The handler to fire when we - receive requests for data. - - received_handler (TransportReceivedHandler): The handler to fire when - we receive data. - """ - - def __init__(self, homeserver, server_name, server, client): - """ - Args: - server_name (str): Local home server host - server (synapse.protocol.http.HttpServer): the http server to - register listeners on - client (synapse.protocol.http.HttpClient): the http client used to - send requests - """ - self.keyring = homeserver.get_keyring() - self.clock = homeserver.get_clock() - self.server_name = server_name - self.server = server - self.client = client - self.request_handler = None - self.received_handler = None - - self.ratelimiter = FederationRateLimiter( - self.clock, - window_size=homeserver.config.federation_rc_window_size, - sleep_limit=homeserver.config.federation_rc_sleep_limit, - sleep_msec=homeserver.config.federation_rc_sleep_delay, - reject_limit=homeserver.config.federation_rc_reject_limit, - concurrent_requests=homeserver.config.federation_rc_concurrent, - ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 949d01dea..2b5d40ea7 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) class TransportLayerClient(object): """Sends federation HTTP requests to other servers""" + def __init__(self, hs): + self.server_name = hs.hostname + self.client = hs.get_http_client() + @log_function def get_room_state(self, destination, room_id, event_id): """ Requests all state for a given room from the given server at the diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 8dca0a7f6..65e054f7d 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -17,7 +17,8 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError -from synapse.util.logutils import log_function +from synapse.http.server import JsonResource +from synapse.util.ratelimitutils import FederationRateLimiter import functools import logging @@ -28,9 +29,41 @@ import re logger = logging.getLogger(__name__) -class TransportLayerServer(object): +class TransportLayerServer(JsonResource): """Handles incoming federation HTTP requests""" + def __init__(self, hs): + self.hs = hs + self.clock = hs.get_clock() + + super(TransportLayerServer, self).__init__(hs) + + self.authenticator = Authenticator(hs) + self.ratelimiter = FederationRateLimiter( + self.clock, + window_size=hs.config.federation_rc_window_size, + sleep_limit=hs.config.federation_rc_sleep_limit, + sleep_msec=hs.config.federation_rc_sleep_delay, + reject_limit=hs.config.federation_rc_reject_limit, + concurrent_requests=hs.config.federation_rc_concurrent, + ) + + self.register_servlets() + + def register_servlets(self): + register_servlets( + self.hs, + resource=self, + ratelimiter=self.ratelimiter, + authenticator=self.authenticator, + ) + + +class Authenticator(object): + def __init__(self, hs): + self.keyring = hs.get_keyring() + self.server_name = hs.hostname + # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks def authenticate_request(self, request): @@ -98,37 +131,9 @@ class TransportLayerServer(object): defer.returnValue((origin, content)) - @log_function - def register_received_handler(self, handler): - """ Register a handler that will be fired when we receive data. - - Args: - handler (TransportReceivedHandler) - """ - FederationSendServlet( - handler, - authenticator=self, - ratelimiter=self.ratelimiter, - server_name=self.server_name, - ).register(self.server) - - @log_function - def register_request_handler(self, handler): - """ Register a handler that will be fired when we get asked for data. - - Args: - handler (TransportRequestHandler) - """ - for servletclass in SERVLET_CLASSES: - servletclass( - handler, - authenticator=self, - ratelimiter=self.ratelimiter, - ).register(self.server) - class BaseFederationServlet(object): - def __init__(self, handler, authenticator, ratelimiter): + def __init__(self, handler, authenticator, ratelimiter, server_name): self.handler = handler self.authenticator = authenticator self.ratelimiter = ratelimiter @@ -172,7 +177,9 @@ class FederationSendServlet(BaseFederationServlet): PATH = "/send/([^/]*)/" def __init__(self, handler, server_name, **kwargs): - super(FederationSendServlet, self).__init__(handler, **kwargs) + super(FederationSendServlet, self).__init__( + handler, server_name=server_name, **kwargs + ) self.server_name = server_name # This is when someone is trying to send us a bunch of data. @@ -432,6 +439,7 @@ class On3pidBindServlet(BaseFederationServlet): SERVLET_CLASSES = ( + FederationSendServlet, FederationPullServlet, FederationEventServlet, FederationStateServlet, @@ -451,3 +459,13 @@ SERVLET_CLASSES = ( FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, ) + + +def register_servlets(hs, resource, authenticator, ratelimiter): + for servletclass in SERVLET_CLASSES: + servletclass( + handler=hs.get_replication_layer(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) diff --git a/synapse/server.py b/synapse/server.py index 4a5796b98..a59e46ca2 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -20,6 +20,8 @@ # Imports required for the default HomeServer() implementation from twisted.web.client import BrowserLikePolicyForHTTPS +from twisted.enterprise import adbapi + from synapse.federation import initialize_http_replication from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.notifier import Notifier @@ -36,8 +38,10 @@ from synapse.push.pusherpool import PusherPool from synapse.events.builder import EventBuilderFactory from synapse.api.filtering import Filtering +from synapse.http.matrixfederationclient import MatrixFederationHttpClient -class BaseHomeServer(object): + +class HomeServer(object): """A basic homeserver object without lazy component builders. This will need all of the components it requires to either be passed as @@ -102,36 +106,6 @@ class BaseHomeServer(object): for depname in kwargs: setattr(self, depname, kwargs[depname]) - @classmethod - def _make_dependency_method(cls, depname): - def _get(self): - if hasattr(self, depname): - return getattr(self, depname) - - if hasattr(self, "build_%s" % (depname)): - # Prevent cyclic dependencies from deadlocking - if depname in self._building: - raise ValueError("Cyclic dependency while building %s" % ( - depname, - )) - self._building[depname] = 1 - - builder = getattr(self, "build_%s" % (depname)) - dep = builder() - setattr(self, depname, dep) - - del self._building[depname] - - return dep - - raise NotImplementedError( - "%s has no %s nor a builder for it" % ( - type(self).__name__, depname, - ) - ) - - setattr(BaseHomeServer, "get_%s" % (depname), _get) - def get_ip_from_request(self, request): # X-Forwarded-For is handled by our custom request type. return request.getClientIP() @@ -142,24 +116,6 @@ class BaseHomeServer(object): def is_mine_id(self, string): return string.split(":", 1)[1] == self.hostname -# Build magic accessors for every dependency -for depname in BaseHomeServer.DEPENDENCIES: - BaseHomeServer._make_dependency_method(depname) - - -class HomeServer(BaseHomeServer): - """A homeserver object that will construct most of its dependencies as - required. - - It still requires the following to be specified by the caller: - resource_for_client - resource_for_web_client - resource_for_federation - resource_for_content_repo - http_client - db_pool - """ - def build_clock(self): return Clock() @@ -224,3 +180,55 @@ class HomeServer(BaseHomeServer): def build_pusherpool(self): return PusherPool(self) + + def build_http_client(self): + return MatrixFederationHttpClient(self) + + def build_db_pool(self): + name = self.db_config["name"] + + return adbapi.ConnectionPool( + name, + **self.db_config.get("args", {}) + ) + + +def _make_dependency_method(depname): + def _get(hs): + try: + return getattr(hs, depname) + except AttributeError: + pass + + try: + builder = getattr(hs, "build_%s" % (depname)) + except AttributeError: + builder = None + + if builder: + # Prevent cyclic dependencies from deadlocking + if depname in hs._building: + raise ValueError("Cyclic dependency while building %s" % ( + depname, + )) + hs._building[depname] = 1 + + dep = builder() + setattr(hs, depname, dep) + + del hs._building[depname] + + return dep + + raise NotImplementedError( + "%s has no %s nor a builder for it" % ( + type(hs).__name__, depname, + ) + ) + + setattr(HomeServer, "get_%s" % (depname), _get) + + +# Build magic accessors for every dependency +for depname in HomeServer.DEPENDENCIES: + _make_dependency_method(depname) diff --git a/tests/federation/__init__.py b/tests/federation/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/federation/test_federation.py b/tests/federation/test_federation.py deleted file mode 100644 index f2c2ee412..000000000 --- a/tests/federation/test_federation.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# trial imports -from twisted.internet import defer -from tests import unittest - -# python imports -from mock import Mock, ANY - -from ..utils import MockHttpResource, MockClock, setup_test_homeserver - -from synapse.federation import initialize_http_replication -from synapse.events import FrozenEvent - - -def make_pdu(prev_pdus=[], **kwargs): - """Provide some default fields for making a PduTuple.""" - pdu_fields = { - "state_key": None, - "prev_events": prev_pdus, - } - pdu_fields.update(kwargs) - - return FrozenEvent(pdu_fields) - - -class FederationTestCase(unittest.TestCase): - @defer.inlineCallbacks - def setUp(self): - self.mock_resource = MockHttpResource() - self.mock_http_client = Mock(spec=[ - "get_json", - "put_json", - ]) - self.mock_persistence = Mock(spec=[ - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_auth_chain", - ]) - self.mock_persistence.get_received_txn_response.return_value = ( - defer.succeed(None) - ) - - retry_timings_res = { - "destination": "", - "retry_last_ts": 0, - "retry_interval": 0, - } - self.mock_persistence.get_destination_retry_timings.return_value = ( - defer.succeed(retry_timings_res) - ) - self.mock_persistence.get_auth_chain.return_value = [] - self.clock = MockClock() - hs = yield setup_test_homeserver( - resource_for_federation=self.mock_resource, - http_client=self.mock_http_client, - datastore=self.mock_persistence, - clock=self.clock, - keyring=Mock(), - ) - self.federation = initialize_http_replication(hs) - self.distributor = hs.get_distributor() - - @defer.inlineCallbacks - def test_get_state(self): - mock_handler = Mock(spec=[ - "get_state_for_pdu", - ]) - - self.federation.set_handler(mock_handler) - - mock_handler.get_state_for_pdu.return_value = defer.succeed([]) - - # Empty context initially - (code, response) = yield self.mock_resource.trigger( - "GET", - "/_matrix/federation/v1/state/my-context/", - None - ) - self.assertEquals(200, code) - self.assertFalse(response["pdus"]) - - # Now lets give the context some state - mock_handler.get_state_for_pdu.return_value = ( - defer.succeed([ - make_pdu( - event_id="the-pdu-id", - origin="red", - user_id="@a:red", - room_id="my-context", - type="m.topic", - origin_server_ts=123456789000, - depth=1, - content={"topic": "The topic"}, - state_key="", - power_level=1000, - prev_state="last-pdu-id", - ), - ]) - ) - - (code, response) = yield self.mock_resource.trigger( - "GET", - "/_matrix/federation/v1/state/my-context/", - None - ) - self.assertEquals(200, code) - self.assertEquals(1, len(response["pdus"])) - - @defer.inlineCallbacks - def test_get_pdu(self): - mock_handler = Mock(spec=[ - "get_persisted_pdu", - ]) - - self.federation.set_handler(mock_handler) - - mock_handler.get_persisted_pdu.return_value = ( - defer.succeed(None) - ) - - (code, response) = yield self.mock_resource.trigger( - "GET", - "/_matrix/federation/v1/event/abc123def456/", - None - ) - self.assertEquals(404, code) - - # Now insert such a PDU - mock_handler.get_persisted_pdu.return_value = ( - defer.succeed( - make_pdu( - event_id="abc123def456", - origin="red", - user_id="@a:red", - room_id="my-context", - type="m.text", - origin_server_ts=123456789001, - depth=1, - content={"text": "Here is the message"}, - ) - ) - ) - - (code, response) = yield self.mock_resource.trigger( - "GET", - "/_matrix/federation/v1/event/abc123def456/", - None - ) - self.assertEquals(200, code) - self.assertEquals(1, len(response["pdus"])) - self.assertEquals("m.text", response["pdus"][0]["type"]) - - @defer.inlineCallbacks - def test_send_pdu(self): - self.mock_http_client.put_json.return_value = defer.succeed( - (200, "OK") - ) - - pdu = make_pdu( - event_id="abc123def456", - origin="red", - user_id="@a:red", - room_id="my-context", - type="m.text", - origin_server_ts=123456789001, - depth=1, - content={"text": "Here is the message"}, - ) - - yield self.federation.send_pdu(pdu, ["remote"]) - - self.mock_http_client.put_json.assert_called_with( - "remote", - path="/_matrix/federation/v1/send/1000000/", - data={ - "origin_server_ts": 1000000, - "origin": "test", - "pdus": [ - pdu.get_pdu_json(), - ], - 'pdu_failures': [], - }, - json_data_callback=ANY, - long_retries=True, - ) - - @defer.inlineCallbacks - def test_send_edu(self): - self.mock_http_client.put_json.return_value = defer.succeed( - (200, "OK") - ) - - yield self.federation.send_edu( - destination="remote", - edu_type="m.test", - content={"testing": "content here"}, - ) - - # MockClock ensures we can guess these timestamps - self.mock_http_client.put_json.assert_called_with( - "remote", - path="/_matrix/federation/v1/send/1000000/", - data={ - "origin": "test", - "origin_server_ts": 1000000, - "pdus": [], - "edus": [ - { - "edu_type": "m.test", - "content": {"testing": "content here"}, - } - ], - 'pdu_failures': [], - }, - json_data_callback=ANY, - long_retries=True, - ) - - @defer.inlineCallbacks - def test_recv_edu(self): - recv_observer = Mock() - recv_observer.return_value = defer.succeed(()) - - self.federation.register_edu_handler("m.test", recv_observer) - - yield self.mock_resource.trigger( - "PUT", - "/_matrix/federation/v1/send/1001000/", - """{ - "origin": "remote", - "origin_server_ts": 1001000, - "pdus": [], - "edus": [ - { - "origin": "remote", - "destination": "test", - "edu_type": "m.test", - "content": {"testing": "reply here"} - } - ] - }""" - ) - - recv_observer.assert_called_with( - "remote", {"testing": "reply here"} - ) - - @defer.inlineCallbacks - def test_send_query(self): - self.mock_http_client.get_json.return_value = defer.succeed( - {"your": "response"} - ) - - response = yield self.federation.make_query( - destination="remote", - query_type="a-question", - args={"one": "1", "two": "2"}, - ) - - self.assertEquals({"your": "response"}, response) - - self.mock_http_client.get_json.assert_called_with( - destination="remote", - path="/_matrix/federation/v1/query/a-question", - args={"one": "1", "two": "2"}, - retry_on_dns_fail=True, - ) - - @defer.inlineCallbacks - def test_recv_query(self): - recv_handler = Mock() - recv_handler.return_value = defer.succeed({"another": "response"}) - - self.federation.register_query_handler("a-question", recv_handler) - - code, response = yield self.mock_resource.trigger( - "GET", - "/_matrix/federation/v1/query/a-question?three=3&four=4", - None - ) - - self.assertEquals(200, code) - self.assertEquals({"another": "response"}, response) - - recv_handler.assert_called_with( - {"three": "3", "four": "4"} - ) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 90b911f87..8d7cfd79a 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -280,6 +280,15 @@ class PresenceEventStreamTestCase(unittest.TestCase): } EventSources.SOURCE_TYPES["presence"] = PresenceEventSource + clock = Mock(spec=[ + "call_later", + "cancel_call_later", + "time_msec", + "looping_call", + ]) + + clock.time_msec.return_value = 1000000 + hs = yield setup_test_homeserver( http_client=None, resource_for_client=self.mock_resource, @@ -289,16 +298,9 @@ class PresenceEventStreamTestCase(unittest.TestCase): "get_presence_list", "get_rooms_for_user", ]), - clock=Mock(spec=[ - "call_later", - "cancel_call_later", - "time_msec", - "looping_call", - ]), + clock=clock, ) - hs.get_clock().time_msec.return_value = 1000000 - def _get_user_by_req(req=None, allow_guest=False): return Requester(UserID.from_string(myid), "", False) diff --git a/tests/test_types.py b/tests/test_types.py index b9534329e..24d61dbe5 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -16,10 +16,10 @@ from tests import unittest from synapse.api.errors import SynapseError -from synapse.server import BaseHomeServer +from synapse.server import HomeServer from synapse.types import UserID, RoomAlias -mock_homeserver = BaseHomeServer(hostname="my.domain") +mock_homeserver = HomeServer(hostname="my.domain") class UserIDTestCase(unittest.TestCase): @@ -34,7 +34,6 @@ class UserIDTestCase(unittest.TestCase): with self.assertRaises(SynapseError): UserID.from_string("") - def test_build(self): user = UserID("5678efgh", "my.domain") diff --git a/tests/utils.py b/tests/utils.py index 358b5b72b..d75d492cb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,6 +19,8 @@ from synapse.api.constants import EventTypes from synapse.storage.prepare_database import prepare_database from synapse.storage.engines import create_engine from synapse.server import HomeServer +from synapse.federation.transport import server +from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.logcontext import LoggingContext @@ -80,6 +82,22 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): hs.build_handlers = swap_out_hash_for_testing(hs.build_handlers) + fed = kargs.get("resource_for_federation", None) + if fed: + server.register_servlets( + hs, + resource=fed, + authenticator=server.Authenticator(hs), + ratelimiter=FederationRateLimiter( + hs.get_clock(), + window_size=hs.config.federation_rc_window_size, + sleep_limit=hs.config.federation_rc_sleep_limit, + sleep_msec=hs.config.federation_rc_sleep_delay, + reject_limit=hs.config.federation_rc_reject_limit, + concurrent_requests=hs.config.federation_rc_concurrent + ), + ) + defer.returnValue(hs)