mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-02 10:56:15 -05:00
Make the http server handle coroutine-making REST servlets (#5475)
This commit is contained in:
parent
c7ff297dde
commit
f40a7dc41f
1
changelog.d/5475.misc
Normal file
1
changelog.d/5475.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Synapse can now handle RestServlets that return coroutines.
|
@ -16,10 +16,11 @@
|
|||||||
|
|
||||||
import cgi
|
import cgi
|
||||||
import collections
|
import collections
|
||||||
|
import http.client
|
||||||
import logging
|
import logging
|
||||||
|
import types
|
||||||
from six import PY3
|
import urllib
|
||||||
from six.moves import http_client, urllib
|
from io import BytesIO
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
||||||
|
|
||||||
@ -41,11 +42,6 @@ from synapse.api.errors import (
|
|||||||
from synapse.util.caches import intern_dict
|
from synapse.util.caches import intern_dict
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn
|
||||||
|
|
||||||
if PY3:
|
|
||||||
from io import BytesIO
|
|
||||||
else:
|
|
||||||
from cStringIO import StringIO as BytesIO
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
|
HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
|
||||||
@ -75,10 +71,9 @@ def wrap_json_request_handler(h):
|
|||||||
deferred fails with any other type of error we send a 500 reponse.
|
deferred fails with any other type of error we send a 500 reponse.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def wrapped_request_handler(self, request):
|
||||||
def wrapped_request_handler(self, request):
|
|
||||||
try:
|
try:
|
||||||
yield h(self, request)
|
await h(self, request)
|
||||||
except SynapseError as e:
|
except SynapseError as e:
|
||||||
code = e.code
|
code = e.code
|
||||||
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
|
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
|
||||||
@ -142,10 +137,12 @@ def wrap_html_request_handler(h):
|
|||||||
where "request" must be a SynapseRequest.
|
where "request" must be a SynapseRequest.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapped_request_handler(self, request):
|
async def wrapped_request_handler(self, request):
|
||||||
d = defer.maybeDeferred(h, self, request)
|
try:
|
||||||
d.addErrback(_return_html_error, request)
|
return await h(self, request)
|
||||||
return d
|
except Exception:
|
||||||
|
f = failure.Failure()
|
||||||
|
return _return_html_error(f, request)
|
||||||
|
|
||||||
return wrap_async_request_handler(wrapped_request_handler)
|
return wrap_async_request_handler(wrapped_request_handler)
|
||||||
|
|
||||||
@ -171,7 +168,7 @@ def _return_html_error(f, request):
|
|||||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
code = http_client.INTERNAL_SERVER_ERROR
|
code = http.client.INTERNAL_SERVER_ERROR
|
||||||
msg = "Internal server error"
|
msg = "Internal server error"
|
||||||
|
|
||||||
logger.error(
|
logger.error(
|
||||||
@ -201,10 +198,9 @@ def wrap_async_request_handler(h):
|
|||||||
logged until the deferred completes.
|
logged until the deferred completes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def wrapped_async_request_handler(self, request):
|
||||||
def wrapped_async_request_handler(self, request):
|
|
||||||
with request.processing():
|
with request.processing():
|
||||||
yield h(self, request)
|
await h(self, request)
|
||||||
|
|
||||||
# we need to preserve_fn here, because the synchronous render method won't yield for
|
# we need to preserve_fn here, because the synchronous render method won't yield for
|
||||||
# us (obviously)
|
# us (obviously)
|
||||||
@ -270,12 +266,11 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
def render(self, request):
|
def render(self, request):
|
||||||
""" This gets called by twisted every time someone sends us a request.
|
""" This gets called by twisted every time someone sends us a request.
|
||||||
"""
|
"""
|
||||||
self._async_render(request)
|
defer.ensureDeferred(self._async_render(request))
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
@wrap_json_request_handler
|
@wrap_json_request_handler
|
||||||
@defer.inlineCallbacks
|
async def _async_render(self, request):
|
||||||
def _async_render(self, request):
|
|
||||||
""" This gets called from render() every time someone sends us a request.
|
""" This gets called from render() every time someone sends us a request.
|
||||||
This checks if anyone has registered a callback for that method and
|
This checks if anyone has registered a callback for that method and
|
||||||
path.
|
path.
|
||||||
@ -292,26 +287,19 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
# Now trigger the callback. If it returns a response, we send it
|
# Now trigger the callback. If it returns a response, we send it
|
||||||
# here. If it throws an exception, that is handled by the wrapper
|
# here. If it throws an exception, that is handled by the wrapper
|
||||||
# installed by @request_handler.
|
# installed by @request_handler.
|
||||||
|
|
||||||
def _unquote(s):
|
|
||||||
if PY3:
|
|
||||||
# On Python 3, unquote is unicode -> unicode
|
|
||||||
return urllib.parse.unquote(s)
|
|
||||||
else:
|
|
||||||
# On Python 2, unquote is bytes -> bytes We need to encode the
|
|
||||||
# URL again (as it was decoded by _get_handler_for request), as
|
|
||||||
# ASCII because it's a URL, and then decode it to get the UTF-8
|
|
||||||
# characters that were quoted.
|
|
||||||
return urllib.parse.unquote(s.encode("ascii")).decode("utf8")
|
|
||||||
|
|
||||||
kwargs = intern_dict(
|
kwargs = intern_dict(
|
||||||
{
|
{
|
||||||
name: _unquote(value) if value else value
|
name: urllib.parse.unquote(value) if value else value
|
||||||
for name, value in group_dict.items()
|
for name, value in group_dict.items()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
callback_return = yield callback(request, **kwargs)
|
callback_return = callback(request, **kwargs)
|
||||||
|
|
||||||
|
# Is it synchronous? We'll allow this for now.
|
||||||
|
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
|
||||||
|
callback_return = await callback_return
|
||||||
|
|
||||||
if callback_return is not None:
|
if callback_return is not None:
|
||||||
code, response = callback_return
|
code, response = callback_return
|
||||||
self._send_response(request, code, response)
|
self._send_response(request, code, response)
|
||||||
@ -360,6 +348,23 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DirectServeResource(resource.Resource):
|
||||||
|
def render(self, request):
|
||||||
|
"""
|
||||||
|
Render the request, using an asynchronous render handler if it exists.
|
||||||
|
"""
|
||||||
|
render_callback_name = "_async_render_" + request.method.decode("ascii")
|
||||||
|
|
||||||
|
if hasattr(self, render_callback_name):
|
||||||
|
# Call the handler
|
||||||
|
callback = getattr(self, render_callback_name)
|
||||||
|
defer.ensureDeferred(callback(request))
|
||||||
|
|
||||||
|
return NOT_DONE_YET
|
||||||
|
else:
|
||||||
|
super().render(request)
|
||||||
|
|
||||||
|
|
||||||
def _options_handler(request):
|
def _options_handler(request):
|
||||||
"""Request handler for OPTIONS requests
|
"""Request handler for OPTIONS requests
|
||||||
|
|
||||||
|
@ -23,13 +23,13 @@ from six.moves import http_client
|
|||||||
import jinja2
|
import jinja2
|
||||||
from jinja2 import TemplateNotFound
|
from jinja2 import TemplateNotFound
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.web.resource import Resource
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
|
||||||
|
|
||||||
from synapse.api.errors import NotFoundError, StoreError, SynapseError
|
from synapse.api.errors import NotFoundError, StoreError, SynapseError
|
||||||
from synapse.config import ConfigError
|
from synapse.config import ConfigError
|
||||||
from synapse.http.server import finish_request, wrap_html_request_handler
|
from synapse.http.server import (
|
||||||
|
DirectServeResource,
|
||||||
|
finish_request,
|
||||||
|
wrap_html_request_handler,
|
||||||
|
)
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ else:
|
|||||||
return a == b
|
return a == b
|
||||||
|
|
||||||
|
|
||||||
class ConsentResource(Resource):
|
class ConsentResource(DirectServeResource):
|
||||||
"""A twisted Resource to display a privacy policy and gather consent to it
|
"""A twisted Resource to display a privacy policy and gather consent to it
|
||||||
|
|
||||||
When accessed via GET, returns the privacy policy via a template.
|
When accessed via GET, returns the privacy policy via a template.
|
||||||
@ -87,7 +87,7 @@ class ConsentResource(Resource):
|
|||||||
Args:
|
Args:
|
||||||
hs (synapse.server.HomeServer): homeserver
|
hs (synapse.server.HomeServer): homeserver
|
||||||
"""
|
"""
|
||||||
Resource.__init__(self)
|
super().__init__()
|
||||||
|
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
@ -118,18 +118,12 @@ class ConsentResource(Resource):
|
|||||||
|
|
||||||
self._hmac_secret = hs.config.form_secret.encode("utf-8")
|
self._hmac_secret = hs.config.form_secret.encode("utf-8")
|
||||||
|
|
||||||
def render_GET(self, request):
|
|
||||||
self._async_render_GET(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_html_request_handler
|
@wrap_html_request_handler
|
||||||
@defer.inlineCallbacks
|
async def _async_render_GET(self, request):
|
||||||
def _async_render_GET(self, request):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
request (twisted.web.http.Request):
|
request (twisted.web.http.Request):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
version = parse_string(request, "v", default=self._default_consent_version)
|
version = parse_string(request, "v", default=self._default_consent_version)
|
||||||
username = parse_string(request, "u", required=False, default="")
|
username = parse_string(request, "u", required=False, default="")
|
||||||
userhmac = None
|
userhmac = None
|
||||||
@ -145,7 +139,7 @@ class ConsentResource(Resource):
|
|||||||
else:
|
else:
|
||||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||||
|
|
||||||
u = yield self.store.get_user_by_id(qualified_user_id)
|
u = await self.store.get_user_by_id(qualified_user_id)
|
||||||
if u is None:
|
if u is None:
|
||||||
raise NotFoundError("Unknown user")
|
raise NotFoundError("Unknown user")
|
||||||
|
|
||||||
@ -165,13 +159,8 @@ class ConsentResource(Resource):
|
|||||||
except TemplateNotFound:
|
except TemplateNotFound:
|
||||||
raise NotFoundError("Unknown policy version")
|
raise NotFoundError("Unknown policy version")
|
||||||
|
|
||||||
def render_POST(self, request):
|
|
||||||
self._async_render_POST(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_html_request_handler
|
@wrap_html_request_handler
|
||||||
@defer.inlineCallbacks
|
async def _async_render_POST(self, request):
|
||||||
def _async_render_POST(self, request):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
request (twisted.web.http.Request):
|
request (twisted.web.http.Request):
|
||||||
@ -188,12 +177,12 @@ class ConsentResource(Resource):
|
|||||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.store.user_set_consent_version(qualified_user_id, version)
|
await self.store.user_set_consent_version(qualified_user_id, version)
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
if e.code != 404:
|
if e.code != 404:
|
||||||
raise
|
raise
|
||||||
raise NotFoundError("Unknown user")
|
raise NotFoundError("Unknown user")
|
||||||
yield self.registration_handler.post_consent_actions(qualified_user_id)
|
await self.registration_handler.post_consent_actions(qualified_user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._render_template(request, "success.html")
|
self._render_template(request, "success.html")
|
||||||
|
@ -16,18 +16,20 @@ import logging
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.web.resource import Resource
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.crypto.keyring import ServerKeyFetcher
|
from synapse.crypto.keyring import ServerKeyFetcher
|
||||||
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
|
from synapse.http.server import (
|
||||||
|
DirectServeResource,
|
||||||
|
respond_with_json_bytes,
|
||||||
|
wrap_json_request_handler,
|
||||||
|
)
|
||||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RemoteKey(Resource):
|
class RemoteKey(DirectServeResource):
|
||||||
"""HTTP resource for retreiving the TLS certificate and NACL signature
|
"""HTTP resource for retreiving the TLS certificate and NACL signature
|
||||||
verification keys for a collection of servers. Checks that the reported
|
verification keys for a collection of servers. Checks that the reported
|
||||||
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
|
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
|
||||||
@ -94,13 +96,8 @@ class RemoteKey(Resource):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
|
|
||||||
def render_GET(self, request):
|
|
||||||
self.async_render_GET(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
@wrap_json_request_handler
|
||||||
@defer.inlineCallbacks
|
async def _async_render_GET(self, request):
|
||||||
def async_render_GET(self, request):
|
|
||||||
if len(request.postpath) == 1:
|
if len(request.postpath) == 1:
|
||||||
server, = request.postpath
|
server, = request.postpath
|
||||||
query = {server.decode("ascii"): {}}
|
query = {server.decode("ascii"): {}}
|
||||||
@ -114,20 +111,15 @@ class RemoteKey(Resource):
|
|||||||
else:
|
else:
|
||||||
raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
|
raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
|
||||||
|
|
||||||
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
|
await self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||||
|
|
||||||
def render_POST(self, request):
|
|
||||||
self.async_render_POST(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
@wrap_json_request_handler
|
||||||
@defer.inlineCallbacks
|
async def _async_render_POST(self, request):
|
||||||
def async_render_POST(self, request):
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
query = content["server_keys"]
|
query = content["server_keys"]
|
||||||
|
|
||||||
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
|
await self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_keys(self, request, query, query_remote_on_cache_miss=False):
|
def query_keys(self, request, query, query_remote_on_cache_miss=False):
|
||||||
|
@ -14,31 +14,28 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.web.resource import Resource
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
|
||||||
from synapse.http.server import respond_with_json, wrap_json_request_handler
|
from synapse.http.server import (
|
||||||
|
DirectServeResource,
|
||||||
|
respond_with_json,
|
||||||
|
wrap_json_request_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MediaConfigResource(Resource):
|
class MediaConfigResource(DirectServeResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
Resource.__init__(self)
|
super().__init__()
|
||||||
config = hs.get_config()
|
config = hs.get_config()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||||
|
|
||||||
def render_GET(self, request):
|
|
||||||
self._async_render_GET(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
@wrap_json_request_handler
|
||||||
@defer.inlineCallbacks
|
async def _async_render_GET(self, request):
|
||||||
def _async_render_GET(self, request):
|
await self.auth.get_user_by_req(request)
|
||||||
yield self.auth.get_user_by_req(request)
|
|
||||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||||
|
|
||||||
def render_OPTIONS(self, request):
|
def render_OPTIONS(self, request):
|
||||||
|
@ -14,37 +14,31 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.web.resource import Resource
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
|
||||||
|
|
||||||
import synapse.http.servlet
|
import synapse.http.servlet
|
||||||
from synapse.http.server import set_cors_headers, wrap_json_request_handler
|
from synapse.http.server import (
|
||||||
|
DirectServeResource,
|
||||||
|
set_cors_headers,
|
||||||
|
wrap_json_request_handler,
|
||||||
|
)
|
||||||
|
|
||||||
from ._base import parse_media_id, respond_404
|
from ._base import parse_media_id, respond_404
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DownloadResource(Resource):
|
class DownloadResource(DirectServeResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo):
|
def __init__(self, hs, media_repo):
|
||||||
Resource.__init__(self)
|
super().__init__()
|
||||||
|
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
# this is expected by @wrap_json_request_handler
|
# this is expected by @wrap_json_request_handler
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
def render_GET(self, request):
|
|
||||||
self._async_render_GET(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
@wrap_json_request_handler
|
||||||
@defer.inlineCallbacks
|
async def _async_render_GET(self, request):
|
||||||
def _async_render_GET(self, request):
|
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
request.setHeader(
|
request.setHeader(
|
||||||
b"Content-Security-Policy",
|
b"Content-Security-Policy",
|
||||||
@ -58,7 +52,7 @@ class DownloadResource(Resource):
|
|||||||
)
|
)
|
||||||
server_name, media_id, name = parse_media_id(request)
|
server_name, media_id, name = parse_media_id(request)
|
||||||
if server_name == self.server_name:
|
if server_name == self.server_name:
|
||||||
yield self.media_repo.get_local_media(request, media_id, name)
|
await self.media_repo.get_local_media(request, media_id, name)
|
||||||
else:
|
else:
|
||||||
allow_remote = synapse.http.servlet.parse_boolean(
|
allow_remote = synapse.http.servlet.parse_boolean(
|
||||||
request, "allow_remote", default=True
|
request, "allow_remote", default=True
|
||||||
@ -72,4 +66,4 @@ class DownloadResource(Resource):
|
|||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
yield self.media_repo.get_remote_media(request, server_name, media_id, name)
|
await self.media_repo.get_remote_media(request, server_name, media_id, name)
|
||||||
|
@ -32,12 +32,11 @@ from canonicaljson import json
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.web.resource import Resource
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
from synapse.http.server import (
|
from synapse.http.server import (
|
||||||
|
DirectServeResource,
|
||||||
respond_with_json,
|
respond_with_json,
|
||||||
respond_with_json_bytes,
|
respond_with_json_bytes,
|
||||||
wrap_json_request_handler,
|
wrap_json_request_handler,
|
||||||
@ -58,11 +57,11 @@ _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=r
|
|||||||
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
|
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
|
||||||
|
|
||||||
|
|
||||||
class PreviewUrlResource(Resource):
|
class PreviewUrlResource(DirectServeResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo, media_storage):
|
def __init__(self, hs, media_repo, media_storage):
|
||||||
Resource.__init__(self)
|
super().__init__()
|
||||||
|
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
@ -98,16 +97,11 @@ class PreviewUrlResource(Resource):
|
|||||||
def render_OPTIONS(self, request):
|
def render_OPTIONS(self, request):
|
||||||
return respond_with_json(request, 200, {}, send_cors=True)
|
return respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
def render_GET(self, request):
|
|
||||||
self._async_render_GET(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
@wrap_json_request_handler
|
||||||
@defer.inlineCallbacks
|
async def _async_render_GET(self, request):
|
||||||
def _async_render_GET(self, request):
|
|
||||||
|
|
||||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
url = parse_string(request, "url")
|
url = parse_string(request, "url")
|
||||||
if b"ts" in request.args:
|
if b"ts" in request.args:
|
||||||
ts = parse_integer(request, "ts")
|
ts = parse_integer(request, "ts")
|
||||||
@ -159,7 +153,7 @@ class PreviewUrlResource(Resource):
|
|||||||
else:
|
else:
|
||||||
logger.info("Returning cached response")
|
logger.info("Returning cached response")
|
||||||
|
|
||||||
og = yield make_deferred_yieldable(observable.observe())
|
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
|
||||||
respond_with_json_bytes(request, 200, og, send_cors=True)
|
respond_with_json_bytes(request, 200, og, send_cors=True)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -17,10 +17,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.web.resource import Resource
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
|
||||||
|
|
||||||
from synapse.http.server import set_cors_headers, wrap_json_request_handler
|
from synapse.http.server import (
|
||||||
|
DirectServeResource,
|
||||||
|
set_cors_headers,
|
||||||
|
wrap_json_request_handler,
|
||||||
|
)
|
||||||
from synapse.http.servlet import parse_integer, parse_string
|
from synapse.http.servlet import parse_integer, parse_string
|
||||||
|
|
||||||
from ._base import (
|
from ._base import (
|
||||||
@ -34,11 +36,11 @@ from ._base import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ThumbnailResource(Resource):
|
class ThumbnailResource(DirectServeResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo, media_storage):
|
def __init__(self, hs, media_repo, media_storage):
|
||||||
Resource.__init__(self)
|
super().__init__()
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
@ -47,13 +49,8 @@ class ThumbnailResource(Resource):
|
|||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
def render_GET(self, request):
|
|
||||||
self._async_render_GET(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_json_request_handler
|
@wrap_json_request_handler
|
||||||
@defer.inlineCallbacks
|
async def _async_render_GET(self, request):
|
||||||
def _async_render_GET(self, request):
|
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
server_name, media_id, _ = parse_media_id(request)
|
server_name, media_id, _ = parse_media_id(request)
|
||||||
width = parse_integer(request, "width", required=True)
|
width = parse_integer(request, "width", required=True)
|
||||||
@ -63,21 +60,21 @@ class ThumbnailResource(Resource):
|
|||||||
|
|
||||||
if server_name == self.server_name:
|
if server_name == self.server_name:
|
||||||
if self.dynamic_thumbnails:
|
if self.dynamic_thumbnails:
|
||||||
yield self._select_or_generate_local_thumbnail(
|
await self._select_or_generate_local_thumbnail(
|
||||||
request, media_id, width, height, method, m_type
|
request, media_id, width, height, method, m_type
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield self._respond_local_thumbnail(
|
await self._respond_local_thumbnail(
|
||||||
request, media_id, width, height, method, m_type
|
request, media_id, width, height, method, m_type
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(None, media_id)
|
self.media_repo.mark_recently_accessed(None, media_id)
|
||||||
else:
|
else:
|
||||||
if self.dynamic_thumbnails:
|
if self.dynamic_thumbnails:
|
||||||
yield self._select_or_generate_remote_thumbnail(
|
await self._select_or_generate_remote_thumbnail(
|
||||||
request, server_name, media_id, width, height, method, m_type
|
request, server_name, media_id, width, height, method, m_type
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
yield self._respond_remote_thumbnail(
|
await self._respond_remote_thumbnail(
|
||||||
request, server_name, media_id, width, height, method, m_type
|
request, server_name, media_id, width, height, method, m_type
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
@ -15,22 +15,24 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.web.resource import Resource
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from twisted.web.server import NOT_DONE_YET
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.http.server import respond_with_json, wrap_json_request_handler
|
from synapse.http.server import (
|
||||||
|
DirectServeResource,
|
||||||
|
respond_with_json,
|
||||||
|
wrap_json_request_handler,
|
||||||
|
)
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UploadResource(Resource):
|
class UploadResource(DirectServeResource):
|
||||||
isLeaf = True
|
isLeaf = True
|
||||||
|
|
||||||
def __init__(self, hs, media_repo):
|
def __init__(self, hs, media_repo):
|
||||||
Resource.__init__(self)
|
super().__init__()
|
||||||
|
|
||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.filepaths = media_repo.filepaths
|
self.filepaths = media_repo.filepaths
|
||||||
@ -41,18 +43,13 @@ class UploadResource(Resource):
|
|||||||
self.max_upload_size = hs.config.max_upload_size
|
self.max_upload_size = hs.config.max_upload_size
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
def render_POST(self, request):
|
|
||||||
self._async_render_POST(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
def render_OPTIONS(self, request):
|
def render_OPTIONS(self, request):
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
@wrap_json_request_handler
|
@wrap_json_request_handler
|
||||||
@defer.inlineCallbacks
|
async def _async_render_POST(self, request):
|
||||||
def _async_render_POST(self, request):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
# TODO: The checks here are a bit late. The content will have
|
# TODO: The checks here are a bit late. The content will have
|
||||||
# already been uploaded to a tmp file at this point
|
# already been uploaded to a tmp file at this point
|
||||||
content_length = request.getHeader(b"Content-Length").decode("ascii")
|
content_length = request.getHeader(b"Content-Length").decode("ascii")
|
||||||
@ -81,7 +78,7 @@ class UploadResource(Resource):
|
|||||||
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
|
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
|
||||||
# TODO(markjh): parse content-dispostion
|
# TODO(markjh): parse content-dispostion
|
||||||
|
|
||||||
content_uri = yield self.media_repo.create_content(
|
content_uri = await self.media_repo.create_content(
|
||||||
media_type, upload_name, request.content, content_length, requester.user
|
media_type, upload_name, request.content, content_length, requester.user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,34 +18,27 @@ import logging
|
|||||||
import saml2
|
import saml2
|
||||||
from saml2.client import Saml2Client
|
from saml2.client import Saml2Client
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
|
||||||
from twisted.web.server import NOT_DONE_YET
|
|
||||||
|
|
||||||
from synapse.api.errors import CodeMessageException
|
from synapse.api.errors import CodeMessageException
|
||||||
from synapse.http.server import wrap_html_request_handler
|
from synapse.http.server import DirectServeResource, wrap_html_request_handler
|
||||||
from synapse.http.servlet import parse_string
|
from synapse.http.servlet import parse_string
|
||||||
from synapse.rest.client.v1.login import SSOAuthHandler
|
from synapse.rest.client.v1.login import SSOAuthHandler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SAML2ResponseResource(Resource):
|
class SAML2ResponseResource(DirectServeResource):
|
||||||
"""A Twisted web resource which handles the SAML response"""
|
"""A Twisted web resource which handles the SAML response"""
|
||||||
|
|
||||||
isLeaf = 1
|
isLeaf = 1
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
Resource.__init__(self)
|
super().__init__()
|
||||||
|
|
||||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||||
|
|
||||||
def render_POST(self, request):
|
|
||||||
self._async_render_POST(request)
|
|
||||||
return NOT_DONE_YET
|
|
||||||
|
|
||||||
@wrap_html_request_handler
|
@wrap_html_request_handler
|
||||||
def _async_render_POST(self, request):
|
async def _async_render_POST(self, request):
|
||||||
resp_bytes = parse_string(request, "SAMLResponse", required=True)
|
resp_bytes = parse_string(request, "SAMLResponse", required=True)
|
||||||
relay_state = parse_string(request, "RelayState", required=True)
|
relay_state = parse_string(request, "RelayState", required=True)
|
||||||
|
|
||||||
|
@ -22,7 +22,6 @@ from binascii import unhexlify
|
|||||||
from mock import Mock
|
from mock import Mock
|
||||||
from six.moves.urllib import parse
|
from six.moves.urllib import parse
|
||||||
|
|
||||||
from twisted.internet import defer, reactor
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
|
||||||
from synapse.rest.media.v1._base import FileInfo
|
from synapse.rest.media.v1._base import FileInfo
|
||||||
@ -34,15 +33,17 @@ from synapse.util.logcontext import make_deferred_yieldable
|
|||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class MediaStorageTests(unittest.TestCase):
|
class MediaStorageTests(unittest.HomeserverTestCase):
|
||||||
def setUp(self):
|
|
||||||
|
needs_threadpool = True
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
|
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
|
||||||
|
self.addCleanup(shutil.rmtree, self.test_dir)
|
||||||
|
|
||||||
self.primary_base_path = os.path.join(self.test_dir, "primary")
|
self.primary_base_path = os.path.join(self.test_dir, "primary")
|
||||||
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
|
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
|
||||||
|
|
||||||
hs = Mock()
|
|
||||||
hs.get_reactor = Mock(return_value=reactor)
|
|
||||||
hs.config.media_store_path = self.primary_base_path
|
hs.config.media_store_path = self.primary_base_path
|
||||||
|
|
||||||
storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
|
storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
|
||||||
@ -52,10 +53,6 @@ class MediaStorageTests(unittest.TestCase):
|
|||||||
hs, self.primary_base_path, self.filepaths, storage_providers
|
hs, self.primary_base_path, self.filepaths, storage_providers
|
||||||
)
|
)
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
shutil.rmtree(self.test_dir)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_ensure_media_is_in_local_cache(self):
|
def test_ensure_media_is_in_local_cache(self):
|
||||||
media_id = "some_media_id"
|
media_id = "some_media_id"
|
||||||
test_body = "Test\n"
|
test_body = "Test\n"
|
||||||
@ -73,7 +70,15 @@ class MediaStorageTests(unittest.TestCase):
|
|||||||
# Now we run ensure_media_is_in_local_cache, which should copy the file
|
# Now we run ensure_media_is_in_local_cache, which should copy the file
|
||||||
# to the local cache.
|
# to the local cache.
|
||||||
file_info = FileInfo(None, media_id)
|
file_info = FileInfo(None, media_id)
|
||||||
local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info)
|
|
||||||
|
# This uses a real blocking threadpool so we have to wait for it to be
|
||||||
|
# actually done :/
|
||||||
|
x = self.media_storage.ensure_media_is_in_local_cache(file_info)
|
||||||
|
|
||||||
|
# Hotloop until the threadpool does its job...
|
||||||
|
self.wait_on_thread(x)
|
||||||
|
|
||||||
|
local_path = self.get_success(x)
|
||||||
|
|
||||||
self.assertTrue(os.path.exists(local_path))
|
self.assertTrue(os.path.exists(local_path))
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
@ -24,7 +25,8 @@ from canonicaljson import json
|
|||||||
|
|
||||||
import twisted
|
import twisted
|
||||||
import twisted.logger
|
import twisted.logger
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred, succeed
|
||||||
|
from twisted.python.threadpool import ThreadPool
|
||||||
from twisted.trial import unittest
|
from twisted.trial import unittest
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
@ -164,6 +166,7 @@ class HomeserverTestCase(TestCase):
|
|||||||
|
|
||||||
servlets = []
|
servlets = []
|
||||||
hijack_auth = True
|
hijack_auth = True
|
||||||
|
needs_threadpool = False
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""
|
"""
|
||||||
@ -192,15 +195,19 @@ class HomeserverTestCase(TestCase):
|
|||||||
if self.hijack_auth:
|
if self.hijack_auth:
|
||||||
|
|
||||||
def get_user_by_access_token(token=None, allow_guest=False):
|
def get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return succeed(
|
||||||
"user": UserID.from_string(self.helper.auth_user_id),
|
{
|
||||||
"token_id": 1,
|
"user": UserID.from_string(self.helper.auth_user_id),
|
||||||
"is_guest": False,
|
"token_id": 1,
|
||||||
}
|
"is_guest": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def get_user_by_req(request, allow_guest=False, rights="access"):
|
def get_user_by_req(request, allow_guest=False, rights="access"):
|
||||||
return create_requester(
|
return succeed(
|
||||||
UserID.from_string(self.helper.auth_user_id), 1, False, None
|
create_requester(
|
||||||
|
UserID.from_string(self.helper.auth_user_id), 1, False, None
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hs.get_auth().get_user_by_req = get_user_by_req
|
self.hs.get_auth().get_user_by_req = get_user_by_req
|
||||||
@ -209,9 +216,26 @@ class HomeserverTestCase(TestCase):
|
|||||||
return_value="1234"
|
return_value="1234"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.needs_threadpool:
|
||||||
|
self.reactor.threadpool = ThreadPool()
|
||||||
|
self.addCleanup(self.reactor.threadpool.stop)
|
||||||
|
self.reactor.threadpool.start()
|
||||||
|
|
||||||
if hasattr(self, "prepare"):
|
if hasattr(self, "prepare"):
|
||||||
self.prepare(self.reactor, self.clock, self.hs)
|
self.prepare(self.reactor, self.clock, self.hs)
|
||||||
|
|
||||||
|
def wait_on_thread(self, deferred, timeout=10):
|
||||||
|
"""
|
||||||
|
Wait until a Deferred is done, where it's waiting on a real thread.
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while not deferred.called:
|
||||||
|
if start_time + timeout < time.time():
|
||||||
|
raise ValueError("Timed out waiting for threadpool")
|
||||||
|
self.reactor.advance(0.01)
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
"""
|
"""
|
||||||
Make and return a homeserver.
|
Make and return a homeserver.
|
||||||
|
Loading…
Reference in New Issue
Block a user