Make the http server handle coroutine-making REST servlets (#5475)

This commit is contained in:
Amber Brown 2019-06-29 17:06:55 +10:00 committed by GitHub
parent c7ff297dde
commit f40a7dc41f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 162 additions and 174 deletions

1
changelog.d/5475.misc Normal file
View File

@ -0,0 +1 @@
Synapse can now handle RestServlets that return coroutines.

View File

@ -16,10 +16,11 @@
import cgi import cgi
import collections import collections
import http.client
import logging import logging
import types
from six import PY3 import urllib
from six.moves import http_client, urllib from io import BytesIO
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@ -41,11 +42,6 @@ from synapse.api.errors import (
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
if PY3:
from io import BytesIO
else:
from cStringIO import StringIO as BytesIO
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HTML_ERROR_TEMPLATE = """<!DOCTYPE html> HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
@ -75,10 +71,9 @@ def wrap_json_request_handler(h):
deferred fails with any other type of error we send a 500 reponse. deferred fails with any other type of error we send a 500 reponse.
""" """
@defer.inlineCallbacks async def wrapped_request_handler(self, request):
def wrapped_request_handler(self, request):
try: try:
yield h(self, request) await h(self, request)
except SynapseError as e: except SynapseError as e:
code = e.code code = e.code
logger.info("%s SynapseError: %s - %s", request, code, e.msg) logger.info("%s SynapseError: %s - %s", request, code, e.msg)
@ -142,10 +137,12 @@ def wrap_html_request_handler(h):
where "request" must be a SynapseRequest. where "request" must be a SynapseRequest.
""" """
def wrapped_request_handler(self, request): async def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request) try:
d.addErrback(_return_html_error, request) return await h(self, request)
return d except Exception:
f = failure.Failure()
return _return_html_error(f, request)
return wrap_async_request_handler(wrapped_request_handler) return wrap_async_request_handler(wrapped_request_handler)
@ -171,7 +168,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()), exc_info=(f.type, f.value, f.getTracebackObject()),
) )
else: else:
code = http_client.INTERNAL_SERVER_ERROR code = http.client.INTERNAL_SERVER_ERROR
msg = "Internal server error" msg = "Internal server error"
logger.error( logger.error(
@ -201,10 +198,9 @@ def wrap_async_request_handler(h):
logged until the deferred completes. logged until the deferred completes.
""" """
@defer.inlineCallbacks async def wrapped_async_request_handler(self, request):
def wrapped_async_request_handler(self, request):
with request.processing(): with request.processing():
yield h(self, request) await h(self, request)
# we need to preserve_fn here, because the synchronous render method won't yield for # we need to preserve_fn here, because the synchronous render method won't yield for
# us (obviously) # us (obviously)
@ -270,12 +266,11 @@ class JsonResource(HttpServer, resource.Resource):
def render(self, request): def render(self, request):
""" This gets called by twisted every time someone sends us a request. """ This gets called by twisted every time someone sends us a request.
""" """
self._async_render(request) defer.ensureDeferred(self._async_render(request))
return NOT_DONE_YET return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks async def _async_render(self, request):
def _async_render(self, request):
""" This gets called from render() every time someone sends us a request. """ This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
@ -292,26 +287,19 @@ class JsonResource(HttpServer, resource.Resource):
# Now trigger the callback. If it returns a response, we send it # Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper # here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler. # installed by @request_handler.
def _unquote(s):
if PY3:
# On Python 3, unquote is unicode -> unicode
return urllib.parse.unquote(s)
else:
# On Python 2, unquote is bytes -> bytes We need to encode the
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
return urllib.parse.unquote(s.encode("ascii")).decode("utf8")
kwargs = intern_dict( kwargs = intern_dict(
{ {
name: _unquote(value) if value else value name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items() for name, value in group_dict.items()
} }
) )
callback_return = yield callback(request, **kwargs) callback_return = callback(request, **kwargs)
# Is it synchronous? We'll allow this for now.
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await callback_return
if callback_return is not None: if callback_return is not None:
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
@ -360,6 +348,23 @@ class JsonResource(HttpServer, resource.Resource):
) )
class DirectServeResource(resource.Resource):
def render(self, request):
"""
Render the request, using an asynchronous render handler if it exists.
"""
render_callback_name = "_async_render_" + request.method.decode("ascii")
if hasattr(self, render_callback_name):
# Call the handler
callback = getattr(self, render_callback_name)
defer.ensureDeferred(callback(request))
return NOT_DONE_YET
else:
super().render(request)
def _options_handler(request): def _options_handler(request):
"""Request handler for OPTIONS requests """Request handler for OPTIONS requests

View File

@ -23,13 +23,13 @@ from six.moves import http_client
import jinja2 import jinja2
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.server import finish_request, wrap_html_request_handler from synapse.http.server import (
DirectServeResource,
finish_request,
wrap_html_request_handler,
)
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.types import UserID from synapse.types import UserID
@ -47,7 +47,7 @@ else:
return a == b return a == b
class ConsentResource(Resource): class ConsentResource(DirectServeResource):
"""A twisted Resource to display a privacy policy and gather consent to it """A twisted Resource to display a privacy policy and gather consent to it
When accessed via GET, returns the privacy policy via a template. When accessed via GET, returns the privacy policy via a template.
@ -87,7 +87,7 @@ class ConsentResource(Resource):
Args: Args:
hs (synapse.server.HomeServer): homeserver hs (synapse.server.HomeServer): homeserver
""" """
Resource.__init__(self) super().__init__()
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -118,18 +118,12 @@ class ConsentResource(Resource):
self._hmac_secret = hs.config.form_secret.encode("utf-8") self._hmac_secret = hs.config.form_secret.encode("utf-8")
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@wrap_html_request_handler @wrap_html_request_handler
@defer.inlineCallbacks async def _async_render_GET(self, request):
def _async_render_GET(self, request):
""" """
Args: Args:
request (twisted.web.http.Request): request (twisted.web.http.Request):
""" """
version = parse_string(request, "v", default=self._default_consent_version) version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", required=False, default="") username = parse_string(request, "u", required=False, default="")
userhmac = None userhmac = None
@ -145,7 +139,7 @@ class ConsentResource(Resource):
else: else:
qualified_user_id = UserID(username, self.hs.hostname).to_string() qualified_user_id = UserID(username, self.hs.hostname).to_string()
u = yield self.store.get_user_by_id(qualified_user_id) u = await self.store.get_user_by_id(qualified_user_id)
if u is None: if u is None:
raise NotFoundError("Unknown user") raise NotFoundError("Unknown user")
@ -165,13 +159,8 @@ class ConsentResource(Resource):
except TemplateNotFound: except TemplateNotFound:
raise NotFoundError("Unknown policy version") raise NotFoundError("Unknown policy version")
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@wrap_html_request_handler @wrap_html_request_handler
@defer.inlineCallbacks async def _async_render_POST(self, request):
def _async_render_POST(self, request):
""" """
Args: Args:
request (twisted.web.http.Request): request (twisted.web.http.Request):
@ -188,12 +177,12 @@ class ConsentResource(Resource):
qualified_user_id = UserID(username, self.hs.hostname).to_string() qualified_user_id = UserID(username, self.hs.hostname).to_string()
try: try:
yield self.store.user_set_consent_version(qualified_user_id, version) await self.store.user_set_consent_version(qualified_user_id, version)
except StoreError as e: except StoreError as e:
if e.code != 404: if e.code != 404:
raise raise
raise NotFoundError("Unknown user") raise NotFoundError("Unknown user")
yield self.registration_handler.post_consent_actions(qualified_user_id) await self.registration_handler.post_consent_actions(qualified_user_id)
try: try:
self._render_template(request, "success.html") self._render_template(request, "success.html")

View File

@ -16,18 +16,20 @@ import logging
from io import BytesIO from io import BytesIO
from twisted.internet import defer from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler from synapse.http.server import (
DirectServeResource,
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RemoteKey(Resource): class RemoteKey(DirectServeResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature """HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@ -94,13 +96,8 @@ class RemoteKey(Resource):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.federation_domain_whitelist = hs.config.federation_domain_whitelist
def render_GET(self, request):
self.async_render_GET(request)
return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks async def _async_render_GET(self, request):
def async_render_GET(self, request):
if len(request.postpath) == 1: if len(request.postpath) == 1:
server, = request.postpath server, = request.postpath
query = {server.decode("ascii"): {}} query = {server.decode("ascii"): {}}
@ -114,20 +111,15 @@ class RemoteKey(Resource):
else: else:
raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND) raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
yield self.query_keys(request, query, query_remote_on_cache_miss=True) await self.query_keys(request, query, query_remote_on_cache_miss=True)
def render_POST(self, request):
self.async_render_POST(request)
return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks async def _async_render_POST(self, request):
def async_render_POST(self, request):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
query = content["server_keys"] query = content["server_keys"]
yield self.query_keys(request, query, query_remote_on_cache_miss=True) await self.query_keys(request, query, query_remote_on_cache_miss=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False): def query_keys(self, request, query, query_remote_on_cache_miss=False):

View File

@ -14,31 +14,28 @@
# limitations under the License. # limitations under the License.
# #
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from synapse.http.server import respond_with_json, wrap_json_request_handler from synapse.http.server import (
DirectServeResource,
respond_with_json,
wrap_json_request_handler,
)
class MediaConfigResource(Resource): class MediaConfigResource(DirectServeResource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs):
Resource.__init__(self) super().__init__()
config = hs.get_config() config = hs.get_config()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size} self.limits_dict = {"m.upload.size": config.max_upload_size}
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks async def _async_render_GET(self, request):
def _async_render_GET(self, request): await self.auth.get_user_by_req(request)
yield self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True) respond_with_json(request, 200, self.limits_dict, send_cors=True)
def render_OPTIONS(self, request): def render_OPTIONS(self, request):

View File

@ -14,37 +14,31 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
import synapse.http.servlet import synapse.http.servlet
from synapse.http.server import set_cors_headers, wrap_json_request_handler from synapse.http.server import (
DirectServeResource,
set_cors_headers,
wrap_json_request_handler,
)
from ._base import parse_media_id, respond_404 from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DownloadResource(Resource): class DownloadResource(DirectServeResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo): def __init__(self, hs, media_repo):
Resource.__init__(self) super().__init__()
self.media_repo = media_repo self.media_repo = media_repo
self.server_name = hs.hostname self.server_name = hs.hostname
# this is expected by @wrap_json_request_handler # this is expected by @wrap_json_request_handler
self.clock = hs.get_clock() self.clock = hs.get_clock()
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks async def _async_render_GET(self, request):
def _async_render_GET(self, request):
set_cors_headers(request) set_cors_headers(request)
request.setHeader( request.setHeader(
b"Content-Security-Policy", b"Content-Security-Policy",
@ -58,7 +52,7 @@ class DownloadResource(Resource):
) )
server_name, media_id, name = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:
yield self.media_repo.get_local_media(request, media_id, name) await self.media_repo.get_local_media(request, media_id, name)
else: else:
allow_remote = synapse.http.servlet.parse_boolean( allow_remote = synapse.http.servlet.parse_boolean(
request, "allow_remote", default=True request, "allow_remote", default=True
@ -72,4 +66,4 @@ class DownloadResource(Resource):
respond_404(request) respond_404(request)
return return
yield self.media_repo.get_remote_media(request, server_name, media_id, name) await self.media_repo.get_remote_media(request, server_name, media_id, name)

View File

@ -32,12 +32,11 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.http.server import ( from synapse.http.server import (
DirectServeResource,
respond_with_json, respond_with_json,
respond_with_json_bytes, respond_with_json_bytes,
wrap_json_request_handler, wrap_json_request_handler,
@ -58,11 +57,11 @@ _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=r
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
class PreviewUrlResource(Resource): class PreviewUrlResource(DirectServeResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo, media_storage): def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self) super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -98,16 +97,11 @@ class PreviewUrlResource(Resource):
def render_OPTIONS(self, request): def render_OPTIONS(self, request):
return respond_with_json(request, 200, {}, send_cors=True) return respond_with_json(request, 200, {}, send_cors=True)
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks async def _async_render_GET(self, request):
def _async_render_GET(self, request):
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
url = parse_string(request, "url") url = parse_string(request, "url")
if b"ts" in request.args: if b"ts" in request.args:
ts = parse_integer(request, "ts") ts = parse_integer(request, "ts")
@ -159,7 +153,7 @@ class PreviewUrlResource(Resource):
else: else:
logger.info("Returning cached response") logger.info("Returning cached response")
og = yield make_deferred_yieldable(observable.observe()) og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True) respond_with_json_bytes(request, 200, og, send_cors=True)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -17,10 +17,12 @@
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.http.server import set_cors_headers, wrap_json_request_handler from synapse.http.server import (
DirectServeResource,
set_cors_headers,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from ._base import ( from ._base import (
@ -34,11 +36,11 @@ from ._base import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThumbnailResource(Resource): class ThumbnailResource(DirectServeResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo, media_storage): def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self) super().__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.media_repo = media_repo self.media_repo = media_repo
@ -47,13 +49,8 @@ class ThumbnailResource(Resource):
self.server_name = hs.hostname self.server_name = hs.hostname
self.clock = hs.get_clock() self.clock = hs.get_clock()
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks async def _async_render_GET(self, request):
def _async_render_GET(self, request):
set_cors_headers(request) set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request) server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True) width = parse_integer(request, "width", required=True)
@ -63,21 +60,21 @@ class ThumbnailResource(Resource):
if server_name == self.server_name: if server_name == self.server_name:
if self.dynamic_thumbnails: if self.dynamic_thumbnails:
yield self._select_or_generate_local_thumbnail( await self._select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type request, media_id, width, height, method, m_type
) )
else: else:
yield self._respond_local_thumbnail( await self._respond_local_thumbnail(
request, media_id, width, height, method, m_type request, media_id, width, height, method, m_type
) )
self.media_repo.mark_recently_accessed(None, media_id) self.media_repo.mark_recently_accessed(None, media_id)
else: else:
if self.dynamic_thumbnails: if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail( await self._select_or_generate_remote_thumbnail(
request, server_name, media_id, width, height, method, m_type request, server_name, media_id, width, height, method, m_type
) )
else: else:
yield self._respond_remote_thumbnail( await self._respond_remote_thumbnail(
request, server_name, media_id, width, height, method, m_type request, server_name, media_id, width, height, method, m_type
) )
self.media_repo.mark_recently_accessed(server_name, media_id) self.media_repo.mark_recently_accessed(server_name, media_id)

View File

@ -15,22 +15,24 @@
import logging import logging
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_json, wrap_json_request_handler from synapse.http.server import (
DirectServeResource,
respond_with_json,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UploadResource(Resource): class UploadResource(DirectServeResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo): def __init__(self, hs, media_repo):
Resource.__init__(self) super().__init__()
self.media_repo = media_repo self.media_repo = media_repo
self.filepaths = media_repo.filepaths self.filepaths = media_repo.filepaths
@ -41,18 +43,13 @@ class UploadResource(Resource):
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock() self.clock = hs.get_clock()
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
def render_OPTIONS(self, request): def render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET return NOT_DONE_YET
@wrap_json_request_handler @wrap_json_request_handler
@defer.inlineCallbacks async def _async_render_POST(self, request):
def _async_render_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point
content_length = request.getHeader(b"Content-Length").decode("ascii") content_length = request.getHeader(b"Content-Length").decode("ascii")
@ -81,7 +78,7 @@ class UploadResource(Resource):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0] # disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
content_uri = yield self.media_repo.create_content( content_uri = await self.media_repo.create_content(
media_type, upload_name, request.content, content_length, requester.user media_type, upload_name, request.content, content_length, requester.user
) )

View File

@ -18,34 +18,27 @@ import logging
import saml2 import saml2
from saml2.client import Saml2Client from saml2.client import Saml2Client
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.server import wrap_html_request_handler from synapse.http.server import DirectServeResource, wrap_html_request_handler
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler from synapse.rest.client.v1.login import SSOAuthHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SAML2ResponseResource(Resource): class SAML2ResponseResource(DirectServeResource):
"""A Twisted web resource which handles the SAML response""" """A Twisted web resource which handles the SAML response"""
isLeaf = 1 isLeaf = 1
def __init__(self, hs): def __init__(self, hs):
Resource.__init__(self) super().__init__()
self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._sso_auth_handler = SSOAuthHandler(hs) self._sso_auth_handler = SSOAuthHandler(hs)
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@wrap_html_request_handler @wrap_html_request_handler
def _async_render_POST(self, request): async def _async_render_POST(self, request):
resp_bytes = parse_string(request, "SAMLResponse", required=True) resp_bytes = parse_string(request, "SAMLResponse", required=True)
relay_state = parse_string(request, "RelayState", required=True) relay_state = parse_string(request, "RelayState", required=True)

View File

@ -22,7 +22,6 @@ from binascii import unhexlify
from mock import Mock from mock import Mock
from six.moves.urllib import parse from six.moves.urllib import parse
from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from synapse.rest.media.v1._base import FileInfo from synapse.rest.media.v1._base import FileInfo
@ -34,15 +33,17 @@ from synapse.util.logcontext import make_deferred_yieldable
from tests import unittest from tests import unittest
class MediaStorageTests(unittest.TestCase): class MediaStorageTests(unittest.HomeserverTestCase):
def setUp(self):
needs_threadpool = True
def prepare(self, reactor, clock, hs):
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-") self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
self.addCleanup(shutil.rmtree, self.test_dir)
self.primary_base_path = os.path.join(self.test_dir, "primary") self.primary_base_path = os.path.join(self.test_dir, "primary")
self.secondary_base_path = os.path.join(self.test_dir, "secondary") self.secondary_base_path = os.path.join(self.test_dir, "secondary")
hs = Mock()
hs.get_reactor = Mock(return_value=reactor)
hs.config.media_store_path = self.primary_base_path hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)] storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
@ -52,10 +53,6 @@ class MediaStorageTests(unittest.TestCase):
hs, self.primary_base_path, self.filepaths, storage_providers hs, self.primary_base_path, self.filepaths, storage_providers
) )
def tearDown(self):
shutil.rmtree(self.test_dir)
@defer.inlineCallbacks
def test_ensure_media_is_in_local_cache(self): def test_ensure_media_is_in_local_cache(self):
media_id = "some_media_id" media_id = "some_media_id"
test_body = "Test\n" test_body = "Test\n"
@ -73,7 +70,15 @@ class MediaStorageTests(unittest.TestCase):
# Now we run ensure_media_is_in_local_cache, which should copy the file # Now we run ensure_media_is_in_local_cache, which should copy the file
# to the local cache. # to the local cache.
file_info = FileInfo(None, media_id) file_info = FileInfo(None, media_id)
local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info)
# This uses a real blocking threadpool so we have to wait for it to be
# actually done :/
x = self.media_storage.ensure_media_is_in_local_cache(file_info)
# Hotloop until the threadpool does its job...
self.wait_on_thread(x)
local_path = self.get_success(x)
self.assertTrue(os.path.exists(local_path)) self.assertTrue(os.path.exists(local_path))

View File

@ -17,6 +17,7 @@ import gc
import hashlib import hashlib
import hmac import hmac
import logging import logging
import time
from mock import Mock from mock import Mock
@ -24,7 +25,8 @@ from canonicaljson import json
import twisted import twisted
import twisted.logger import twisted.logger
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred, succeed
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest from twisted.trial import unittest
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -164,6 +166,7 @@ class HomeserverTestCase(TestCase):
servlets = [] servlets = []
hijack_auth = True hijack_auth = True
needs_threadpool = False
def setUp(self): def setUp(self):
""" """
@ -192,16 +195,20 @@ class HomeserverTestCase(TestCase):
if self.hijack_auth: if self.hijack_auth:
def get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return succeed(
{
"user": UserID.from_string(self.helper.auth_user_id), "user": UserID.from_string(self.helper.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
)
def get_user_by_req(request, allow_guest=False, rights="access"): def get_user_by_req(request, allow_guest=False, rights="access"):
return create_requester( return succeed(
create_requester(
UserID.from_string(self.helper.auth_user_id), 1, False, None UserID.from_string(self.helper.auth_user_id), 1, False, None
) )
)
self.hs.get_auth().get_user_by_req = get_user_by_req self.hs.get_auth().get_user_by_req = get_user_by_req
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
@ -209,9 +216,26 @@ class HomeserverTestCase(TestCase):
return_value="1234" return_value="1234"
) )
if self.needs_threadpool:
self.reactor.threadpool = ThreadPool()
self.addCleanup(self.reactor.threadpool.stop)
self.reactor.threadpool.start()
if hasattr(self, "prepare"): if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs) self.prepare(self.reactor, self.clock, self.hs)
def wait_on_thread(self, deferred, timeout=10):
"""
Wait until a Deferred is done, where it's waiting on a real thread.
"""
start_time = time.time()
while not deferred.called:
if start_time + timeout < time.time():
raise ValueError("Timed out waiting for threadpool")
self.reactor.advance(0.01)
time.sleep(0.01)
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
""" """
Make and return a homeserver. Make and return a homeserver.