Implement handling of HTTP HEAD requests. (#7999)

This commit is contained in:
Patrick Cloke 2020-08-03 08:45:42 -04:00 committed by GitHub
parent 2a89ce8cd4
commit 6812509807
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 8 deletions

1
changelog.d/7999.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long standing bug where HTTP HEAD requests resulted in a 400 error.

View File

@ -242,10 +242,12 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
no appropriate method exists. Can be overriden in sub classes for no appropriate method exists. Can be overriden in sub classes for
different routing. different routing.
""" """
# Treat HEAD requests as GET requests.
request_method = request.method.decode("ascii")
if request_method == "HEAD":
request_method = "GET"
method_handler = getattr( method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
self, "_async_render_%s" % (request.method.decode("ascii"),), None
)
if method_handler: if method_handler:
raw_callback_return = method_handler(request) raw_callback_return = method_handler(request)
@ -362,11 +364,15 @@ class JsonResource(DirectServeJsonResource):
A tuple of the callback to use, the name of the servlet, and the A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback key word arguments to pass to the callback
""" """
# Treat HEAD requests as GET requests.
request_path = request.path.decode("ascii") request_path = request.path.decode("ascii")
request_method = request.method
if request_method == b"HEAD":
request_method = b"GET"
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request_method, []):
m = path_entry.pattern.match(request_path) m = path_entry.pattern.match(request_path)
if m: if m:
# We found a match! # We found a match!
@ -579,7 +585,7 @@ def set_cors_headers(request: Request):
""" """
request.setHeader(b"Access-Control-Allow-Origin", b"*") request.setHeader(b"Access-Control-Allow-Origin", b"*")
request.setHeader( request.setHeader(
b"Access-Control-Allow-Methods", b"GET, POST, PUT, DELETE, OPTIONS" b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS"
) )
request.setHeader( request.setHeader(
b"Access-Control-Allow-Headers", b"Access-Control-Allow-Headers",

View File

@ -157,6 +157,29 @@ class JsonResourceTests(unittest.TestCase):
self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["error"], "Unrecognized request")
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
def test_head_request(self):
"""
JsonResource.handler_for_request gives correctly decoded URL args to
the callback, while Twisted will give the raw bytes of URL query
arguments.
"""
def _callback(request, **kwargs):
return 200, {"result": True}
res = JsonResource(self.homeserver)
res.register_paths(
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet",
)
# The path was registered as GET, but this is a HEAD request.
request, channel = make_request(self.reactor, b"HEAD", b"/_matrix/foo")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
self.assertEqual(channel.headers.getRawHeaders(b"Content-Length"), [b"15"])
class OptionsResourceTests(unittest.TestCase): class OptionsResourceTests(unittest.TestCase):
def setUp(self): def setUp(self):
@ -255,7 +278,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.reactor = ThreadedMemoryReactorClock() self.reactor = ThreadedMemoryReactorClock()
def test_good_response(self): def test_good_response(self):
def callback(request): async def callback(request):
request.write(b"response") request.write(b"response")
request.finish() request.finish()
@ -275,7 +298,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
with the right location. with the right location.
""" """
def callback(request, **kwargs): async def callback(request, **kwargs):
raise RedirectException(b"/look/an/eagle", 301) raise RedirectException(b"/look/an/eagle", 301)
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
@ -295,7 +318,7 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
returned too returned too
""" """
def callback(request, **kwargs): async def callback(request, **kwargs):
e = RedirectException(b"/no/over/there", 304) e = RedirectException(b"/no/over/there", 304)
e.cookies.append(b"session=yespls") e.cookies.append(b"session=yespls")
raise e raise e
@ -312,3 +335,19 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.assertEqual(location_headers, [b"/no/over/there"]) self.assertEqual(location_headers, [b"/no/over/there"])
cookies_headers = [v for k, v in headers if k == b"Set-Cookie"] cookies_headers = [v for k, v in headers if k == b"Set-Cookie"]
self.assertEqual(cookies_headers, [b"session=yespls"]) self.assertEqual(cookies_headers, [b"session=yespls"])
def test_head_request(self):
"""A head request should work by being turned into a GET request."""
async def callback(request):
request.write(b"response")
request.finish()
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
request, channel = make_request(self.reactor, b"HEAD", b"/path")
render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)