Merge pull request #8757 from matrix-org/rav/pass_site_to_make_request

Pass a Site into `make_request`
This commit is contained in:
Richard van der Hoff 2020-11-16 18:22:24 +00:00 committed by GitHub
commit 3dc1871219
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 228 additions and 88 deletions

1
changelog.d/8757.misc Normal file
View File

@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.

View File

@ -15,6 +15,7 @@
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
from tests.server import make_request, render
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -55,10 +56,10 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"] resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status") request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
self.render(request) render(request, resource, self.reactor)
# 400 + unrecognised, because nothing is registered # 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400) self.assertEqual(channel.code, 400)
@ -77,10 +78,10 @@ class FrontendProxyTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1) self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"] resource = site.resource.children[b"_matrix"].children[b"client"]
request, channel = self.make_request("PUT", "presence/a/status") request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
self.render(request) render(request, resource, self.reactor)
# 401, because the stub servlet still checks authentication # 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)

View File

@ -20,6 +20,7 @@ from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.homeserver import SynapseHomeServer from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def from synapse.config.server import parse_listener_def
from tests.server import make_request, render
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -66,16 +67,16 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
try: try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"] resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError: except KeyError:
if expectation == "no_resource": if expectation == "no_resource":
return return
raise raise
request, channel = self.make_request( request, channel = make_request(
"GET", "/_matrix/federation/v1/openid/userinfo" self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
) )
self.render(request) render(request, resource, self.reactor)
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)
@ -115,15 +116,15 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
# Grab the resource from the site that was told to listen # Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1] site = self.reactor.tcpServers[0][1]
try: try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"] resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError: except KeyError:
if expectation == "no_resource": if expectation == "no_resource":
return return
raise raise
request, channel = self.make_request( request, channel = make_request(
"GET", "/_matrix/federation/v1/openid/userinfo" self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
) )
self.render(request) render(request, resource, self.reactor)
self.assertEqual(channel.code, 401) self.assertEqual(channel.code, 401)

View File

@ -17,6 +17,7 @@
from synapse.http.additional_resource import AdditionalResource from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json from synapse.http.server import respond_with_json
from tests.server import FakeSite, make_request, render
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -43,20 +44,20 @@ class AdditionalResourceTests(HomeserverTestCase):
def test_async(self): def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request handler = _AsyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler) resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/") request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
self.render(request) render(request, resource, self.reactor)
self.assertEqual(request.code, 200) self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self): def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request handler = _SyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler) resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/") request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
self.render(request) render(request, resource, self.reactor)
self.assertEqual(request.code, 200) self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

View File

@ -20,7 +20,7 @@ from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -46,8 +46,11 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Test that registration works when using a single client reader worker. """Test that registration works when using a single client reader worker.
""" """
worker_hs = self.make_worker_hs("synapse.app.client_reader") worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
request_1, channel_1 = self.make_request( request_1, channel_1 = make_request(
self.reactor,
site,
"POST", "POST",
"register", "register",
{"username": "user", "type": "m.login.password", "password": "bar"}, {"username": "user", "type": "m.login.password", "password": "bar"},
@ -59,8 +62,12 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
session = channel_1.json_body["session"] session = channel_1.json_body["session"]
# also complete the dummy auth # also complete the dummy auth
request_2, channel_2 = self.make_request( request_2, channel_2 = make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs, request_2) self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200) self.assertEqual(request_2.code, 200)
@ -74,7 +81,10 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader") worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
request_1, channel_1 = self.make_request( site_1 = self._hs_to_site[worker_hs_1]
request_1, channel_1 = make_request(
self.reactor,
site_1,
"POST", "POST",
"register", "register",
{"username": "user", "type": "m.login.password", "password": "bar"}, {"username": "user", "type": "m.login.password", "password": "bar"},
@ -86,8 +96,13 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
session = channel_1.json_body["session"] session = channel_1.json_body["session"]
# also complete the dummy auth # also complete the dummy auth
request_2, channel_2 = self.make_request( site_2 = self._hs_to_site[worker_hs_2]
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}} request_2, channel_2 = make_request(
self.reactor,
site_2,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel ) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs_2, request_2) self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200) self.assertEqual(request_2.code, 200)

View File

@ -28,7 +28,7 @@ from synapse.server import HomeServer
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -67,14 +67,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
The channel for the *client* request and the *outbound* request for The channel for the *client* request and the *outbound* request for
the media which the caller should respond to. the media which the caller should respond to.
""" """
resource = hs.get_media_repository_resource().children[b"download"]
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(resource),
"GET", "GET",
"/{}/{}".format(target, media_id), "/{}/{}".format(target, media_id),
shorthand=False, shorthand=False,
access_token=self.access_token, access_token=self.access_token,
) )
request.render(hs.get_media_repository_resource().children[b"download"]) request.render(resource)
self.pump() self.pump()
clients = self.reactor.tcpClients clients = self.reactor.tcpClients

View File

@ -22,6 +22,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.utils import USE_POSTGRES_FOR_TESTS from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -148,6 +149,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
sync_hs = self.make_worker_hs( sync_hs = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "sync"}, "synapse.app.generic_worker", {"worker_name": "sync"},
) )
sync_hs_site = self._hs_to_site[sync_hs]
# Specially selected room IDs that get persisted on different workers. # Specially selected room IDs that get persisted on different workers.
room_id1 = "!foo:test" room_id1 = "!foo:test"
@ -178,7 +180,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
) )
# Do an initial sync so that we're up to date. # Do an initial sync so that we're up to date.
request, channel = self.make_request("GET", "/sync", access_token=access_token) request, channel = make_request(
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
)
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
next_batch = channel.json_body["next_batch"] next_batch = channel.json_body["next_batch"]
@ -203,8 +207,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Check that syncing still gets the new event, despite the gap in the # Check that syncing still gets the new event, despite the gap in the
# stream IDs. # stream IDs.
request, channel = self.make_request( request, channel = make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
) )
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
@ -230,7 +238,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token) response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"] first_event_in_room2 = response["event_id"]
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/sync?since={}".format(vector_clock_token), "/sync?since={}".format(vector_clock_token),
access_token=access_token, access_token=access_token,
@ -254,8 +264,12 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token) self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token) self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
request, channel = self.make_request( request, channel = make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
) )
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
@ -269,7 +283,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back in the first room should not produce any results, as # Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly # no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion. # filtering results based on the vector clock portion.
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format( "/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id1, prev_batch1, vector_clock_token room_id1, prev_batch1, vector_clock_token
@ -281,7 +297,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back on the second room should produce the first event # Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken. # again. This tests that pagination isn't completely broken.
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format( "/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id2, prev_batch2, vector_clock_token room_id2, prev_batch2, vector_clock_token
@ -295,7 +313,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
) )
# Paginating forwards should give the same results # Paginating forwards should give the same results
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format( "/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id1, vector_clock_token, prev_batch1 room_id1, vector_clock_token, prev_batch1
@ -305,7 +325,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.render_on_worker(sync_hs, request) self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"]) self.assertListEqual([], channel.json_body["chunk"])
request, channel = self.make_request( request, channel = make_request(
self.reactor,
sync_hs_site,
"GET", "GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format( "/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id2, vector_clock_token, prev_batch2, room_id2, vector_clock_token, prev_batch2,

View File

@ -30,6 +30,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import groups from synapse.rest.client.v2_alpha import groups
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
class VersionTestCase(unittest.HomeserverTestCase): class VersionTestCase(unittest.HomeserverTestCase):
@ -222,8 +223,13 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
def _ensure_quarantined(self, admin_user_tok, server_and_media_id): def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it.""" """Ensure a piece of media is quarantined when trying to access it."""
request, channel = self.make_request( request, channel = make_request(
"GET", server_and_media_id, shorthand=False, access_token=admin_user_tok, self.reactor,
FakeSite(self.download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=admin_user_tok,
) )
request.render(self.download_resource) request.render(self.download_resource)
self.pump(1.0) self.pump(1.0)
@ -287,7 +293,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_name, media_id = server_name_and_media_id.split("/") server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media # Attempt to access the media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET", "GET",
server_name_and_media_id, server_name_and_media_id,
shorthand=False, shorthand=False,
@ -462,7 +470,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1) self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media # Attempt to access each piece of media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET", "GET",
server_and_media_id_2, server_and_media_id_2,
shorthand=False, shorthand=False,

View File

@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, profile, room
from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.filepath import MediaFilePaths
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
@ -124,7 +125,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(server_name, self.server_name) self.assertEqual(server_name, self.server_name)
# Attempt to access media # Attempt to access media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@ -161,7 +164,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
) )
# Attempt to access media # Attempt to access media
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@ -535,7 +540,9 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
media_id = server_and_media_id.split("/")[1] media_id = server_and_media_id.split("/")[1]
local_path = self.filepaths.local_media_filepath(media_id) local_path = self.filepaths.local_media_filepath(media_id)
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(download_resource),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,

View File

@ -21,7 +21,7 @@ from synapse.rest.client.v1 import login, room
from synapse.rest.consent import consent_resource from synapse.rest.consent import consent_resource
from tests import unittest from tests import unittest
from tests.server import render from tests.server import FakeSite, make_request, render
class ConsentResourceTestCase(unittest.HomeserverTestCase): class ConsentResourceTestCase(unittest.HomeserverTestCase):
@ -61,7 +61,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
def test_render_public_consent(self): def test_render_public_consent(self):
"""You can observe the terms form without specifying a user""" """You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs) resource = consent_resource.ConsentResource(self.hs)
request, channel = self.make_request("GET", "/consent?v=1", shorthand=False) request, channel = make_request(
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
)
render(request, resource, self.reactor) render(request, resource, self.reactor)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -81,8 +83,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "") uri_builder.build_user_consent_uri(user_id).replace("_matrix/", "")
+ "&u=user" + "&u=user"
) )
request, channel = self.make_request( request, channel = make_request(
"GET", consent_uri, access_token=access_token, shorthand=False self.reactor,
FakeSite(resource),
"GET",
consent_uri,
access_token=access_token,
shorthand=False,
) )
render(request, resource, self.reactor) render(request, resource, self.reactor)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -92,7 +99,9 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
self.assertEqual(consented, "False") self.assertEqual(consented, "False")
# POST to the consent page, saying we've agreed # POST to the consent page, saying we've agreed
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(resource),
"POST", "POST",
consent_uri + "&v=" + version, consent_uri + "&v=" + version,
access_token=access_token, access_token=access_token,
@ -103,8 +112,13 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# Fetch the consent page, to get the consent version -- it should have # Fetch the consent page, to get the consent version -- it should have
# changed # changed
request, channel = self.make_request( request, channel = make_request(
"GET", consent_uri, access_token=access_token, shorthand=False self.reactor,
FakeSite(resource),
"GET",
consent_uri,
access_token=access_token,
shorthand=False,
) )
render(request, resource, self.reactor) render(request, resource, self.reactor)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)

View File

@ -23,10 +23,11 @@ from typing import Any, Dict, Optional
import attr import attr
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership from synapse.api.constants import Membership
from tests.server import make_request, render from tests.server import FakeSite, make_request, render
@attr.s @attr.s
@ -36,7 +37,7 @@ class RestHelper:
""" """
hs = attr.ib() hs = attr.ib()
resource = attr.ib() site = attr.ib(type=Site)
auth_user_id = attr.ib() auth_user_id = attr.ib()
def create_room_as( def create_room_as(
@ -52,9 +53,13 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "POST", path, json.dumps(content).encode("utf8") self.hs.get_reactor(),
self.site,
"POST",
path,
json.dumps(content).encode("utf8"),
) )
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert channel.result["code"] == b"%d" % expect_code, channel.result assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id self.auth_user_id = temp_id
@ -125,10 +130,14 @@ class RestHelper:
data.update(extra_data) data.update(extra_data)
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode("utf8") self.hs.get_reactor(),
self.site,
"PUT",
path,
json.dumps(data).encode("utf8"),
) )
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r" "Expected: %d, got: %d, resp: %r"
@ -158,9 +167,13 @@ class RestHelper:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode("utf8") self.hs.get_reactor(),
self.site,
"PUT",
path,
json.dumps(content).encode("utf8"),
) )
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r" "Expected: %d, got: %d, resp: %r"
@ -210,9 +223,11 @@ class RestHelper:
if body is not None: if body is not None:
content = json.dumps(body).encode("utf8") content = json.dumps(body).encode("utf8")
request, channel = make_request(self.hs.get_reactor(), method, path, content) request, channel = make_request(
self.hs.get_reactor(), self.site, method, path, content
)
render(request, self.resource, self.hs.get_reactor()) render(request, self.site.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r" "Expected: %d, got: %d, resp: %r"
@ -297,6 +312,7 @@ class RestHelper:
path = "/_matrix/media/r0/upload?filename=%s" % (filename,) path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
request, channel = make_request( request, channel = make_request(
self.hs.get_reactor(), self.hs.get_reactor(),
FakeSite(resource),
"POST", "POST",
path, path,
content=image_data, content=image_data,

View File

@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import account, register
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
from tests.unittest import override_config from tests.unittest import override_config
@ -255,9 +256,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "") path = link.replace("https://example.com", "")
# Load the password reset confirmation page # Load the password reset confirmation page
request, channel = self.make_request("GET", path, shorthand=False) request, channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"GET",
path,
shorthand=False,
)
request.render(self.submit_token_resource) request.render(self.submit_token_resource)
self.pump() self.pump()
self.assertEquals(200, channel.code, channel.result) self.assertEquals(200, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the # Now POST to the same endpoint, mimicking the same behaviour as clicking the
@ -271,7 +279,9 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
form_args.append(arg) form_args.append(arg)
# Confirm the password reset # Confirm the password reset
request, channel = self.make_request( request, channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
"POST", "POST",
path, path,
content=urlencode(form_args).encode("utf8"), content=urlencode(form_args).encode("utf8"),

View File

@ -36,6 +36,7 @@ from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
class MediaStorageTests(unittest.HomeserverTestCase): class MediaStorageTests(unittest.HomeserverTestCase):
@ -227,7 +228,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _req(self, content_disposition): def _req(self, content_disposition):
request, channel = self.make_request("GET", self.media_id, shorthand=False) request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
self.media_id,
shorthand=False,
)
request.render(self.download_resource) request.render(self.download_resource)
self.pump() self.pump()
@ -317,8 +324,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def _test_thumbnail(self, method, expected_body, expected_found): def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method params = "?width=32&height=32&method=" + method
request, channel = self.make_request( request, channel = make_request(
"GET", self.media_id + params, shorthand=False self.reactor,
FakeSite(self.thumbnail_resource),
"GET",
self.media_id + params,
shorthand=False,
) )
request.render(self.thumbnail_resource) request.render(self.thumbnail_resource)
self.pump() self.pump()

View File

@ -21,6 +21,7 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http import unquote from twisted.web.http import unquote
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
from twisted.web.server import Site from twisted.web.server import Site
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -147,9 +148,21 @@ class FakeSite:
site_tag = "test" site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake") access_logger = logging.getLogger("synapse.access.http.fake")
def __init__(self, resource: IResource):
"""
Args:
resource: the resource to be used for rendering all requests
"""
self._resource = resource
def getResourceFor(self, request):
return self._resource
def make_request( def make_request(
reactor, reactor,
site: Site,
method, method,
path, path,
content=b"", content=b"",
@ -167,6 +180,8 @@ def make_request(
content, and return the Request and the Channel underneath. content, and return the Request and the Channel underneath.
Args: Args:
site: The twisted Site to associate with the Channel
method (bytes/unicode): The HTTP request method ("verb"). method (bytes/unicode): The HTTP request method ("verb").
path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
escaped UTF-8 & spaces and such). escaped UTF-8 & spaces and such).
@ -202,10 +217,11 @@ def make_request(
if not path.startswith(b"/"): if not path.startswith(b"/"):
path = b"/" + path path = b"/" + path
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
if isinstance(content, str): if isinstance(content, str):
content = content.encode("utf8") content = content.encode("utf8")
site = FakeSite()
channel = FakeChannel(site, reactor) channel = FakeChannel(site, reactor)
req = request(channel) req = request(channel)

View File

@ -414,6 +414,7 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase):
request, channel = make_request( request, channel = make_request(
self.reactor, self.reactor,
self.site,
"GET", "GET",
"/_matrix/client/r0/admin/users/" + self.user_id, "/_matrix/client/r0/admin/users/" + self.user_id,
access_token=access_token, access_token=access_token,

View File

@ -26,6 +26,7 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import ( from tests.server import (
FakeSite,
ThreadedMemoryReactorClock, ThreadedMemoryReactorClock,
make_request, make_request,
render, render,
@ -62,7 +63,7 @@ class JsonResourceTests(unittest.TestCase):
) )
request, channel = make_request( request, channel = make_request(
self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
) )
render(request, res, self.reactor) render(request, res, self.reactor)
@ -83,7 +84,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@ -108,7 +111,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@ -127,7 +132,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
@ -150,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar") request, channel = make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
@ -173,7 +182,9 @@ class JsonResourceTests(unittest.TestCase):
) )
# The path was registered as GET, but this is a HEAD request. # The path was registered as GET, but this is a HEAD request.
request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo") request, channel = make_request(
self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
@ -196,9 +207,6 @@ class OptionsResourceTests(unittest.TestCase):
def _make_request(self, method, path): def _make_request(self, method, path):
"""Create a request from the method/path and return a channel with the response.""" """Create a request from the method/path and return a channel with the response."""
request, channel = make_request(self.reactor, method, path, shorthand=False)
request.prepath = [] # This doesn't get set properly by make_request.
# Create a site and query for the resource. # Create a site and query for the resource.
site = SynapseSite( site = SynapseSite(
"test", "test",
@ -207,6 +215,12 @@ class OptionsResourceTests(unittest.TestCase):
self.resource, self.resource,
"1.0", "1.0",
) )
request, channel = make_request(
self.reactor, site, method, path, shorthand=False
)
request.prepath = [] # This doesn't get set properly by make_request.
request.site = site request.site = site
resource = site.getResourceFor(request) resource = site.getResourceFor(request)
@ -284,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
@ -303,7 +317,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"301") self.assertEqual(channel.result["code"], b"301")
@ -325,7 +339,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"304") self.assertEqual(channel.result["code"], b"304")
@ -345,7 +359,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
request, channel = make_request(self.reactor, b"HEAD", b"/path") request, channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")

View File

@ -252,7 +252,7 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.v1.utils import RestHelper from tests.rest.client.v1.utils import RestHelper
self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
if hasattr(self, "user_id"): if hasattr(self, "user_id"):
if self.hijack_auth: if self.hijack_auth:
@ -425,11 +425,9 @@ class HomeserverTestCase(TestCase):
Returns: Returns:
Tuple[synapse.http.site.SynapseRequest, channel] Tuple[synapse.http.site.SynapseRequest, channel]
""" """
if isinstance(content, dict):
content = json.dumps(content).encode("utf8")
return make_request( return make_request(
self.reactor, self.reactor,
self.site,
method, method,
path, path,
content, content,