Clean up exception handling for access_tokens (#5656)

First of all, let's get rid of `TOKEN_NOT_FOUND_HTTP_STATUS`. It was a hack we
did at one point when it was possible to return either a 403 or a 401 if the
creds were missing. We always return a 401 in these cases now (thankfully), so
it's not needed.

Let's also stop abusing `AuthError` for these cases. Honestly they have nothing
that relates them to the other places that `AuthError` is used, other than the
fact that they are loosely under the 'Auth' banner. It makes no sense for them
to share exception classes.

Instead, let's add a couple of new exception classes: `InvalidClientTokenError`
and `MissingClientTokenError`, for the `M_UNKNOWN_TOKEN` and `M_MISSING_TOKEN`
cases respectively - and an `InvalidClientCredentialsError` base class for the
two of them.
This commit is contained in:
Richard van der Hoff 2019-07-11 11:06:23 +01:00 committed by GitHub
parent 38a6d3eea7
commit 0a4001eba1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 111 additions and 100 deletions

1
changelog.d/5656.misc Normal file
View File

@ -0,0 +1 @@
Clean up exception handling around client access tokens.

View File

@ -25,7 +25,13 @@ from twisted.internet import defer
import synapse.types import synapse.types
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, ResourceLimitError from synapse.api.errors import (
AuthError,
Codes,
InvalidClientTokenError,
MissingClientTokenError,
ResourceLimitError,
)
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
@ -63,7 +69,6 @@ class Auth(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
register_cache("cache", "token_cache", self.token_cache) register_cache("cache", "token_cache", self.token_cache)
@ -189,18 +194,17 @@ class Auth(object):
Returns: Returns:
defer.Deferred: resolves to a ``synapse.types.Requester`` object defer.Deferred: resolves to a ``synapse.types.Requester`` object
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
AuthError if access is denied for the user in the access token
""" """
# Can optionally look elsewhere in the request (e.g. headers)
try: try:
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent", default=[b""] b"User-Agent", default=[b""]
)[0].decode("ascii", "surrogateescape") )[0].decode("ascii", "surrogateescape")
access_token = self.get_access_token_from_request( access_token = self.get_access_token_from_request(request)
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
user_id, app_service = yield self._get_appservice_user_id(request) user_id, app_service = yield self._get_appservice_user_id(request)
if user_id: if user_id:
@ -264,18 +268,12 @@ class Auth(object):
) )
) )
except KeyError: except KeyError:
raise AuthError( raise MissingClientTokenError()
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Missing access token.",
errcode=Codes.MISSING_TOKEN,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request): def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token( app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request( self.get_access_token_from_request(request)
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
) )
if app_service is None: if app_service is None:
defer.returnValue((None, None)) defer.returnValue((None, None))
@ -313,7 +311,8 @@ class Auth(object):
`token_id` (int|None): access token id. May be None if guest `token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token `device_id` (str|None): device corresponding to access token
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. InvalidClientCredentialsError if no user by that token exists or the token
is invalid.
""" """
if rights == "access": if rights == "access":
@ -331,11 +330,7 @@ class Auth(object):
if not guest: if not guest:
# non-guest access tokens must be in the database # non-guest access tokens must be in the database
logger.warning("Unrecognised access token - not in store.") logger.warning("Unrecognised access token - not in store.")
raise AuthError( raise InvalidClientTokenError()
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN,
)
# Guest access tokens are not stored in the database (there can # Guest access tokens are not stored in the database (there can
# only be one access token per guest, anyway). # only be one access token per guest, anyway).
@ -350,16 +345,10 @@ class Auth(object):
# guest tokens. # guest tokens.
stored_user = yield self.store.get_user_by_id(user_id) stored_user = yield self.store.get_user_by_id(user_id)
if not stored_user: if not stored_user:
raise AuthError( raise InvalidClientTokenError("Unknown user_id %s" % user_id)
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unknown user_id %s" % user_id,
errcode=Codes.UNKNOWN_TOKEN,
)
if not stored_user["is_guest"]: if not stored_user["is_guest"]:
raise AuthError( raise InvalidClientTokenError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Guest access token used for regular user"
"Guest access token used for regular user",
errcode=Codes.UNKNOWN_TOKEN,
) )
ret = { ret = {
"user": user, "user": user,
@ -386,11 +375,7 @@ class Auth(object):
ValueError, ValueError,
) as e: ) as e:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e) logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
raise AuthError( raise InvalidClientTokenError("Invalid macaroon passed.")
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN,
)
def _parse_and_validate_macaroon(self, token, rights="access"): def _parse_and_validate_macaroon(self, token, rights="access"):
"""Takes a macaroon and tries to parse and validate it. This is cached """Takes a macaroon and tries to parse and validate it. This is cached
@ -430,11 +415,7 @@ class Auth(object):
macaroon, rights, self.hs.config.expire_access_token, user_id=user_id macaroon, rights, self.hs.config.expire_access_token, user_id=user_id
) )
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError( raise InvalidClientTokenError("Invalid macaroon passed.")
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN,
)
if not has_expiry and rights == "access": if not has_expiry and rights == "access":
self.token_cache[token] = (user_id, guest) self.token_cache[token] = (user_id, guest)
@ -453,17 +434,14 @@ class Auth(object):
(str) user id (str) user id
Raises: Raises:
AuthError if there is no user_id caveat in the macaroon InvalidClientCredentialsError if there is no user_id caveat in the
macaroon
""" """
user_prefix = "user_id = " user_prefix = "user_id = "
for caveat in macaroon.caveats: for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix): if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix) :] return caveat.caveat_id[len(user_prefix) :]
raise AuthError( raise InvalidClientTokenError("No user caveat in macaroon")
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN,
)
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
""" """
@ -531,22 +509,13 @@ class Auth(object):
defer.returnValue(user_info) defer.returnValue(user_info)
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: token = self.get_access_token_from_request(request)
token = self.get_access_token_from_request( service = self.store.get_app_service_by_token(token)
request, self.TOKEN_NOT_FOUND_HTTP_STATUS if not service:
) logger.warn("Unrecognised appservice access token.")
service = self.store.get_app_service_by_token(token) raise InvalidClientTokenError()
if not service: request.authenticated_entity = service.sender
logger.warn("Unrecognised appservice access token.") return defer.succeed(service)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN,
)
request.authenticated_entity = service.sender
return defer.succeed(service)
except KeyError:
raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.")
def is_server_admin(self, user): def is_server_admin(self, user):
""" Check if the given user is a local server admin. """ Check if the given user is a local server admin.
@ -692,20 +661,16 @@ class Auth(object):
return bool(query_params) or bool(auth_headers) return bool(query_params) or bool(auth_headers)
@staticmethod @staticmethod
def get_access_token_from_request(request, token_not_found_http_status=401): def get_access_token_from_request(request):
"""Extracts the access_token from the request. """Extracts the access_token from the request.
Args: Args:
request: The http request. request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns: Returns:
unicode: The access_token unicode: The access_token
Raises: Raises:
AuthError: If there isn't an access_token in the request. MissingClientTokenError: If there isn't a single access_token in the
request
""" """
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@ -714,34 +679,20 @@ class Auth(object):
# Try the get the access_token from a "Authorization: Bearer" # Try the get the access_token from a "Authorization: Bearer"
# header # header
if query_params is not None: if query_params is not None:
raise AuthError( raise MissingClientTokenError(
token_not_found_http_status, "Mixing Authorization headers and access_token query parameters."
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
) )
if len(auth_headers) > 1: if len(auth_headers) > 1:
raise AuthError( raise MissingClientTokenError("Too many Authorization headers.")
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(b" ") parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2: if parts[0] == b"Bearer" and len(parts) == 2:
return parts[1].decode("ascii") return parts[1].decode("ascii")
else: else:
raise AuthError( raise MissingClientTokenError("Invalid Authorization header.")
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else: else:
# Try to get the access_token from the query params. # Try to get the access_token from the query params.
if not query_params: if not query_params:
raise AuthError( raise MissingClientTokenError()
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN,
)
return query_params[0].decode("ascii") return query_params[0].decode("ascii")

View File

@ -210,7 +210,9 @@ class NotFoundError(SynapseError):
class AuthError(SynapseError): class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event.""" """An error raised when there was a problem authorising an event, and at various
other poorly-defined times.
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if "errcode" not in kwargs: if "errcode" not in kwargs:
@ -218,6 +220,35 @@ class AuthError(SynapseError):
super(AuthError, self).__init__(*args, **kwargs) super(AuthError, self).__init__(*args, **kwargs)
class InvalidClientCredentialsError(SynapseError):
"""An error raised when there was a problem with the authorisation credentials
in a client request.
https://matrix.org/docs/spec/client_server/r0.5.0#using-access-tokens:
When credentials are required but missing or invalid, the HTTP call will
return with a status of 401 and the error code, M_MISSING_TOKEN or
M_UNKNOWN_TOKEN respectively.
"""
def __init__(self, msg, errcode):
super().__init__(code=401, msg=msg, errcode=errcode)
class MissingClientTokenError(InvalidClientCredentialsError):
"""Raised when we couldn't find the access token in a request"""
def __init__(self, msg="Missing access token"):
super().__init__(msg=msg, errcode="M_MISSING_TOKEN")
class InvalidClientTokenError(InvalidClientCredentialsError):
"""Raised when we didn't understand the access token in a request"""
def __init__(self, msg="Unrecognised access token"):
super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
class ResourceLimitError(SynapseError): class ResourceLimitError(SynapseError):
""" """
Any error raised when there is a problem with resource usage. Any error raised when there is a problem with resource usage.

View File

@ -18,7 +18,13 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import (
AuthError,
Codes,
InvalidClientCredentialsError,
NotFoundError,
SynapseError,
)
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import RoomAlias from synapse.types import RoomAlias
@ -97,7 +103,7 @@ class ClientDirectoryServer(RestServlet):
room_alias.to_string(), room_alias.to_string(),
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
except AuthError: except InvalidClientCredentialsError:
# fallback to default user behaviour if they aren't an AS # fallback to default user behaviour if they aren't an AS
pass pass

View File

@ -24,7 +24,12 @@ from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import (
AuthError,
Codes,
InvalidClientCredentialsError,
SynapseError,
)
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2 from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -307,7 +312,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
try: try:
yield self.auth.get_user_by_req(request, allow_guest=True) yield self.auth.get_user_by_req(request, allow_guest=True)
except AuthError as e: except InvalidClientCredentialsError as e:
# Option to allow servers to require auth when accessing # Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private # /publicRooms via CS API. This is especially helpful in private
# federations. # federations.

View File

@ -21,7 +21,14 @@ from twisted.internet import defer
import synapse.handlers.auth import synapse.handlers.auth
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError, Codes, ResourceLimitError from synapse.api.errors import (
AuthError,
Codes,
InvalidClientCredentialsError,
InvalidClientTokenError,
MissingClientTokenError,
ResourceLimitError,
)
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
@ -70,7 +77,9 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self): def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"} user_info = {"name": self.test_user, "token_id": "ditto"}
@ -79,7 +88,9 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token(self): def test_get_user_by_req_appservice_valid_token(self):
@ -133,7 +144,9 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
@ -143,7 +156,9 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_missing_token(self): def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
@ -153,7 +168,9 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
d = self.auth.get_user_by_req(request) d = self.auth.get_user_by_req(request)
self.failureResultOf(d, AuthError) f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_appservice_valid_token_valid_user_id(self): def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
@ -280,7 +297,7 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [guest_tok.encode("ascii")] request.args[b"access_token"] = [guest_tok.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(AuthError) as cm: with self.assertRaises(InvalidClientCredentialsError) as cm:
yield self.auth.get_user_by_req(request, allow_guest=True) yield self.auth.get_user_by_req(request, allow_guest=True)
self.assertEqual(401, cm.exception.code) self.assertEqual(401, cm.exception.code)