Merge pull request #1098 from matrix-org/markjh/bearer_token

Allow clients to supply access_tokens as headers
This commit is contained in:
Mark Haines 2016-10-25 17:33:15 +01:00 committed by GitHub
commit 177f104432
6 changed files with 67 additions and 24 deletions

View File

@ -1170,7 +1170,8 @@ def has_access_token(request):
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
return bool(query_params)
auth_headers = request.requestHeaders.getRawHeaders("Authorization")
return bool(query_params) or bool(auth_headers)
def get_access_token_from_request(request, token_not_found_http_status=401):
@ -1188,7 +1189,34 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
Raises:
AuthError: If there isn't an access_token in the request.
"""
auth_headers = request.requestHeaders.getRawHeaders("Authorization")
query_params = request.args.get("access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(

View File

@ -20,7 +20,7 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
from synapse.types import UserID
from tests.utils import setup_test_homeserver
from tests.utils import setup_test_homeserver, mock_getRawHeaders
import pymacaroons
@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
@ -61,7 +61,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@ -74,7 +74,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), self.test_user)
@ -96,7 +96,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@ -106,7 +106,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)
@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = yield self.auth.get_user_by_req(request)
self.assertEquals(requester.user.to_string(), masquerading_user_id)
@ -135,7 +135,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError)

View File

@ -219,7 +219,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
"user_id": self.u_onion.to_string(),
"typing": True,
}
)
),
federation_auth=True,
)
self.on_new_event.assert_has_calls([

View File

@ -17,6 +17,7 @@ from synapse.rest.client.v1.register import CreateUserRestServlet
from twisted.internet import defer
from mock import Mock
from tests import unittest
from tests.utils import mock_getRawHeaders
import json
@ -30,6 +31,7 @@ class CreateUserServletTestCase(unittest.TestCase):
path='/_matrix/client/api/v1/createUser'
)
self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.registration_handler = Mock()

View File

@ -3,6 +3,7 @@ from synapse.api.errors import SynapseError
from twisted.internet import defer
from mock import Mock
from tests import unittest
from tests.utils import mock_getRawHeaders
import json
@ -16,6 +17,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
path='/_matrix/api/v2_alpha/register'
)
self.request.args = {}
self.request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(

View File

@ -116,6 +116,15 @@ def get_mock_call_args(pattern_func, mock_func):
return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
def mock_getRawHeaders(headers=None):
headers = headers if headers is not None else {}
def getRawHeaders(name, default=None):
return headers.get(name, default)
return getRawHeaders
# This is a mock /resource/ not an entire server
class MockHttpResource(HttpServer):
@ -128,7 +137,7 @@ class MockHttpResource(HttpServer):
@patch('twisted.web.http.Request')
@defer.inlineCallbacks
def trigger(self, http_method, path, content, mock_request):
def trigger(self, http_method, path, content, mock_request, federation_auth=False):
""" Fire an HTTP event.
Args:
@ -156,9 +165,10 @@ class MockHttpResource(HttpServer):
mock_request.getClientIP.return_value = "-"
mock_request.requestHeaders.getRawHeaders.return_value = [
"X-Matrix origin=test,key=,sig="
]
headers = {}
if federation_auth:
headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it
mock_request.path = path