Merge pull request #2586 from matrix-org/rav/frontend_proxy_auth_header

Front-end proxy: pass through auth header
This commit is contained in:
Richard van der Hoff 2017-10-27 11:01:50 +01:00 committed by GitHub
commit 8b56977b6f
2 changed files with 87 additions and 28 deletions

View File

@ -88,9 +88,16 @@ class KeyUploadServlet(RestServlet):
if body: if body:
# They're actually trying to upload something, proxy to main synapse. # They're actually trying to upload something, proxy to main synapse.
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
headers = {
"Authorization": auth_headers,
}
result = yield self.http_client.post_json_get_json( result = yield self.http_client.post_json_get_json(
self.main_uri + request.uri, self.main_uri + request.uri,
body, body,
headers=headers,
) )
defer.returnValue((200, result)) defer.returnValue((200, result))

View File

@ -114,19 +114,34 @@ class SimpleHttpClient(object):
raise e raise e
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}): def post_urlencoded_get_json(self, uri, args={}, headers=None):
"""
Args:
uri (str):
args (dict[str, str|List[str]]): query params
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred[object]: parsed json
"""
# TODO: Do we ever want to log message contents? # TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True) query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request( response = yield self.request(
"POST", "POST",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers(actual_headers),
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
}),
bodyProducer=FileBodyProducer(StringIO(query_bytes)) bodyProducer=FileBodyProducer(StringIO(query_bytes))
) )
@ -135,18 +150,33 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json_get_json(self, uri, post_json): def post_json_get_json(self, uri, post_json, headers=None):
"""
Args:
uri (str):
post_json (object):
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred[object]: parsed json
"""
json_str = encode_canonical_json(post_json) json_str = encode_canonical_json(post_json)
logger.debug("HTTP POST %s -> %s", json_str, uri) logger.debug("HTTP POST %s -> %s", json_str, uri)
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request( response = yield self.request(
"POST", "POST",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers(actual_headers),
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}),
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
) )
@ -160,7 +190,7 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, uri, args={}): def get_json(self, uri, args={}, headers=None):
""" Gets some json from the given URI. """ Gets some json from the given URI.
Args: Args:
@ -169,6 +199,8 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON. HTTP body as JSON.
@ -177,13 +209,13 @@ class SimpleHttpClient(object):
error message. error message.
""" """
try: try:
body = yield self.get_raw(uri, args) body = yield self.get_raw(uri, args, headers=headers)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
except CodeMessageException as e: except CodeMessageException as e:
raise self._exceptionFromFailedRequest(e.code, e.msg) raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, uri, json_body, args={}): def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI. """ Puts some json to the given URI.
Args: Args:
@ -193,6 +225,8 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON. HTTP body as JSON.
@ -205,13 +239,17 @@ class SimpleHttpClient(object):
json_str = encode_canonical_json(json_body) json_str = encode_canonical_json(json_body)
actual_headers = {
b"Content-Type": [b"application/json"],
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request( response = yield self.request(
"PUT", "PUT",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers(actual_headers),
b"User-Agent": [self.user_agent],
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
) )
@ -226,7 +264,7 @@ class SimpleHttpClient(object):
raise CodeMessageException(response.code, body) raise CodeMessageException(response.code, body)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_raw(self, uri, args={}): def get_raw(self, uri, args={}, headers=None):
""" Gets raw text from the given URI. """ Gets raw text from the given URI.
Args: Args:
@ -235,6 +273,8 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body at text. HTTP body at text.
@ -246,12 +286,16 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes) uri = "%s?%s" % (uri, query_bytes)
actual_headers = {
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request( response = yield self.request(
"GET", "GET",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers(actual_headers),
b"User-Agent": [self.user_agent],
})
) )
body = yield make_deferred_yieldable(readBody(response)) body = yield make_deferred_yieldable(readBody(response))
@ -274,27 +318,33 @@ class SimpleHttpClient(object):
# The two should be factored out. # The two should be factored out.
@defer.inlineCallbacks @defer.inlineCallbacks
def get_file(self, url, output_stream, max_size=None): def get_file(self, url, output_stream, max_size=None, headers=None):
"""GETs a file from a given URL """GETs a file from a given URL
Args: Args:
url (str): The URL to GET url (str): The URL to GET
output_stream (file): File to write the response body to. output_stream (file): File to write the response body to.
headers (dict[str, List[str]]|None): If not None, a map from
header name to a list of values for that header
Returns: Returns:
A (int,dict,string,int) tuple of the file length, dict of the response A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code. headers, absolute URI of the response and HTTP response code.
""" """
actual_headers = {
b"User-Agent": [self.user_agent],
}
if headers:
actual_headers.update(headers)
response = yield self.request( response = yield self.request(
"GET", "GET",
url.encode("ascii"), url.encode("ascii"),
headers=Headers({ headers=Headers(actual_headers),
b"User-Agent": [self.user_agent],
})
) )
headers = dict(response.headers.getAllRawHeaders()) resp_headers = dict(response.headers.getAllRawHeaders())
if 'Content-Length' in headers and headers['Content-Length'] > max_size: if 'Content-Length' in resp_headers and resp_headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError( raise SynapseError(
502, 502,
@ -326,7 +376,9 @@ class SimpleHttpClient(object):
Codes.UNKNOWN, Codes.UNKNOWN,
) )
defer.returnValue((length, headers, response.request.absoluteURI, response.code)) defer.returnValue(
(length, resp_headers, response.request.absoluteURI, response.code),
)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.