Implement RedirectException (#6687)

Allow REST endpoint implemnentations to raise a RedirectException, which will
redirect the user's browser to a given location.
This commit is contained in:
Richard van der Hoff 2020-01-15 15:58:55 +00:00 committed by GitHub
parent 28c98e51ff
commit 8f5d7302ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 113 additions and 7 deletions

1
changelog.d/6687.misc Normal file
View File

@ -0,0 +1 @@
Allow REST endpoint implementations to raise a RedirectException, which will redirect the user's browser to a given location.

View File

@ -17,13 +17,15 @@
"""Contains exceptions and error codes.""" """Contains exceptions and error codes."""
import logging import logging
from typing import Dict from typing import Dict, List
from six import iteritems from six import iteritems
from six.moves import http_client from six.moves import http_client
from canonicaljson import json from canonicaljson import json
from twisted.web import http
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -80,6 +82,29 @@ class CodeMessageException(RuntimeError):
self.msg = msg self.msg = msg
class RedirectException(CodeMessageException):
"""A pseudo-error indicating that we want to redirect the client to a different
location
Attributes:
cookies: a list of set-cookies values to add to the response. For example:
b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT"
"""
def __init__(self, location: bytes, http_code: int = http.FOUND):
"""
Args:
location: the URI to redirect to
http_code: the HTTP response code
"""
msg = "Redirect to %s" % (location.decode("utf-8"),)
super().__init__(code=http_code, msg=msg)
self.location = location
self.cookies = [] # type: List[bytes]
class SynapseError(CodeMessageException): class SynapseError(CodeMessageException):
"""A base exception type for matrix errors which have an errcode and error """A base exception type for matrix errors which have an errcode and error
message (as well as an HTTP status code). message (as well as an HTTP status code).

View File

@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cgi
import collections import collections
import html
import http.client import http.client
import logging import logging
import types import types
@ -36,6 +36,7 @@ import synapse.metrics
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
Codes, Codes,
RedirectException,
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
@ -153,14 +154,18 @@ def _return_html_error(f, request):
Args: Args:
f (twisted.python.failure.Failure): f (twisted.python.failure.Failure):
request (twisted.web.iweb.IRequest): request (twisted.web.server.Request):
""" """
if f.check(CodeMessageException): if f.check(CodeMessageException):
cme = f.value cme = f.value
code = cme.code code = cme.code
msg = cme.msg msg = cme.msg
if isinstance(cme, SynapseError): if isinstance(cme, RedirectException):
logger.info("%s redirect to %s", request, cme.location)
request.setHeader(b"location", cme.location)
request.cookies.extend(cme.cookies)
elif isinstance(cme, SynapseError):
logger.info("%s SynapseError: %s - %s", request, code, msg) logger.info("%s SynapseError: %s - %s", request, code, msg)
else: else:
logger.error( logger.error(
@ -178,7 +183,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()), exc_info=(f.type, f.value, f.getTracebackObject()),
) )
body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8") body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
request.setResponseCode(code) request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % (len(body),)) request.setHeader(b"Content-Length", b"%i" % (len(body),))

View File

@ -23,8 +23,12 @@ from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import (
DirectServeResource,
JsonResource,
wrap_html_request_handler,
)
from synapse.http.site import SynapseSite, logger from synapse.http.site import SynapseSite, logger
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock from synapse.util import Clock
@ -164,6 +168,77 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
class WrapHtmlRequestHandlerTests(unittest.TestCase):
class TestResource(DirectServeResource):
callback = None
@wrap_html_request_handler
async def _async_render_GET(self, request):
return await self.callback(request)
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
def test_good_response(self):
def callback(request):
request.write(b"response")
request.finish()
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200")
body = channel.result["body"]
self.assertEqual(body, b"response")
def test_redirect_exception(self):
"""
If the callback raises a RedirectException, it is turned into a 30x
with the right location.
"""
def callback(request, **kwargs):
raise RedirectException(b"/look/an/eagle", 301)
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"301")
headers = channel.result["headers"]
location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/look/an/eagle"])
def test_redirect_exception_with_cookie(self):
"""
If the callback raises a RedirectException which sets a cookie, that is
returned too
"""
def callback(request, **kwargs):
e = RedirectException(b"/no/over/there", 304)
e.cookies.append(b"session=yespls")
raise e
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"GET", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"304")
headers = channel.result["headers"]
location_headers = [v for k, v in headers if k == b"Location"]
self.assertEqual(location_headers, [b"/no/over/there"])
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"])
class SiteTestCase(unittest.HomeserverTestCase): class SiteTestCase(unittest.HomeserverTestCase):
def test_lose_connection(self): def test_lose_connection(self):
""" """