Make the http server handle coroutine-making REST servlets (#5475)

This commit is contained in:
Amber Brown 2019-06-29 17:06:55 +10:00 committed by GitHub
parent c7ff297dde
commit f40a7dc41f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 162 additions and 174 deletions

View file

@ -16,10 +16,11 @@
import cgi
import collections
import http.client
import logging
from six import PY3
from six.moves import http_client, urllib
import types
import urllib
from io import BytesIO
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@ -41,11 +42,6 @@ from synapse.api.errors import (
from synapse.util.caches import intern_dict
from synapse.util.logcontext import preserve_fn
if PY3:
from io import BytesIO
else:
from cStringIO import StringIO as BytesIO
logger = logging.getLogger(__name__)
HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
@ -75,10 +71,9 @@ def wrap_json_request_handler(h):
deferred fails with any other type of error we send a 500 reponse.
"""
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
async def wrapped_request_handler(self, request):
try:
yield h(self, request)
await h(self, request)
except SynapseError as e:
code = e.code
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
@ -142,10 +137,12 @@ def wrap_html_request_handler(h):
where "request" must be a SynapseRequest.
"""
def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request)
return d
async def wrapped_request_handler(self, request):
try:
return await h(self, request)
except Exception:
f = failure.Failure()
return _return_html_error(f, request)
return wrap_async_request_handler(wrapped_request_handler)
@ -171,7 +168,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
else:
code = http_client.INTERNAL_SERVER_ERROR
code = http.client.INTERNAL_SERVER_ERROR
msg = "Internal server error"
logger.error(
@ -201,10 +198,9 @@ def wrap_async_request_handler(h):
logged until the deferred completes.
"""
@defer.inlineCallbacks
def wrapped_async_request_handler(self, request):
async def wrapped_async_request_handler(self, request):
with request.processing():
yield h(self, request)
await h(self, request)
# we need to preserve_fn here, because the synchronous render method won't yield for
# us (obviously)
@ -270,12 +266,11 @@ class JsonResource(HttpServer, resource.Resource):
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
self._async_render(request)
defer.ensureDeferred(self._async_render(request))
return NOT_DONE_YET
@wrap_json_request_handler
@defer.inlineCallbacks
def _async_render(self, request):
async def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
@ -292,26 +287,19 @@ class JsonResource(HttpServer, resource.Resource):
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
def _unquote(s):
if PY3:
# On Python 3, unquote is unicode -> unicode
return urllib.parse.unquote(s)
else:
# On Python 2, unquote is bytes -> bytes We need to encode the
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
return urllib.parse.unquote(s.encode("ascii")).decode("utf8")
kwargs = intern_dict(
{
name: _unquote(value) if value else value
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
)
callback_return = yield callback(request, **kwargs)
callback_return = callback(request, **kwargs)
# Is it synchronous? We'll allow this for now.
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await callback_return
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
@ -360,6 +348,23 @@ class JsonResource(HttpServer, resource.Resource):
)
class DirectServeResource(resource.Resource):
def render(self, request):
"""
Render the request, using an asynchronous render handler if it exists.
"""
render_callback_name = "_async_render_" + request.method.decode("ascii")
if hasattr(self, render_callback_name):
# Call the handler
callback = getattr(self, render_callback_name)
defer.ensureDeferred(callback(request))
return NOT_DONE_YET
else:
super().render(request)
def _options_handler(request):
"""Request handler for OPTIONS requests