mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Merge different Resource implementation classes (#7732)
This commit is contained in:
parent
21a212f8e5
commit
5cdca53aa0
1
changelog.d/7732.bugfix
Normal file
1
changelog.d/7732.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix "Tried to close a non-active scope!" error messages when opentracing is enabled.
|
@ -361,11 +361,7 @@ class BaseFederationServlet(object):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
server.register_paths(
|
server.register_paths(
|
||||||
method,
|
method, (pattern,), self._wrap(code), self.__class__.__name__,
|
||||||
(pattern,),
|
|
||||||
self._wrap(code),
|
|
||||||
self.__class__.__name__,
|
|
||||||
trace=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,13 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
from synapse.http.server import DirectServeJsonResource
|
||||||
from twisted.web.server import NOT_DONE_YET
|
|
||||||
|
|
||||||
from synapse.http.server import wrap_json_request_handler
|
|
||||||
|
|
||||||
|
|
||||||
class AdditionalResource(Resource):
|
class AdditionalResource(DirectServeJsonResource):
|
||||||
"""Resource wrapper for additional_resources
|
"""Resource wrapper for additional_resources
|
||||||
|
|
||||||
If the user has configured additional_resources, we need to wrap the
|
If the user has configured additional_resources, we need to wrap the
|
||||||
@ -41,16 +38,10 @@ class AdditionalResource(Resource):
|
|||||||
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
|
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
|
||||||
function to be called to handle the request.
|
function to be called to handle the request.
|
||||||
"""
|
"""
|
||||||
Resource.__init__(self)
|
super().__init__()
|
||||||
self._handler = handler
|
self._handler = handler
|
||||||
|
|
||||||
# required by the request_handler wrapper
|
|
||||||
self.clock = hs.get_clock()
|
|
||||||
|
|
||||||
def render(self, request):
|
|
||||||
self._async_render(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
|
||||||
def _async_render(self, request):
|
def _async_render(self, request):
|
||||||
|
# Cheekily pass the result straight through, so we don't need to worry
|
||||||
|
# if its an awaitable or not.
|
||||||
return self._handler(request)
|
return self._handler(request)
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import abc
|
||||||
import collections
|
import collections
|
||||||
import html
|
import html
|
||||||
import logging
|
import logging
|
||||||
@ -21,7 +22,7 @@ import types
|
|||||||
import urllib
|
import urllib
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Awaitable, Callable, TypeVar, Union
|
from typing import Any, Callable, Dict, Tuple, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
||||||
@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def wrap_json_request_handler(h):
|
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||||
"""Wraps a request handler method with exception handling.
|
"""Sends a JSON error response to clients.
|
||||||
|
|
||||||
Also does the wrapping with request.processing as per wrap_async_request_handler.
|
|
||||||
|
|
||||||
The handler method must have a signature of "handle_foo(self, request)",
|
|
||||||
where "request" must be a SynapseRequest.
|
|
||||||
|
|
||||||
The handler must return a deferred or a coroutine. If the deferred succeeds
|
|
||||||
we assume that a response has been sent. If the deferred fails with a SynapseError we use
|
|
||||||
it to send a JSON response with the appropriate HTTP reponse code. If the
|
|
||||||
deferred fails with any other type of error we send a 500 reponse.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def wrapped_request_handler(self, request):
|
if f.check(SynapseError):
|
||||||
try:
|
error_code = f.value.code
|
||||||
await h(self, request)
|
error_dict = f.value.error_dict()
|
||||||
except SynapseError as e:
|
|
||||||
code = e.code
|
|
||||||
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
|
|
||||||
|
|
||||||
# Only respond with an error response if we haven't already started
|
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
|
||||||
# writing, otherwise lets just kill the connection
|
else:
|
||||||
if request.startedWriting:
|
error_code = 500
|
||||||
if request.transport:
|
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
|
||||||
try:
|
|
||||||
request.transport.abortConnection()
|
|
||||||
except Exception:
|
|
||||||
# abortConnection throws if the connection is already closed
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
respond_with_json(
|
|
||||||
request,
|
|
||||||
code,
|
|
||||||
e.error_dict(),
|
|
||||||
send_cors=True,
|
|
||||||
pretty_print=_request_user_agent_is_curl(request),
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception:
|
logger.error(
|
||||||
# failure.Failure() fishes the original Failure out
|
"Failed handle request via %r: %r",
|
||||||
# of our stack, and thus gives us a sensible stack
|
request.request_metrics.name,
|
||||||
# trace.
|
request,
|
||||||
f = failure.Failure()
|
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||||
logger.error(
|
)
|
||||||
"Failed handle request via %r: %r",
|
|
||||||
request.request_metrics.name,
|
|
||||||
request,
|
|
||||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
|
||||||
)
|
|
||||||
# Only respond with an error response if we haven't already started
|
|
||||||
# writing, otherwise lets just kill the connection
|
|
||||||
if request.startedWriting:
|
|
||||||
if request.transport:
|
|
||||||
try:
|
|
||||||
request.transport.abortConnection()
|
|
||||||
except Exception:
|
|
||||||
# abortConnection throws if the connection is already closed
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
respond_with_json(
|
|
||||||
request,
|
|
||||||
500,
|
|
||||||
{"error": "Internal server error", "errcode": Codes.UNKNOWN},
|
|
||||||
send_cors=True,
|
|
||||||
pretty_print=_request_user_agent_is_curl(request),
|
|
||||||
)
|
|
||||||
|
|
||||||
return wrap_async_request_handler(wrapped_request_handler)
|
# Only respond with an error response if we haven't already started writing,
|
||||||
|
# otherwise lets just kill the connection
|
||||||
|
if request.startedWriting:
|
||||||
TV = TypeVar("TV")
|
if request.transport:
|
||||||
|
try:
|
||||||
|
request.transport.abortConnection()
|
||||||
def wrap_html_request_handler(
|
except Exception:
|
||||||
h: Callable[[TV, SynapseRequest], Awaitable]
|
# abortConnection throws if the connection is already closed
|
||||||
) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
|
pass
|
||||||
"""Wraps a request handler method with exception handling.
|
else:
|
||||||
|
respond_with_json(
|
||||||
Also does the wrapping with request.processing as per wrap_async_request_handler.
|
request,
|
||||||
|
error_code,
|
||||||
The handler method must have a signature of "handle_foo(self, request)",
|
error_dict,
|
||||||
where "request" must be a SynapseRequest.
|
send_cors=True,
|
||||||
"""
|
pretty_print=_request_user_agent_is_curl(request),
|
||||||
|
)
|
||||||
async def wrapped_request_handler(self, request):
|
|
||||||
try:
|
|
||||||
await h(self, request)
|
|
||||||
except Exception:
|
|
||||||
f = failure.Failure()
|
|
||||||
return_html_error(f, request, HTML_ERROR_TEMPLATE)
|
|
||||||
|
|
||||||
return wrap_async_request_handler(wrapped_request_handler)
|
|
||||||
|
|
||||||
|
|
||||||
def return_html_error(
|
def return_html_error(
|
||||||
@ -249,7 +194,113 @@ class HttpServer(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class JsonResource(HttpServer, resource.Resource):
|
class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||||
|
"""Base class for resources that have async handlers.
|
||||||
|
|
||||||
|
Sub classes can either implement `_async_render_<METHOD>` to handle
|
||||||
|
requests by method, or override `_async_render` to handle all requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extract_context: Whether to attempt to extract the opentracing
|
||||||
|
context from the request the servlet is handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, extract_context=False):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._extract_context = extract_context
|
||||||
|
|
||||||
|
def render(self, request):
|
||||||
|
""" This gets called by twisted every time someone sends us a request.
|
||||||
|
"""
|
||||||
|
defer.ensureDeferred(self._async_render_wrapper(request))
|
||||||
|
return NOT_DONE_YET
|
||||||
|
|
||||||
|
@wrap_async_request_handler
|
||||||
|
async def _async_render_wrapper(self, request):
|
||||||
|
"""This is a wrapper that delegates to `_async_render` and handles
|
||||||
|
exceptions, return values, metrics, etc.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
request.request_metrics.name = self.__class__.__name__
|
||||||
|
|
||||||
|
with trace_servlet(request, self._extract_context):
|
||||||
|
callback_return = await self._async_render(request)
|
||||||
|
|
||||||
|
if callback_return is not None:
|
||||||
|
code, response = callback_return
|
||||||
|
self._send_response(request, code, response)
|
||||||
|
except Exception:
|
||||||
|
# failure.Failure() fishes the original Failure out
|
||||||
|
# of our stack, and thus gives us a sensible stack
|
||||||
|
# trace.
|
||||||
|
f = failure.Failure()
|
||||||
|
self._send_error_response(f, request)
|
||||||
|
|
||||||
|
async def _async_render(self, request):
|
||||||
|
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
|
||||||
|
no appropriate method exists. Can be overriden in sub classes for
|
||||||
|
different routing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
method_handler = getattr(
|
||||||
|
self, "_async_render_%s" % (request.method.decode("ascii"),), None
|
||||||
|
)
|
||||||
|
if method_handler:
|
||||||
|
raw_callback_return = method_handler(request)
|
||||||
|
|
||||||
|
# Is it synchronous? We'll allow this for now.
|
||||||
|
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
|
||||||
|
callback_return = await raw_callback_return
|
||||||
|
else:
|
||||||
|
callback_return = raw_callback_return
|
||||||
|
|
||||||
|
return callback_return
|
||||||
|
|
||||||
|
_unrecognised_request_handler(request)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _send_response(
|
||||||
|
self, request: SynapseRequest, code: int, response_object: Any,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _send_error_response(
|
||||||
|
self, f: failure.Failure, request: SynapseRequest,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class DirectServeJsonResource(_AsyncResource):
|
||||||
|
"""A resource that will call `self._async_on_<METHOD>` on new requests,
|
||||||
|
formatting responses and errors as JSON.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _send_response(
|
||||||
|
self, request, code, response_object,
|
||||||
|
):
|
||||||
|
"""Implements _AsyncResource._send_response
|
||||||
|
"""
|
||||||
|
# TODO: Only enable CORS for the requests that need it.
|
||||||
|
respond_with_json(
|
||||||
|
request,
|
||||||
|
code,
|
||||||
|
response_object,
|
||||||
|
send_cors=True,
|
||||||
|
pretty_print=_request_user_agent_is_curl(request),
|
||||||
|
canonical_json=self.canonical_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _send_error_response(
|
||||||
|
self, f: failure.Failure, request: SynapseRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Implements _AsyncResource._send_error_response
|
||||||
|
"""
|
||||||
|
return_json_error(f, request)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonResource(DirectServeJsonResource):
|
||||||
""" This implements the HttpServer interface and provides JSON support for
|
""" This implements the HttpServer interface and provides JSON support for
|
||||||
Resources.
|
Resources.
|
||||||
|
|
||||||
@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
"_PathEntry", ["pattern", "callback", "servlet_classname"]
|
"_PathEntry", ["pattern", "callback", "servlet_classname"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, hs, canonical_json=True):
|
def __init__(self, hs, canonical_json=True, extract_context=False):
|
||||||
resource.Resource.__init__(self)
|
super().__init__(extract_context)
|
||||||
|
|
||||||
self.canonical_json = canonical_json
|
self.canonical_json = canonical_json
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.path_regexs = {}
|
self.path_regexs = {}
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
||||||
def register_paths(
|
def register_paths(self, method, path_patterns, callback, servlet_classname):
|
||||||
self, method, path_patterns, callback, servlet_classname, trace=True
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Registers a request handler against a regular expression. Later request URLs are
|
Registers a request handler against a regular expression. Later request URLs are
|
||||||
checked against these regular expressions in order to identify an appropriate
|
checked against these regular expressions in order to identify an appropriate
|
||||||
@ -295,74 +344,23 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
|
|
||||||
servlet_classname (str): The name of the handler to be used in prometheus
|
servlet_classname (str): The name of the handler to be used in prometheus
|
||||||
and opentracing logs.
|
and opentracing logs.
|
||||||
|
|
||||||
trace (bool): Whether we should start a span to trace the servlet.
|
|
||||||
"""
|
"""
|
||||||
method = method.encode("utf-8") # method is bytes on py3
|
method = method.encode("utf-8") # method is bytes on py3
|
||||||
|
|
||||||
if trace:
|
|
||||||
# We don't extract the context from the servlet because we can't
|
|
||||||
# trust the sender
|
|
||||||
callback = trace_servlet(servlet_classname)(callback)
|
|
||||||
|
|
||||||
for path_pattern in path_patterns:
|
for path_pattern in path_patterns:
|
||||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||||
self.path_regexs.setdefault(method, []).append(
|
self.path_regexs.setdefault(method, []).append(
|
||||||
self._PathEntry(path_pattern, callback, servlet_classname)
|
self._PathEntry(path_pattern, callback, servlet_classname)
|
||||||
)
|
)
|
||||||
|
|
||||||
def render(self, request):
|
def _get_handler_for_request(
|
||||||
""" This gets called by twisted every time someone sends us a request.
|
self, request: SynapseRequest
|
||||||
"""
|
) -> Tuple[Callable, str, Dict[str, str]]:
|
||||||
defer.ensureDeferred(self._async_render(request))
|
"""Finds a callback method to handle the given request.
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
|
||||||
async def _async_render(self, request):
|
|
||||||
""" This gets called from render() every time someone sends us a request.
|
|
||||||
This checks if anyone has registered a callback for that method and
|
|
||||||
path.
|
|
||||||
"""
|
|
||||||
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
|
|
||||||
|
|
||||||
# Make sure we have a name for this handler in prometheus.
|
|
||||||
request.request_metrics.name = servlet_classname
|
|
||||||
|
|
||||||
# Now trigger the callback. If it returns a response, we send it
|
|
||||||
# here. If it throws an exception, that is handled by the wrapper
|
|
||||||
# installed by @request_handler.
|
|
||||||
kwargs = intern_dict(
|
|
||||||
{
|
|
||||||
name: urllib.parse.unquote(value) if value else value
|
|
||||||
for name, value in group_dict.items()
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
callback_return = callback(request, **kwargs)
|
|
||||||
|
|
||||||
# Is it synchronous? We'll allow this for now.
|
|
||||||
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
|
|
||||||
callback_return = await callback_return
|
|
||||||
|
|
||||||
if callback_return is not None:
|
|
||||||
code, response = callback_return
|
|
||||||
self._send_response(request, code, response)
|
|
||||||
|
|
||||||
def _get_handler_for_request(self, request):
|
|
||||||
"""Finds a callback method to handle the given request
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request (twisted.web.http.Request):
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
|
A tuple of the callback to use, the name of the servlet, and the
|
||||||
label to use for that method in prometheus metrics, and the
|
key word arguments to pass to the callback
|
||||||
dict mapping keys to path components as specified in the
|
|
||||||
handler's path match regexp.
|
|
||||||
|
|
||||||
The callback will normally be a method registered via
|
|
||||||
register_paths, so will return (possibly via Deferred) either
|
|
||||||
None, or a tuple of (http code, response body).
|
|
||||||
"""
|
"""
|
||||||
request_path = request.path.decode("ascii")
|
request_path = request.path.decode("ascii")
|
||||||
|
|
||||||
@ -377,42 +375,59 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||||
return _unrecognised_request_handler, "unrecognised_request_handler", {}
|
return _unrecognised_request_handler, "unrecognised_request_handler", {}
|
||||||
|
|
||||||
def _send_response(
|
async def _async_render(self, request):
|
||||||
self, request, code, response_json_object, response_code_message=None
|
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
|
||||||
):
|
|
||||||
# TODO: Only enable CORS for the requests that need it.
|
# Make sure we have an appopriate name for this handler in prometheus
|
||||||
respond_with_json(
|
# (rather than the default of JsonResource).
|
||||||
request,
|
request.request_metrics.name = servlet_classname
|
||||||
code,
|
|
||||||
response_json_object,
|
# Now trigger the callback. If it returns a response, we send it
|
||||||
send_cors=True,
|
# here. If it throws an exception, that is handled by the wrapper
|
||||||
response_code_message=response_code_message,
|
# installed by @request_handler.
|
||||||
pretty_print=_request_user_agent_is_curl(request),
|
kwargs = intern_dict(
|
||||||
canonical_json=self.canonical_json,
|
{
|
||||||
|
name: urllib.parse.unquote(value) if value else value
|
||||||
|
for name, value in group_dict.items()
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
raw_callback_return = callback(request, **kwargs)
|
||||||
|
|
||||||
class DirectServeResource(resource.Resource):
|
# Is it synchronous? We'll allow this for now.
|
||||||
def render(self, request):
|
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
|
||||||
|
callback_return = await raw_callback_return
|
||||||
|
else:
|
||||||
|
callback_return = raw_callback_return
|
||||||
|
|
||||||
|
return callback_return
|
||||||
|
|
||||||
|
|
||||||
|
class DirectServeHtmlResource(_AsyncResource):
|
||||||
|
"""A resource that will call `self._async_on_<METHOD>` on new requests,
|
||||||
|
formatting responses and errors as HTML.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The error template to use for this resource
|
||||||
|
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
|
||||||
|
|
||||||
|
def _send_response(
|
||||||
|
self, request: SynapseRequest, code: int, response_object: Any,
|
||||||
|
):
|
||||||
|
"""Implements _AsyncResource._send_response
|
||||||
"""
|
"""
|
||||||
Render the request, using an asynchronous render handler if it exists.
|
# We expect to get bytes for us to write
|
||||||
|
assert isinstance(response_object, bytes)
|
||||||
|
html_bytes = response_object
|
||||||
|
|
||||||
|
respond_with_html_bytes(request, 200, html_bytes)
|
||||||
|
|
||||||
|
def _send_error_response(
|
||||||
|
self, f: failure.Failure, request: SynapseRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Implements _AsyncResource._send_error_response
|
||||||
"""
|
"""
|
||||||
async_render_callback_name = "_async_render_" + request.method.decode("ascii")
|
return_html_error(f, request, self.ERROR_TEMPLATE)
|
||||||
|
|
||||||
# Try and get the async renderer
|
|
||||||
callback = getattr(self, async_render_callback_name, None)
|
|
||||||
|
|
||||||
# No async renderer for this request method.
|
|
||||||
if not callback:
|
|
||||||
return super().render(request)
|
|
||||||
|
|
||||||
resp = trace_servlet(self.__class__.__name__)(callback)(request)
|
|
||||||
|
|
||||||
# If it's a coroutine, turn it into a Deferred
|
|
||||||
if isinstance(resp, types.CoroutineType):
|
|
||||||
defer.ensureDeferred(resp)
|
|
||||||
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
|
|
||||||
class StaticResource(File):
|
class StaticResource(File):
|
||||||
|
@ -169,7 +169,6 @@ import contextlib
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import types
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Type
|
from typing import TYPE_CHECKING, Dict, Optional, Type
|
||||||
|
|
||||||
@ -182,6 +181,7 @@ from synapse.config import ConfigError
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
|
||||||
# Helper class
|
# Helper class
|
||||||
|
|
||||||
@ -793,48 +793,42 @@ def tag_args(func):
|
|||||||
return _tag_args_inner
|
return _tag_args_inner
|
||||||
|
|
||||||
|
|
||||||
def trace_servlet(servlet_name, extract_context=False):
|
@contextlib.contextmanager
|
||||||
"""Decorator which traces a serlet. It starts a span with some servlet specific
|
def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
|
||||||
tags such as the servlet_name and request information
|
"""Returns a context manager which traces a request. It starts a span
|
||||||
|
with some servlet specific tags such as the request metrics name and
|
||||||
|
request information.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
servlet_name (str): The name to be used for the span's operation_name
|
request
|
||||||
extract_context (bool): Whether to attempt to extract the opentracing
|
extract_context: Whether to attempt to extract the opentracing
|
||||||
context from the request the servlet is handling.
|
context from the request the servlet is handling.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _trace_servlet_inner_1(func):
|
if opentracing is None:
|
||||||
if not opentracing:
|
yield
|
||||||
return func
|
return
|
||||||
|
|
||||||
@wraps(func)
|
request_tags = {
|
||||||
async def _trace_servlet_inner(request, *args, **kwargs):
|
"request_id": request.get_request_id(),
|
||||||
request_tags = {
|
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
||||||
"request_id": request.get_request_id(),
|
tags.HTTP_METHOD: request.get_method(),
|
||||||
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
tags.HTTP_URL: request.get_redacted_uri(),
|
||||||
tags.HTTP_METHOD: request.get_method(),
|
tags.PEER_HOST_IPV6: request.getClientIP(),
|
||||||
tags.HTTP_URL: request.get_redacted_uri(),
|
}
|
||||||
tags.PEER_HOST_IPV6: request.getClientIP(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if extract_context:
|
request_name = request.request_metrics.name
|
||||||
scope = start_active_span_from_request(
|
if extract_context:
|
||||||
request, servlet_name, tags=request_tags
|
scope = start_active_span_from_request(request, request_name, tags=request_tags)
|
||||||
)
|
else:
|
||||||
else:
|
scope = start_active_span(request_name, tags=request_tags)
|
||||||
scope = start_active_span(servlet_name, tags=request_tags)
|
|
||||||
|
|
||||||
with scope:
|
with scope:
|
||||||
result = func(request, *args, **kwargs)
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# We set the operation name again in case its changed (which happens
|
||||||
|
# with JsonResource).
|
||||||
|
scope.span.set_operation_name(request.request_metrics.name)
|
||||||
|
|
||||||
if not isinstance(result, (types.CoroutineType, defer.Deferred)):
|
scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
|
||||||
# Some servlets aren't async and just return results
|
|
||||||
# directly, so we handle that here.
|
|
||||||
return result
|
|
||||||
|
|
||||||
return await result
|
|
||||||
|
|
||||||
return _trace_servlet_inner
|
|
||||||
|
|
||||||
return _trace_servlet_inner_1
|
|
||||||
|
@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
|
|||||||
|
|
||||||
class ReplicationRestResource(JsonResource):
|
class ReplicationRestResource(JsonResource):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
JsonResource.__init__(self, hs, canonical_json=False)
|
# We enable extracting jaeger contexts here as these are internal APIs.
|
||||||
|
super().__init__(hs, canonical_json=False, extract_context=True)
|
||||||
self.register_servlets(hs)
|
self.register_servlets(hs)
|
||||||
|
|
||||||
def register_servlets(self, hs):
|
def register_servlets(self, hs):
|
||||||
|
@ -28,11 +28,7 @@ from synapse.api.errors import (
|
|||||||
RequestSendFailed,
|
RequestSendFailed,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.logging.opentracing import (
|
from synapse.logging.opentracing import inject_active_span_byte_dict, trace
|
||||||
inject_active_span_byte_dict,
|
|
||||||
trace,
|
|
||||||
trace_servlet,
|
|
||||||
)
|
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
@ -240,11 +236,8 @@ class ReplicationEndpoint(object):
|
|||||||
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
|
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
|
||||||
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
|
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
|
||||||
|
|
||||||
handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
|
|
||||||
# We don't let register paths trace this servlet using the default tracing
|
|
||||||
# options because we wish to extract the context explicitly.
|
|
||||||
http_server.register_paths(
|
http_server.register_paths(
|
||||||
method, [pattern], handler, self.__class__.__name__, trace=False
|
method, [pattern], handler, self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _cached_handler(self, request, txn_id, **kwargs):
|
def _cached_handler(self, request, txn_id, **kwargs):
|
||||||
|
@ -26,11 +26,7 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from synapse.api.errors import NotFoundError, StoreError, SynapseError
|
from synapse.api.errors import NotFoundError, StoreError, SynapseError
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.http.server import (
|
from synapse.http.server import DirectServeHtmlResource, respond_with_html
|
||||||
DirectServeResource,
|
|
||||||
respond_with_html,
|
|
||||||
wrap_html_request_handler,
|
|
||||||
)
|
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
@ -48,7 +44,7 @@ else:
|
|||||||
return a == b
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
class ConsentResource(DirectServeResource):
|
class ConsentResource(DirectServeHtmlResource):
|
||||||
"""A twisted Resource to display a privacy policy and gather consent to it
|
"""A twisted Resource to display a privacy policy and gather consent to it
|
||||||
|
|
||||||
When accessed via GET, returns the privacy policy via a template.
|
When accessed via GET, returns the privacy policy via a template.
|
||||||
@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource):
|
|||||||
|
|
||||||
self._hmac_secret = hs.config.form_secret.encode("utf-8")
|
self._hmac_secret = hs.config.form_secret.encode("utf-8")
|
||||||
|
|
||||||
@wrap_html_request_handler
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource):
|
|||||||
except TemplateNotFound:
|
except TemplateNotFound:
|
||||||
raise NotFoundError("Unknown policy version")
|
raise NotFoundError("Unknown policy version")
|
||||||
|
|
||||||
@wrap_html_request_handler
|
|
||||||
async def _async_render_POST(self, request):
|
async def _async_render_POST(self, request):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -20,17 +20,13 @@ from signedjson.sign import sign_json
|
|||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.crypto.keyring import ServerKeyFetcher
|
from synapse.crypto.keyring import ServerKeyFetcher
|
||||||
from synapse.http.server import (
|
from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
|
||||||
DirectServeResource,
|
|
||||||
respond_with_json_bytes,
|
|
||||||
wrap_json_request_handler,
|
|
||||||
)
|
|
||||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RemoteKey(DirectServeResource):
|
class RemoteKey(DirectServeJsonResource):
|
||||||
"""HTTP resource for retreiving the TLS certificate and NACL signature
|
"""HTTP resource for retreiving the TLS certificate and NACL signature
|
||||||
verification keys for a collection of servers. Checks that the reported
|
verification keys for a collection of servers. Checks that the reported
|
||||||
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
|
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
|
||||||
@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
|
|||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.fetcher = ServerKeyFetcher(hs)
|
self.fetcher = ServerKeyFetcher(hs)
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
|
||||||
@wrap_json_request_handler
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
if len(request.postpath) == 1:
|
if len(request.postpath) == 1:
|
||||||
(server,) = request.postpath
|
(server,) = request.postpath
|
||||||
@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
|
|||||||
|
|
||||||
await self.query_keys(request, query, query_remote_on_cache_miss=True)
|
await self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||||
|
|
||||||
@wrap_json_request_handler
|
|
||||||
async def _async_render_POST(self, request):
|
async def _async_render_POST(self, request):
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
@ -14,16 +14,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
|
|
||||||
from synapse.http.server import (
|
|
||||||
DirectServeResource,
|
|
||||||
respond_with_json,
|
|
||||||
wrap_json_request_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MediaConfigResource(DirectServeResource):
|
class MediaConfigResource(DirectServeJsonResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
|
|||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||||
|
|
||||||
@wrap_json_request_handler
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
await self.auth.get_user_by_req(request)
|
await self.auth.get_user_by_req(request)
|
||||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||||
|
|
||||||
def render_OPTIONS(self, request):
|
async def _async_render_OPTIONS(self, request):
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
return NOT_DONE_YET
|
|
||||||
|
@ -15,18 +15,14 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import synapse.http.servlet
|
import synapse.http.servlet
|
||||||
from synapse.http.server import (
|
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||||
DirectServeResource,
|
|
||||||
set_cors_headers,
|
|
||||||
wrap_json_request_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ._base import parse_media_id, respond_404
|
from ._base import parse_media_id, respond_404
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DownloadResource(DirectServeResource):
|
class DownloadResource(DirectServeJsonResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo):
|
def __init__(self, hs, media_repo):
|
||||||
@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
|
|||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
# this is expected by @wrap_json_request_handler
|
|
||||||
self.clock = hs.get_clock()
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
request.setHeader(
|
request.setHeader(
|
||||||
|
@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError
|
|||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
from synapse.http.server import (
|
from synapse.http.server import (
|
||||||
DirectServeResource,
|
DirectServeJsonResource,
|
||||||
respond_with_json,
|
respond_with_json,
|
||||||
respond_with_json_bytes,
|
respond_with_json_bytes,
|
||||||
wrap_json_request_handler,
|
|
||||||
)
|
)
|
||||||
from synapse.http.servlet import parse_integer, parse_string
|
from synapse.http.servlet import parse_integer, parse_string
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50
|
|||||||
OG_TAG_VALUE_MAXLEN = 1000
|
OG_TAG_VALUE_MAXLEN = 1000
|
||||||
|
|
||||||
|
|
||||||
class PreviewUrlResource(DirectServeResource):
|
class PreviewUrlResource(DirectServeJsonResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo, media_storage):
|
def __init__(self, hs, media_repo, media_storage):
|
||||||
@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource):
|
|||||||
self._start_expire_url_cache_data, 10 * 1000
|
self._start_expire_url_cache_data, 10 * 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
def render_OPTIONS(self, request):
|
async def _async_render_OPTIONS(self, request):
|
||||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||||
return respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
@wrap_json_request_handler
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
|
|
||||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||||
|
@ -16,11 +16,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.http.server import (
|
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||||
DirectServeResource,
|
|
||||||
set_cors_headers,
|
|
||||||
wrap_json_request_handler,
|
|
||||||
)
|
|
||||||
from synapse.http.servlet import parse_integer, parse_string
|
from synapse.http.servlet import parse_integer, parse_string
|
||||||
|
|
||||||
from ._base import (
|
from ._base import (
|
||||||
@ -34,7 +30,7 @@ from ._base import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ThumbnailResource(DirectServeResource):
|
class ThumbnailResource(DirectServeJsonResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo, media_storage):
|
def __init__(self, hs, media_repo, media_storage):
|
||||||
@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
|
|||||||
self.media_storage = media_storage
|
self.media_storage = media_storage
|
||||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.clock = hs.get_clock()
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
server_name, media_id, _ = parse_media_id(request)
|
server_name, media_id, _ = parse_media_id(request)
|
||||||
|
@ -15,20 +15,14 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.server import (
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
DirectServeResource,
|
|
||||||
respond_with_json,
|
|
||||||
wrap_json_request_handler,
|
|
||||||
)
|
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UploadResource(DirectServeResource):
|
class UploadResource(DirectServeJsonResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo):
|
def __init__(self, hs, media_repo):
|
||||||
@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
|
|||||||
self.max_upload_size = hs.config.max_upload_size
|
self.max_upload_size = hs.config.max_upload_size
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
def render_OPTIONS(self, request):
|
async def _async_render_OPTIONS(self, request):
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
|
||||||
async def _async_render_POST(self, request):
|
async def _async_render_POST(self, request):
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
# TODO: The checks here are a bit late. The content will have
|
# TODO: The checks here are a bit late. The content will have
|
||||||
|
@ -14,18 +14,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from synapse.http.server import DirectServeResource, wrap_html_request_handler
|
from synapse.http.server import DirectServeHtmlResource
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OIDCCallbackResource(DirectServeResource):
|
class OIDCCallbackResource(DirectServeHtmlResource):
|
||||||
isLeaf = 1
|
isLeaf = 1
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._oidc_handler = hs.get_oidc_handler()
|
self._oidc_handler = hs.get_oidc_handler()
|
||||||
|
|
||||||
@wrap_html_request_handler
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
return await self._oidc_handler.handle_oidc_callback(request)
|
await self._oidc_handler.handle_oidc_callback(request)
|
||||||
|
@ -16,10 +16,10 @@
|
|||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.http.server import DirectServeResource, return_html_error
|
from synapse.http.server import DirectServeHtmlResource, return_html_error
|
||||||
|
|
||||||
|
|
||||||
class SAML2ResponseResource(DirectServeResource):
|
class SAML2ResponseResource(DirectServeHtmlResource):
|
||||||
"""A Twisted web resource which handles the SAML response"""
|
"""A Twisted web resource which handles the SAML response"""
|
||||||
|
|
||||||
isLeaf = 1
|
isLeaf = 1
|
||||||
|
62
tests/http/test_additional_resource.py
Normal file
62
tests/http/test_additional_resource.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from synapse.http.additional_resource import AdditionalResource
|
||||||
|
from synapse.http.server import respond_with_json
|
||||||
|
|
||||||
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class _AsyncTestCustomEndpoint:
|
||||||
|
def __init__(self, config, module_api):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def handle_request(self, request):
|
||||||
|
respond_with_json(request, 200, {"some_key": "some_value_async"})
|
||||||
|
|
||||||
|
|
||||||
|
class _SyncTestCustomEndpoint:
|
||||||
|
def __init__(self, config, module_api):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def handle_request(self, request):
|
||||||
|
respond_with_json(request, 200, {"some_key": "some_value_sync"})
|
||||||
|
|
||||||
|
|
||||||
|
class AdditionalResourceTests(HomeserverTestCase):
|
||||||
|
"""Very basic tests that `AdditionalResource` works correctly with sync
|
||||||
|
and async handlers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_async(self):
|
||||||
|
handler = _AsyncTestCustomEndpoint({}, None).handle_request
|
||||||
|
self.resource = AdditionalResource(self.hs, handler)
|
||||||
|
|
||||||
|
request, channel = self.make_request("GET", "/")
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(request.code, 200)
|
||||||
|
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
|
||||||
|
|
||||||
|
def test_sync(self):
|
||||||
|
handler = _SyncTestCustomEndpoint({}, None).handle_request
|
||||||
|
self.resource = AdditionalResource(self.hs, handler)
|
||||||
|
|
||||||
|
request, channel = self.make_request("GET", "/")
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(request.code, 200)
|
||||||
|
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
|
@ -24,12 +24,7 @@ from twisted.web.server import NOT_DONE_YET
|
|||||||
|
|
||||||
from synapse.api.errors import Codes, RedirectException, SynapseError
|
from synapse.api.errors import Codes, RedirectException, SynapseError
|
||||||
from synapse.config.server import parse_listener_def
|
from synapse.config.server import parse_listener_def
|
||||||
from synapse.http.server import (
|
from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
|
||||||
DirectServeResource,
|
|
||||||
JsonResource,
|
|
||||||
OptionsResource,
|
|
||||||
wrap_html_request_handler,
|
|
||||||
)
|
|
||||||
from synapse.http.site import SynapseSite, logger
|
from synapse.http.site import SynapseSite, logger
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
@ -256,12 +251,11 @@ class OptionsResourceTests(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
||||||
class TestResource(DirectServeResource):
|
class TestResource(DirectServeHtmlResource):
|
||||||
callback = None
|
callback = None
|
||||||
|
|
||||||
@wrap_html_request_handler
|
|
||||||
async def _async_render_GET(self, request):
|
async def _async_render_GET(self, request):
|
||||||
return await self.callback(request)
|
await self.callback(request)
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.reactor = ThreadedMemoryReactorClock()
|
self.reactor = ThreadedMemoryReactorClock()
|
||||||
|
Loading…
Reference in New Issue
Block a user