Merge different Resource implementation classes (#7732)

This commit is contained in:
Erik Johnston 2020-07-03 19:02:19 +01:00 committed by GitHub
parent 21a212f8e5
commit 5cdca53aa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 330 additions and 326 deletions

1
changelog.d/7732.bugfix Normal file
View File

@ -0,0 +1 @@
Fix "Tried to close a non-active scope!" error messages when opentracing is enabled.

View File

@ -361,11 +361,7 @@ class BaseFederationServlet(object):
continue continue
server.register_paths( server.register_paths(
method, method, (pattern,), self._wrap(code), self.__class__.__name__,
(pattern,),
self._wrap(code),
self.__class__.__name__,
trace=False,
) )

View File

@ -13,13 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.web.resource import Resource from synapse.http.server import DirectServeJsonResource
from twisted.web.server import NOT_DONE_YET
from synapse.http.server import wrap_json_request_handler
class AdditionalResource(Resource): class AdditionalResource(DirectServeJsonResource):
"""Resource wrapper for additional_resources """Resource wrapper for additional_resources
If the user has configured additional_resources, we need to wrap the If the user has configured additional_resources, we need to wrap the
@ -41,16 +38,10 @@ class AdditionalResource(Resource):
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request. function to be called to handle the request.
""" """
Resource.__init__(self) super().__init__()
self._handler = handler self._handler = handler
# required by the request_handler wrapper
self.clock = hs.get_clock()
def render(self, request):
self._async_render(request)
return NOT_DONE_YET
@wrap_json_request_handler
def _async_render(self, request): def _async_render(self, request):
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
return self._handler(request) return self._handler(request)

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import collections import collections
import html import html
import logging import logging
@ -21,7 +22,7 @@ import types
import urllib import urllib
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Awaitable, Callable, TypeVar, Union from typing import Any, Callable, Dict, Tuple, Union
import jinja2 import jinja2
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
""" """
def wrap_json_request_handler(h): def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Wraps a request handler method with exception handling. """Sends a JSON error response to clients.
Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
The handler must return a deferred or a coroutine. If the deferred succeeds
we assume that a response has been sent. If the deferred fails with a SynapseError we use
it to send a JSON response with the appropriate HTTP reponse code. If the
deferred fails with any other type of error we send a 500 reponse.
""" """
async def wrapped_request_handler(self, request): if f.check(SynapseError):
try: error_code = f.value.code
await h(self, request) error_dict = f.value.error_dict()
except SynapseError as e:
code = e.code
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
# Only respond with an error response if we haven't already started logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
# writing, otherwise lets just kill the connection else:
if request.startedWriting: error_code = 500
if request.transport: error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
code,
e.error_dict(),
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
except Exception: logger.error(
# failure.Failure() fishes the original Failure out "Failed handle request via %r: %r",
# of our stack, and thus gives us a sensible stack request.request_metrics.name,
# trace. request,
f = failure.Failure() exc_info=(f.type, f.value, f.getTracebackObject()),
logger.error( )
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
if request.startedWriting:
if request.transport:
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
500,
{"error": "Internal server error", "errcode": Codes.UNKNOWN},
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
return wrap_async_request_handler(wrapped_request_handler) # Only respond with an error response if we haven't already started writing,
# otherwise lets just kill the connection
if request.startedWriting:
TV = TypeVar("TV") if request.transport:
try:
request.transport.abortConnection()
def wrap_html_request_handler( except Exception:
h: Callable[[TV, SynapseRequest], Awaitable] # abortConnection throws if the connection is already closed
) -> Callable[[TV, SynapseRequest], Awaitable[None]]: pass
"""Wraps a request handler method with exception handling. else:
respond_with_json(
Also does the wrapping with request.processing as per wrap_async_request_handler. request,
error_code,
The handler method must have a signature of "handle_foo(self, request)", error_dict,
where "request" must be a SynapseRequest. send_cors=True,
""" pretty_print=_request_user_agent_is_curl(request),
)
async def wrapped_request_handler(self, request):
try:
await h(self, request)
except Exception:
f = failure.Failure()
return_html_error(f, request, HTML_ERROR_TEMPLATE)
return wrap_async_request_handler(wrapped_request_handler)
def return_html_error( def return_html_error(
@ -249,7 +194,113 @@ class HttpServer(object):
pass pass
class JsonResource(HttpServer, resource.Resource): class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
"""Base class for resources that have async handlers.
Sub classes can either implement `_async_render_<METHOD>` to handle
requests by method, or override `_async_render` to handle all requests.
Args:
extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling.
"""
def __init__(self, extract_context=False):
super().__init__()
self._extract_context = extract_context
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@wrap_async_request_handler
async def _async_render_wrapper(self, request):
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
try:
request.request_metrics.name = self.__class__.__name__
with trace_servlet(request, self._extract_context):
callback_return = await self._async_render(request)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
except Exception:
# failure.Failure() fishes the original Failure out
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
self._send_error_response(f, request)
async def _async_render(self, request):
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overriden in sub classes for
different routing.
"""
method_handler = getattr(
self, "_async_render_%s" % (request.method.decode("ascii"),), None
)
if method_handler:
raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return
return callback_return
_unrecognised_request_handler(request)
@abc.abstractmethod
def _send_response(
self, request: SynapseRequest, code: int, response_object: Any,
) -> None:
raise NotImplementedError()
@abc.abstractmethod
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
raise NotImplementedError()
class DirectServeJsonResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests,
formatting responses and errors as JSON.
"""
def _send_response(
self, request, code, response_object,
):
"""Implements _AsyncResource._send_response
"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
code,
response_object,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response
"""
return_json_error(f, request)
class JsonResource(DirectServeJsonResource):
""" This implements the HttpServer interface and provides JSON support for """ This implements the HttpServer interface and provides JSON support for
Resources. Resources.
@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
"_PathEntry", ["pattern", "callback", "servlet_classname"] "_PathEntry", ["pattern", "callback", "servlet_classname"]
) )
def __init__(self, hs, canonical_json=True): def __init__(self, hs, canonical_json=True, extract_context=False):
resource.Resource.__init__(self) super().__init__(extract_context)
self.canonical_json = canonical_json self.canonical_json = canonical_json
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.path_regexs = {} self.path_regexs = {}
self.hs = hs self.hs = hs
def register_paths( def register_paths(self, method, path_patterns, callback, servlet_classname):
self, method, path_patterns, callback, servlet_classname, trace=True
):
""" """
Registers a request handler against a regular expression. Later request URLs are Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate checked against these regular expressions in order to identify an appropriate
@ -295,74 +344,23 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname (str): The name of the handler to be used in prometheus servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs. and opentracing logs.
trace (bool): Whether we should start a span to trace the servlet.
""" """
method = method.encode("utf-8") # method is bytes on py3 method = method.encode("utf-8") # method is bytes on py3
if trace:
# We don't extract the context from the servlet because we can't
# trust the sender
callback = trace_servlet(servlet_classname)(callback)
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern) logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback, servlet_classname) self._PathEntry(path_pattern, callback, servlet_classname)
) )
def render(self, request): def _get_handler_for_request(
""" This gets called by twisted every time someone sends us a request. self, request: SynapseRequest
""" ) -> Tuple[Callable, str, Dict[str, str]]:
defer.ensureDeferred(self._async_render(request)) """Finds a callback method to handle the given request.
return NOT_DONE_YET
@wrap_json_request_handler
async def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
"""
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have a name for this handler in prometheus.
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict(
{
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
)
callback_return = callback(request, **kwargs)
# Is it synchronous? We'll allow this for now.
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await callback_return
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
def _get_handler_for_request(self, request):
"""Finds a callback method to handle the given request
Args:
request (twisted.web.http.Request):
Returns: Returns:
Tuple[Callable, str, dict[unicode, unicode]]: callback method, the A tuple of the callback to use, the name of the servlet, and the
label to use for that method in prometheus metrics, and the key word arguments to pass to the callback
dict mapping keys to path components as specified in the
handler's path match regexp.
The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
""" """
request_path = request.path.decode("ascii") request_path = request.path.decode("ascii")
@ -377,42 +375,59 @@ class JsonResource(HttpServer, resource.Resource):
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, "unrecognised_request_handler", {} return _unrecognised_request_handler, "unrecognised_request_handler", {}
def _send_response( async def _async_render(self, request):
self, request, code, response_json_object, response_code_message=None callback, servlet_classname, group_dict = self._get_handler_for_request(request)
):
# TODO: Only enable CORS for the requests that need it. # Make sure we have an appopriate name for this handler in prometheus
respond_with_json( # (rather than the default of JsonResource).
request, request.request_metrics.name = servlet_classname
code,
response_json_object, # Now trigger the callback. If it returns a response, we send it
send_cors=True, # here. If it throws an exception, that is handled by the wrapper
response_code_message=response_code_message, # installed by @request_handler.
pretty_print=_request_user_agent_is_curl(request), kwargs = intern_dict(
canonical_json=self.canonical_json, {
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
) )
raw_callback_return = callback(request, **kwargs)
class DirectServeResource(resource.Resource): # Is it synchronous? We'll allow this for now.
def render(self, request): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return
return callback_return
class DirectServeHtmlResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests,
formatting responses and errors as HTML.
"""
# The error template to use for this resource
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response(
self, request: SynapseRequest, code: int, response_object: Any,
):
"""Implements _AsyncResource._send_response
""" """
Render the request, using an asynchronous render handler if it exists. # We expect to get bytes for us to write
assert isinstance(response_object, bytes)
html_bytes = response_object
respond_with_html_bytes(request, 200, html_bytes)
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response
""" """
async_render_callback_name = "_async_render_" + request.method.decode("ascii") return_html_error(f, request, self.ERROR_TEMPLATE)
# Try and get the async renderer
callback = getattr(self, async_render_callback_name, None)
# No async renderer for this request method.
if not callback:
return super().render(request)
resp = trace_servlet(self.__class__.__name__)(callback)(request)
# If it's a coroutine, turn it into a Deferred
if isinstance(resp, types.CoroutineType):
defer.ensureDeferred(resp)
return NOT_DONE_YET
class StaticResource(File): class StaticResource(File):

View File

@ -169,7 +169,6 @@ import contextlib
import inspect import inspect
import logging import logging
import re import re
import types
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Type from typing import TYPE_CHECKING, Dict, Optional, Type
@ -182,6 +181,7 @@ from synapse.config import ConfigError
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.http.site import SynapseRequest
# Helper class # Helper class
@ -793,48 +793,42 @@ def tag_args(func):
return _tag_args_inner return _tag_args_inner
def trace_servlet(servlet_name, extract_context=False): @contextlib.contextmanager
"""Decorator which traces a serlet. It starts a span with some servlet specific def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
tags such as the servlet_name and request information """Returns a context manager which traces a request. It starts a span
with some servlet specific tags such as the request metrics name and
request information.
Args: Args:
servlet_name (str): The name to be used for the span's operation_name request
extract_context (bool): Whether to attempt to extract the opentracing extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling. context from the request the servlet is handling.
""" """
def _trace_servlet_inner_1(func): if opentracing is None:
if not opentracing: yield
return func return
@wraps(func) request_tags = {
async def _trace_servlet_inner(request, *args, **kwargs): "request_id": request.get_request_id(),
request_tags = { tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
"request_id": request.get_request_id(), tags.HTTP_METHOD: request.get_method(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER, tags.HTTP_URL: request.get_redacted_uri(),
tags.HTTP_METHOD: request.get_method(), tags.PEER_HOST_IPV6: request.getClientIP(),
tags.HTTP_URL: request.get_redacted_uri(), }
tags.PEER_HOST_IPV6: request.getClientIP(),
}
if extract_context: request_name = request.request_metrics.name
scope = start_active_span_from_request( if extract_context:
request, servlet_name, tags=request_tags scope = start_active_span_from_request(request, request_name, tags=request_tags)
) else:
else: scope = start_active_span(request_name, tags=request_tags)
scope = start_active_span(servlet_name, tags=request_tags)
with scope: with scope:
result = func(request, *args, **kwargs) try:
yield
finally:
# We set the operation name again in case its changed (which happens
# with JsonResource).
scope.span.set_operation_name(request.request_metrics.name)
if not isinstance(result, (types.CoroutineType, defer.Deferred)): scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
# Some servlets aren't async and just return results
# directly, so we handle that here.
return result
return await result
return _trace_servlet_inner
return _trace_servlet_inner_1

View File

@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
class ReplicationRestResource(JsonResource): class ReplicationRestResource(JsonResource):
def __init__(self, hs): def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False) # We enable extracting jaeger contexts here as these are internal APIs.
super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs) self.register_servlets(hs)
def register_servlets(self, hs): def register_servlets(self, hs):

View File

@ -28,11 +28,7 @@ from synapse.api.errors import (
RequestSendFailed, RequestSendFailed,
SynapseError, SynapseError,
) )
from synapse.logging.opentracing import ( from synapse.logging.opentracing import inject_active_span_byte_dict, trace
inject_active_span_byte_dict,
trace,
trace_servlet,
)
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -240,11 +236,8 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
# We don't let register paths trace this servlet using the default tracing
# options because we wish to extract the context explicitly.
http_server.register_paths( http_server.register_paths(
method, [pattern], handler, self.__class__.__name__, trace=False method, [pattern], handler, self.__class__.__name__,
) )
def _cached_handler(self, request, txn_id, **kwargs): def _cached_handler(self, request, txn_id, **kwargs):

View File

@ -26,11 +26,7 @@ from twisted.internet import defer
from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.server import ( from synapse.http.server import DirectServeHtmlResource, respond_with_html
DirectServeResource,
respond_with_html,
wrap_html_request_handler,
)
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.types import UserID from synapse.types import UserID
@ -48,7 +44,7 @@ else:
return a == b return a == b
class ConsentResource(DirectServeResource): class ConsentResource(DirectServeHtmlResource):
"""A twisted Resource to display a privacy policy and gather consent to it """A twisted Resource to display a privacy policy and gather consent to it
When accessed via GET, returns the privacy policy via a template. When accessed via GET, returns the privacy policy via a template.
@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8") self._hmac_secret = hs.config.form_secret.encode("utf-8")
@wrap_html_request_handler
async def _async_render_GET(self, request): async def _async_render_GET(self, request):
""" """
Args: Args:
@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource):
except TemplateNotFound: except TemplateNotFound:
raise NotFoundError("Unknown policy version") raise NotFoundError("Unknown policy version")
@wrap_html_request_handler
async def _async_render_POST(self, request): async def _async_render_POST(self, request):
""" """
Args: Args:

View File

@ -20,17 +20,13 @@ from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import ( from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
DirectServeResource,
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RemoteKey(DirectServeResource): class RemoteKey(DirectServeJsonResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature """HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs):
super().__init__()
self.fetcher = ServerKeyFetcher(hs) self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config self.config = hs.config
@wrap_json_request_handler
async def _async_render_GET(self, request): async def _async_render_GET(self, request):
if len(request.postpath) == 1: if len(request.postpath) == 1:
(server,) = request.postpath (server,) = request.postpath
@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True) await self.query_keys(request, query, query_remote_on_cache_miss=True)
@wrap_json_request_handler
async def _async_render_POST(self, request): async def _async_render_POST(self, request):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)

View File

@ -14,16 +14,10 @@
# limitations under the License. # limitations under the License.
# #
from twisted.web.server import NOT_DONE_YET from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.server import (
DirectServeResource,
respond_with_json,
wrap_json_request_handler,
)
class MediaConfigResource(DirectServeResource): class MediaConfigResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs):
@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size} self.limits_dict = {"m.upload.size": config.max_upload_size}
@wrap_json_request_handler
async def _async_render_GET(self, request): async def _async_render_GET(self, request):
await self.auth.get_user_by_req(request) await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True) respond_with_json(request, 200, self.limits_dict, send_cors=True)
def render_OPTIONS(self, request): async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET

View File

@ -15,18 +15,14 @@
import logging import logging
import synapse.http.servlet import synapse.http.servlet
from synapse.http.server import ( from synapse.http.server import DirectServeJsonResource, set_cors_headers
DirectServeResource,
set_cors_headers,
wrap_json_request_handler,
)
from ._base import parse_media_id, respond_404 from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DownloadResource(DirectServeResource): class DownloadResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo): def __init__(self, hs, media_repo):
@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
self.media_repo = media_repo self.media_repo = media_repo
self.server_name = hs.hostname self.server_name = hs.hostname
# this is expected by @wrap_json_request_handler
self.clock = hs.get_clock()
@wrap_json_request_handler
async def _async_render_GET(self, request): async def _async_render_GET(self, request):
set_cors_headers(request) set_cors_headers(request)
request.setHeader( request.setHeader(

View File

@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.http.server import ( from synapse.http.server import (
DirectServeResource, DirectServeJsonResource,
respond_with_json, respond_with_json,
respond_with_json_bytes, respond_with_json_bytes,
wrap_json_request_handler,
) )
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000 OG_TAG_VALUE_MAXLEN = 1000
class PreviewUrlResource(DirectServeResource): class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo, media_storage): def __init__(self, hs, media_repo, media_storage):
@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource):
self._start_expire_url_cache_data, 10 * 1000 self._start_expire_url_cache_data, 10 * 1000
) )
def render_OPTIONS(self, request): async def _async_render_OPTIONS(self, request):
request.setHeader(b"Allow", b"OPTIONS, GET") request.setHeader(b"Allow", b"OPTIONS, GET")
return respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
@wrap_json_request_handler
async def _async_render_GET(self, request): async def _async_render_GET(self, request):
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?

View File

@ -16,11 +16,7 @@
import logging import logging
from synapse.http.server import ( from synapse.http.server import DirectServeJsonResource, set_cors_headers
DirectServeResource,
set_cors_headers,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from ._base import ( from ._base import (
@ -34,7 +30,7 @@ from ._base import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeResource): class ThumbnailResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo, media_storage): def __init__(self, hs, media_repo, media_storage):
@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
self.media_storage = media_storage self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname self.server_name = hs.hostname
self.clock = hs.get_clock()
@wrap_json_request_handler
async def _async_render_GET(self, request): async def _async_render_GET(self, request):
set_cors_headers(request) set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request) server_name, media_id, _ = parse_media_id(request)

View File

@ -15,20 +15,14 @@
import logging import logging
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import ( from synapse.http.server import DirectServeJsonResource, respond_with_json
DirectServeResource,
respond_with_json,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UploadResource(DirectServeResource): class UploadResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo): def __init__(self, hs, media_repo):
@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock() self.clock = hs.get_clock()
def render_OPTIONS(self, request): async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
@wrap_json_request_handler
async def _async_render_POST(self, request): async def _async_render_POST(self, request):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have

View File

@ -14,18 +14,17 @@
# limitations under the License. # limitations under the License.
import logging import logging
from synapse.http.server import DirectServeResource, wrap_html_request_handler from synapse.http.server import DirectServeHtmlResource
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OIDCCallbackResource(DirectServeResource): class OIDCCallbackResource(DirectServeHtmlResource):
isLeaf = 1 isLeaf = 1
def __init__(self, hs): def __init__(self, hs):
super().__init__() super().__init__()
self._oidc_handler = hs.get_oidc_handler() self._oidc_handler = hs.get_oidc_handler()
@wrap_html_request_handler
async def _async_render_GET(self, request): async def _async_render_GET(self, request):
return await self._oidc_handler.handle_oidc_callback(request) await self._oidc_handler.handle_oidc_callback(request)

View File

@ -16,10 +16,10 @@
from twisted.python import failure from twisted.python import failure
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeResource, return_html_error from synapse.http.server import DirectServeHtmlResource, return_html_error
class SAML2ResponseResource(DirectServeResource): class SAML2ResponseResource(DirectServeHtmlResource):
"""A Twisted web resource which handles the SAML response""" """A Twisted web resource which handles the SAML response"""
isLeaf = 1 isLeaf = 1

View File

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json
from tests.unittest import HomeserverTestCase
class _AsyncTestCustomEndpoint:
def __init__(self, config, module_api):
pass
async def handle_request(self, request):
respond_with_json(request, 200, {"some_key": "some_value_async"})
class _SyncTestCustomEndpoint:
def __init__(self, config, module_api):
pass
async def handle_request(self, request):
respond_with_json(request, 200, {"some_key": "some_value_sync"})
class AdditionalResourceTests(HomeserverTestCase):
"""Very basic tests that `AdditionalResource` works correctly with sync
and async handlers.
"""
def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/")
self.render(request)
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/")
self.render(request)
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

View File

@ -24,12 +24,7 @@ from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.config.server import parse_listener_def from synapse.config.server import parse_listener_def
from synapse.http.server import ( from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
DirectServeResource,
JsonResource,
OptionsResource,
wrap_html_request_handler,
)
from synapse.http.site import SynapseSite, logger from synapse.http.site import SynapseSite, logger
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock from synapse.util import Clock
@ -256,12 +251,11 @@ class OptionsResourceTests(unittest.TestCase):
class WrapHtmlRequestHandlerTests(unittest.TestCase): class WrapHtmlRequestHandlerTests(unittest.TestCase):
class TestResource(DirectServeResource): class TestResource(DirectServeHtmlResource):
callback = None callback = None
@wrap_html_request_handler
async def _async_render_GET(self, request): async def _async_render_GET(self, request):
return await self.callback(request) await self.callback(request)
def setUp(self): def setUp(self):
self.reactor = ThreadedMemoryReactorClock() self.reactor = ThreadedMemoryReactorClock()