Merge different Resource implementation classes (#7732)

This commit is contained in:
Erik Johnston 2020-07-03 19:02:19 +01:00 committed by GitHub
parent 21a212f8e5
commit 5cdca53aa0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 330 additions and 326 deletions

View file

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import collections
import html
import logging
@ -21,7 +22,7 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
from typing import Awaitable, Callable, TypeVar, Union
from typing import Any, Callable, Dict, Tuple, Union
import jinja2
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
"""
def wrap_json_request_handler(h):
"""Wraps a request handler method with exception handling.
Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
The handler must return a deferred or a coroutine. If the deferred succeeds
we assume that a response has been sent. If the deferred fails with a SynapseError we use
it to send a JSON response with the appropriate HTTP reponse code. If the
deferred fails with any other type of error we send a 500 reponse.
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients.
"""
async def wrapped_request_handler(self, request):
try:
await h(self, request)
except SynapseError as e:
code = e.code
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
if f.check(SynapseError):
error_code = f.value.code
error_dict = f.value.error_dict()
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
if request.startedWriting:
if request.transport:
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
code,
e.error_dict(),
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
else:
error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
except Exception:
# failure.Failure() fishes the original Failure out
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
logger.error(
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
if request.startedWriting:
if request.transport:
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
500,
{"error": "Internal server error", "errcode": Codes.UNKNOWN},
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
logger.error(
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
)
return wrap_async_request_handler(wrapped_request_handler)
TV = TypeVar("TV")
def wrap_html_request_handler(
h: Callable[[TV, SynapseRequest], Awaitable]
) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
"""Wraps a request handler method with exception handling.
Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
"""
async def wrapped_request_handler(self, request):
try:
await h(self, request)
except Exception:
f = failure.Failure()
return_html_error(f, request, HTML_ERROR_TEMPLATE)
return wrap_async_request_handler(wrapped_request_handler)
# Only respond with an error response if we haven't already started writing,
# otherwise lets just kill the connection
if request.startedWriting:
if request.transport:
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
error_code,
error_dict,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
def return_html_error(
@ -249,7 +194,113 @@ class HttpServer(object):
pass
class JsonResource(HttpServer, resource.Resource):
class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
"""Base class for resources that have async handlers.
Sub classes can either implement `_async_render_<METHOD>` to handle
requests by method, or override `_async_render` to handle all requests.
Args:
extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling.
"""
def __init__(self, extract_context=False):
super().__init__()
self._extract_context = extract_context
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@wrap_async_request_handler
async def _async_render_wrapper(self, request):
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
try:
request.request_metrics.name = self.__class__.__name__
with trace_servlet(request, self._extract_context):
callback_return = await self._async_render(request)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
except Exception:
# failure.Failure() fishes the original Failure out
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
self._send_error_response(f, request)
async def _async_render(self, request):
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overriden in sub classes for
different routing.
"""
method_handler = getattr(
self, "_async_render_%s" % (request.method.decode("ascii"),), None
)
if method_handler:
raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return
return callback_return
_unrecognised_request_handler(request)
@abc.abstractmethod
def _send_response(
self, request: SynapseRequest, code: int, response_object: Any,
) -> None:
raise NotImplementedError()
@abc.abstractmethod
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
raise NotImplementedError()
class DirectServeJsonResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests,
formatting responses and errors as JSON.
"""
def _send_response(
self, request, code, response_object,
):
"""Implements _AsyncResource._send_response
"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
code,
response_object,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response
"""
return_json_error(f, request)
class JsonResource(DirectServeJsonResource):
""" This implements the HttpServer interface and provides JSON support for
Resources.
@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
def __init__(self, hs, canonical_json=True):
resource.Resource.__init__(self)
def __init__(self, hs, canonical_json=True, extract_context=False):
super().__init__(extract_context)
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
self.hs = hs
def register_paths(
self, method, path_patterns, callback, servlet_classname, trace=True
):
def register_paths(self, method, path_patterns, callback, servlet_classname):
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
@ -295,74 +344,23 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
trace (bool): Whether we should start a span to trace the servlet.
"""
method = method.encode("utf-8") # method is bytes on py3
if trace:
# We don't extract the context from the servlet because we can't
# trust the sender
callback = trace_servlet(servlet_classname)(callback)
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback, servlet_classname)
)
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
defer.ensureDeferred(self._async_render(request))
return NOT_DONE_YET
@wrap_json_request_handler
async def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
"""
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have a name for this handler in prometheus.
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict(
{
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
)
callback_return = callback(request, **kwargs)
# Is it synchronous? We'll allow this for now.
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await callback_return
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
def _get_handler_for_request(self, request):
"""Finds a callback method to handle the given request
Args:
request (twisted.web.http.Request):
def _get_handler_for_request(
self, request: SynapseRequest
) -> Tuple[Callable, str, Dict[str, str]]:
"""Finds a callback method to handle the given request.
Returns:
Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
label to use for that method in prometheus metrics, and the
dict mapping keys to path components as specified in the
handler's path match regexp.
The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback
"""
request_path = request.path.decode("ascii")
@ -377,42 +375,59 @@ class JsonResource(HttpServer, resource.Resource):
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, "unrecognised_request_handler", {}
def _send_response(
self, request, code, response_json_object, response_code_message=None
):
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
code,
response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
async def _async_render(self, request):
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have an appopriate name for this handler in prometheus
# (rather than the default of JsonResource).
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict(
{
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
)
raw_callback_return = callback(request, **kwargs)
class DirectServeResource(resource.Resource):
def render(self, request):
# Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return
return callback_return
class DirectServeHtmlResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests,
formatting responses and errors as HTML.
"""
# The error template to use for this resource
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response(
self, request: SynapseRequest, code: int, response_object: Any,
):
"""Implements _AsyncResource._send_response
"""
Render the request, using an asynchronous render handler if it exists.
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
html_bytes = response_object
respond_with_html_bytes(request, 200, html_bytes)
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response
"""
async_render_callback_name = "_async_render_" + request.method.decode("ascii")
# Try and get the async renderer
callback = getattr(self, async_render_callback_name, None)
# No async renderer for this request method.
if not callback:
return super().render(request)
resp = trace_servlet(self.__class__.__name__)(callback)(request)
# If it's a coroutine, turn it into a Deferred
if isinstance(resp, types.CoroutineType):
defer.ensureDeferred(resp)
return NOT_DONE_YET
return_html_error(f, request, self.ERROR_TEMPLATE)
class StaticResource(File):