Save login tokens in database (#13844)

* Save login tokens in database

Signed-off-by: Quentin Gliech <quenting@element.io>

* Add upgrade notes

* Track login token reuse in a Prometheus metric

Signed-off-by: Quentin Gliech <quenting@element.io>
This commit is contained in:
Quentin Gliech 2022-10-26 12:45:41 +02:00 committed by GitHub
parent d902181de9
commit 8756d5c87e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 338 additions and 228 deletions

1
changelog.d/13844.misc Normal file
View File

@ -0,0 +1 @@
Save login tokens in database and prevent login token reuse.

View File

@ -88,6 +88,15 @@ process, for example:
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
``` ```
# Upgrading to v1.71.0
## Removal of the `generate_short_term_login_token` module API method
As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_short_term_login_token-module-api-method), the deprecated `generate_short_term_login_token` module method has been removed.
Modules relying on it can instead use the `create_login_token` method.
# Upgrading to v1.69.0 # Upgrading to v1.69.0
## Changes to the receipts replication streams ## Changes to the receipts replication streams

View File

@ -38,6 +38,7 @@ from typing import (
import attr import attr
import bcrypt import bcrypt
import unpaddedbase64 import unpaddedbase64
from prometheus_client import Counter
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.web.server import Request from twisted.web.server import Request
@ -48,6 +49,7 @@ from synapse.api.errors import (
Codes, Codes,
InteractiveAuthIncompleteError, InteractiveAuthIncompleteError,
LoginError, LoginError,
NotFoundError,
StoreError, StoreError,
SynapseError, SynapseError,
UserDeactivatedError, UserDeactivatedError,
@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main.registration import (
LoginTokenExpired,
LoginTokenLookupResult,
LoginTokenReused,
)
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
from synapse.util.macaroons import LoginTokenAttributes
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import base62_encode from synapse.util.stringutils import base62_encode
from synapse.util.threepids import canonicalise_email from synapse.util.threepids import canonicalise_email
@ -80,6 +86,12 @@ logger = logging.getLogger(__name__)
INVALID_USERNAME_OR_PASSWORD = "Invalid username or password" INVALID_USERNAME_OR_PASSWORD = "Invalid username or password"
invalid_login_token_counter = Counter(
"synapse_user_login_invalid_login_tokens",
"Counts the number of rejected m.login.token on /login",
["reason"],
)
def convert_client_dict_legacy_fields_to_identifier( def convert_client_dict_legacy_fields_to_identifier(
submission: JsonDict, submission: JsonDict,
@ -883,6 +895,25 @@ class AuthHandler:
return True return True
async def create_login_token_for_user_id(
self,
user_id: str,
duration_ms: int = (2 * 60 * 1000),
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> str:
login_token = self.generate_login_token()
now = self._clock.time_msec()
expiry_ts = now + duration_ms
await self.store.add_login_token_to_user(
user_id=user_id,
token=login_token,
expiry_ts=expiry_ts,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
return login_token
async def create_refresh_token_for_user_id( async def create_refresh_token_for_user_id(
self, self,
user_id: str, user_id: str,
@ -1401,6 +1432,18 @@ class AuthHandler:
return None return None
return user_id return user_id
def generate_login_token(self) -> str:
"""Generates an opaque string, for use as an short-term login token"""
# we use the following format for access tokens:
# syl_<random string>_<base62 crc check>
random_string = stringutils.random_string(20)
base = f"syl_{random_string}"
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return f"{base}_{crc}"
def generate_access_token(self, for_user: UserID) -> str: def generate_access_token(self, for_user: UserID) -> str:
"""Generates an opaque string, for use as an access token""" """Generates an opaque string, for use as an access token"""
@ -1427,16 +1470,17 @@ class AuthHandler:
crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
return f"{base}_{crc}" return f"{base}_{crc}"
async def validate_short_term_login_token( async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult:
self, login_token: str
) -> LoginTokenAttributes:
try: try:
res = self.macaroon_gen.verify_short_term_login_token(login_token) return await self.store.consume_login_token(login_token)
except Exception: except LoginTokenExpired:
raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) invalid_login_token_counter.labels("expired").inc()
except LoginTokenReused:
invalid_login_token_counter.labels("reused").inc()
except NotFoundError:
invalid_login_token_counter.labels("not found").inc()
await self.auth_blocking.check_auth_blocking(res.user_id) raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
return res
async def delete_access_token(self, access_token: str) -> None: async def delete_access_token(self, access_token: str) -> None:
"""Invalidate a single access token """Invalidate a single access token
@ -1711,7 +1755,7 @@ class AuthHandler:
) )
# Create a login token # Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token( login_token = await self.create_login_token_for_user_id(
registered_user_id, registered_user_id,
auth_provider_id=auth_provider_id, auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id, auth_provider_session_id=auth_provider_session_id,

View File

@ -771,50 +771,11 @@ class ModuleApi:
auth_provider_session_id: The session ID got during login from the SSO IdP, auth_provider_session_id: The session ID got during login from the SSO IdP,
if any. if any.
""" """
# The deprecated `generate_short_term_login_token` method defaulted to an empty return await self._hs.get_auth_handler().create_login_token_for_user_id(
# string for the `auth_provider_id` because of how the underlying macaroon was
# generated. This will change to a proper NULL-able field when the tokens get
# moved to the database.
return self._hs.get_macaroon_generator().generate_short_term_login_token(
user_id, user_id,
auth_provider_id or "",
auth_provider_session_id,
duration_in_ms, duration_in_ms,
)
def generate_short_term_login_token(
self,
user_id: str,
duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: str = "",
auth_provider_session_id: Optional[str] = None,
) -> str:
"""Generate a login token suitable for m.login.token authentication
Added in Synapse v1.9.0.
This was deprecated in Synapse v1.69.0 in favor of create_login_token, and will
be removed in Synapse 1.71.0.
Args:
user_id: gives the ID of the user that the token is for
duration_in_ms: the time that the token will be valid for
auth_provider_id: the ID of the SSO IdP that the user used to authenticate
to get this token, if any. This is encoded in the token so that
/login can report stats on number of successful logins by IdP.
"""
logger.warn(
"A module configured on this server uses ModuleApi.generate_short_term_login_token(), "
"which is deprecated in favor of ModuleApi.create_login_token(), and will be removed in "
"Synapse 1.71.0",
)
return self._hs.get_macaroon_generator().generate_short_term_login_token(
user_id,
auth_provider_id, auth_provider_id,
auth_provider_session_id, auth_provider_session_id,
duration_in_ms,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -436,8 +436,7 @@ class LoginRestServlet(RestServlet):
The body of the JSON response. The body of the JSON response.
""" """
token = login_submission["token"] token = login_submission["token"]
auth_handler = self.auth_handler res = await self.auth_handler.consume_login_token(token)
res = await auth_handler.validate_short_term_login_token(token)
return await self._complete_login( return await self._complete_login(
res.user_id, res.user_id,

View File

@ -57,7 +57,6 @@ class LoginTokenRequestServlet(RestServlet):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.config.server.server_name self.server_name = hs.config.server.server_name
self.macaroon_gen = hs.get_macaroon_generator()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.token_timeout = hs.config.experimental.msc3882_token_timeout self.token_timeout = hs.config.experimental.msc3882_token_timeout
self.ui_auth = hs.config.experimental.msc3882_ui_auth self.ui_auth = hs.config.experimental.msc3882_ui_auth
@ -76,10 +75,10 @@ class LoginTokenRequestServlet(RestServlet):
can_skip_ui_auth=False, # Don't allow skipping of UI auth can_skip_ui_auth=False, # Don't allow skipping of UI auth
) )
login_token = self.macaroon_gen.generate_short_term_login_token( login_token = await self.auth_handler.create_login_token_for_user_id(
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
auth_provider_id="org.matrix.msc3882.login_token_request", auth_provider_id="org.matrix.msc3882.login_token_request",
duration_in_ms=self.token_timeout, duration_ms=self.token_timeout,
) )
return ( return (

View File

@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
import attr import attr
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import (
Codes,
NotFoundError,
StoreError,
SynapseError,
ThreepidValidationError,
)
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import ( from synapse.storage.database import (
@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception):
because this external id is given to an other user.""" because this external id is given to an other user."""
class LoginTokenExpired(Exception):
"""Exception if the login token sent expired"""
class LoginTokenReused(Exception):
"""Exception if the login token sent was already used"""
@attr.s(frozen=True, slots=True, auto_attribs=True) @attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult: class TokenLookupResult:
"""Result of looking up an access token. """Result of looking up an access token.
@ -115,6 +129,20 @@ class RefreshTokenLookupResult:
If None, the session can be refreshed indefinitely.""" If None, the session can be refreshed indefinitely."""
@attr.s(auto_attribs=True, frozen=True, slots=True)
class LoginTokenLookupResult:
"""Result of looking up a login token."""
user_id: str
"""The user this token belongs to."""
auth_provider_id: Optional[str]
"""The SSO Identity Provider that the user authenticated with, to get this token."""
auth_provider_session_id: Optional[str]
"""The session ID advertised by the SSO Identity Provider."""
class RegistrationWorkerStore(CacheInvalidationWorkerStore): class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__( def __init__(
self, self,
@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn "replace_refresh_token", _replace_refresh_token_txn
) )
async def add_login_token_to_user(
self,
user_id: str,
token: str,
expiry_ts: int,
auth_provider_id: Optional[str],
auth_provider_session_id: Optional[str],
) -> None:
"""Adds a short-term login token for the given user.
Args:
user_id: The user ID.
token: The new login token to add.
expiry_ts (milliseconds since the epoch): Time after which the login token
cannot be used.
auth_provider_id: The SSO Identity Provider that the user authenticated with
to get this token, if any
auth_provider_session_id: The session ID advertised by the SSO Identity
Provider, if any.
"""
await self.db_pool.simple_insert(
"login_tokens",
{
"token": token,
"user_id": user_id,
"expiry_ts": expiry_ts,
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
desc="add_login_token_to_user",
)
def _consume_login_token(
self,
txn: LoggingTransaction,
token: str,
ts: int,
) -> LoginTokenLookupResult:
values = self.db_pool.simple_select_one_txn(
txn,
"login_tokens",
keyvalues={"token": token},
retcols=(
"user_id",
"expiry_ts",
"used_ts",
"auth_provider_id",
"auth_provider_session_id",
),
allow_none=True,
)
if values is None:
raise NotFoundError()
self.db_pool.simple_update_one_txn(
txn,
"login_tokens",
keyvalues={"token": token},
updatevalues={"used_ts": ts},
)
user_id = values["user_id"]
expiry_ts = values["expiry_ts"]
used_ts = values["used_ts"]
auth_provider_id = values["auth_provider_id"]
auth_provider_session_id = values["auth_provider_session_id"]
# Token was already used
if used_ts is not None:
raise LoginTokenReused()
# Token expired
if ts > int(expiry_ts):
raise LoginTokenExpired()
return LoginTokenLookupResult(
user_id=user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
async def consume_login_token(self, token: str) -> LoginTokenLookupResult:
"""Lookup a login token and consume it.
Args:
token: The login token.
Returns:
The data stored with that token, including the `user_id`. Returns `None` if
the token does not exist or if it expired.
Raises:
NotFound if the login token was not found in database
LoginTokenExpired if the login token expired
LoginTokenReused if the login token was already used
"""
return await self.db_pool.runInteraction(
"consume_login_token",
self._consume_login_token,
token,
self._clock.time_msec(),
)
@cached() @cached()
async def is_guest(self, user_id: str) -> bool: async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol( res = await self.db_pool.simple_select_one_onecol(
@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
and hs.config.experimental.msc3866.require_approval_for_new_accounts and hs.config.experimental.msc3866.require_approval_for_new_accounts
) )
# Create a background job for removing expired login tokens
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS
)
async def add_access_token_to_user( async def add_access_token_to_user(
self, self,
user_id: str, user_id: str,
@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
approved, approved,
) )
@wrap_as_background_process("delete_expired_login_tokens")
async def _delete_expired_login_tokens(self) -> None:
"""Remove login tokens with expiry dates that have passed."""
def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None:
sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?"
txn.execute(sql, (ts,))
# We keep the expired tokens for an extra 5 minutes so we can measure how many
# times a token is being used after its expiry
now = self._clock.time_msec()
await self.db_pool.runInteraction(
"delete_expired_login_tokens",
_delete_expired_login_tokens_txn,
now - (5 * 60 * 1000),
)
def find_max_generated_user_id_localpart(cur: Cursor) -> int: def find_max_generated_user_id_localpart(cur: Cursor) -> int:
""" """

View File

@ -0,0 +1,35 @@
/*
* Copyright 2022 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Login tokens are short-lived tokens that are used for the m.login.token
-- login method, mainly during SSO logins
CREATE TABLE login_tokens (
token TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
expiry_ts BIGINT NOT NULL,
used_ts BIGINT,
auth_provider_id TEXT,
auth_provider_session_id TEXT
);
-- We're sometimes querying them by their session ID we got from their IDP
CREATE INDEX login_tokens_auth_provider_idx
ON login_tokens (auth_provider_id, auth_provider_session_id);
-- We're deleting them by their expiration time
CREATE INDEX login_tokens_expiry_time_idx
ON login_tokens (expiry_ts);

View File

@ -24,7 +24,7 @@ from typing_extensions import Literal
from synapse.util import Clock, stringutils from synapse.util import Clock, stringutils
MacaroonType = Literal["access", "delete_pusher", "session", "login"] MacaroonType = Literal["access", "delete_pusher", "session"]
def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
@ -111,19 +111,6 @@ class OidcSessionData:
"""The session ID of the ongoing UI Auth ("" if this is a login)""" """The session ID of the ongoing UI Auth ("" if this is a login)"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class LoginTokenAttributes:
"""Data we store in a short-term login token"""
user_id: str
auth_provider_id: str
"""The SSO Identity Provider that the user authenticated with, to get this token."""
auth_provider_session_id: Optional[str]
"""The session ID advertised by the SSO Identity Provider."""
class MacaroonGenerator: class MacaroonGenerator:
def __init__(self, clock: Clock, location: str, secret_key: bytes): def __init__(self, clock: Clock, location: str, secret_key: bytes):
self._clock = clock self._clock = clock
@ -165,35 +152,6 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat(f"pushkey = {pushkey}") macaroon.add_first_party_caveat(f"pushkey = {pushkey}")
return macaroon.serialize() return macaroon.serialize()
def generate_short_term_login_token(
self,
user_id: str,
auth_provider_id: str,
auth_provider_session_id: Optional[str] = None,
duration_in_ms: int = (2 * 60 * 1000),
) -> str:
"""Generate a short-term login token used during SSO logins
Args:
user_id: The user for which the token is valid.
auth_provider_id: The SSO IdP the user used.
auth_provider_session_id: The session ID got during login from the SSO IdP.
Returns:
A signed token valid for using as a ``m.login.token`` token.
"""
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon = self._generate_base_macaroon("login")
macaroon.add_first_party_caveat(f"user_id = {user_id}")
macaroon.add_first_party_caveat(f"time < {expiry}")
macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}")
if auth_provider_session_id is not None:
macaroon.add_first_party_caveat(
f"auth_provider_session_id = {auth_provider_session_id}"
)
return macaroon.serialize()
def generate_oidc_session_token( def generate_oidc_session_token(
self, self,
state: str, state: str,
@ -233,49 +191,6 @@ class MacaroonGenerator:
return macaroon.serialize() return macaroon.serialize()
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
"""Verify a short-term-login macaroon
Checks that the given token is a valid, unexpired short-term-login token
minted by this server.
Args:
token: The login token to verify.
Returns:
A set of attributes carried by this token, including the
``user_id`` and informations about the SSO IDP used during that
login.
Raises:
MacaroonVerificationFailedException if the verification failed
"""
macaroon = pymacaroons.Macaroon.deserialize(token)
v = self._base_verifier("login")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = "))
satisfy_expiry(v, self._clock.time_msec)
v.verify(macaroon, self._secret_key)
user_id = get_value_from_macaroon(macaroon, "user_id")
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
auth_provider_session_id: Optional[str] = None
try:
auth_provider_session_id = get_value_from_macaroon(
macaroon, "auth_provider_session_id"
)
except MacaroonVerificationFailedException:
pass
return LoginTokenAttributes(
user_id=user_id,
auth_provider_id=auth_provider_id,
auth_provider_session_id=auth_provider_session_id,
)
def verify_guest_token(self, token: str) -> str: def verify_guest_token(self, token: str) -> str:
"""Verify a guest access token macaroon """Verify a guest access token macaroon

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from unittest.mock import Mock from unittest.mock import Mock
import pymacaroons import pymacaroons
@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import AuthError, ResourceLimitError from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable
class AuthTestCase(unittest.HomeserverTestCase): class AuthTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
admin.register_servlets, admin.register_servlets,
login.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1 = self.register_user("a_user", "pass") self.user1 = self.register_user("a_user", "pass")
def token_login(self, token: str) -> Optional[str]:
body = {
"type": "m.login.token",
"token": token,
}
channel = self.make_request(
"POST",
"/_matrix/client/v3/login",
body,
)
if channel.code == 200:
return channel.json_body["user_id"]
return None
def test_macaroon_caveats(self) -> None: def test_macaroon_caveats(self) -> None:
token = self.macaroon_generator.generate_guest_access_token("a_user") token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
@ -73,48 +93,61 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_guest) v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key) v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self) -> None: def test_login_token_gives_user_id(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token( token = self.get_success(
self.user1, "", duration_in_ms=5000 self.auth_handler.create_login_token_for_user_id(
self.user1,
duration_ms=(5 * 1000),
)
) )
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
res = self.get_success(self.auth_handler.consume_login_token(token))
self.assertEqual(self.user1, res.user_id) self.assertEqual(self.user1, res.user_id)
self.assertEqual("", res.auth_provider_id) self.assertEqual(None, res.auth_provider_id)
def test_login_token_reuse_fails(self) -> None:
token = self.get_success(
self.auth_handler.create_login_token_for_user_id(
self.user1,
duration_ms=(5 * 1000),
)
)
self.get_success(self.auth_handler.consume_login_token(token))
self.get_failure(
self.auth_handler.consume_login_token(token),
AuthError,
)
def test_login_token_expires(self) -> None:
token = self.get_success(
self.auth_handler.create_login_token_for_user_id(
self.user1,
duration_ms=(5 * 1000),
)
)
# when we advance the clock, the token should be rejected # when we advance the clock, the token should be rejected
self.reactor.advance(6) self.reactor.advance(6)
self.get_failure( self.get_failure(
self.auth_handler.validate_short_term_login_token(token), self.auth_handler.consume_login_token(token),
AuthError, AuthError,
) )
def test_short_term_login_token_gives_auth_provider(self) -> None: def test_login_token_gives_auth_provider(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token( token = self.get_success(
self.user1, auth_provider_id="my_idp" self.auth_handler.create_login_token_for_user_id(
self.user1,
auth_provider_id="my_idp",
auth_provider_session_id="11-22-33-44",
duration_ms=(5 * 1000),
)
) )
res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) res = self.get_success(self.auth_handler.consume_login_token(token))
self.assertEqual(self.user1, res.user_id) self.assertEqual(self.user1, res.user_id)
self.assertEqual("my_idp", res.auth_provider_id) self.assertEqual("my_idp", res.auth_provider_id)
self.assertEqual("11-22-33-44", res.auth_provider_session_id)
def test_short_term_login_token_cannot_replace_user_id(self) -> None:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
macaroon = pymacaroons.Macaroon.deserialize(token)
res = self.get_success(
self.auth_handler.validate_short_term_login_token(macaroon.serialize())
)
self.assertEqual(self.user1, res.user_id)
# add another "user_id" caveat, which might allow us to override the
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
self.get_failure(
self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
AuthError,
)
def test_mau_limits_disabled(self) -> None: def test_mau_limits_disabled(self) -> None:
self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
) )
self.get_success( token = self.get_success(
self.auth_handler.validate_short_term_login_token( self.auth_handler.create_login_token_for_user_id(self.user1)
self._get_macaroon().serialize()
)
) )
self.assertIsNotNone(self.token_login(token))
def test_mau_limits_exceeded_large(self) -> None: def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock( self.hs.get_datastores().main.get_monthly_active_count = Mock(
@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.get_monthly_active_count = Mock( self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users) return_value=make_awaitable(self.large_number_of_users)
) )
self.get_failure( token = self.get_success(
self.auth_handler.validate_short_term_login_token( self.auth_handler.create_login_token_for_user_id(self.user1)
self._get_macaroon().serialize()
),
ResourceLimitError,
) )
self.assertIsNone(self.token_login(token))
def test_mau_limits_parity(self) -> None: def test_mau_limits_parity(self) -> None:
# Ensure we're not at the unix epoch. # Ensure we're not at the unix epoch.
@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
), ),
ResourceLimitError, ResourceLimitError,
) )
self.get_failure( token = self.get_success(
self.auth_handler.validate_short_term_login_token( self.auth_handler.create_login_token_for_user_id(self.user1)
self._get_macaroon().serialize()
),
ResourceLimitError,
) )
self.assertIsNone(self.token_login(token))
# If in monthly active cohort # If in monthly active cohort
self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1, device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
) )
) )
self.get_success( token = self.get_success(
self.auth_handler.validate_short_term_login_token( self.auth_handler.create_login_token_for_user_id(self.user1)
self._get_macaroon().serialize()
)
) )
self.assertIsNotNone(self.token_login(token))
def test_mau_limits_not_exceeded(self) -> None: def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.get_monthly_active_count = Mock( self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.small_number_of_users) return_value=make_awaitable(self.small_number_of_users)
) )
self.get_success( token = self.get_success(
self.auth_handler.validate_short_term_login_token( self.auth_handler.create_login_token_for_user_id(self.user1)
self._get_macaroon().serialize()
)
) )
self.assertIsNotNone(self.token_login(token))
def _get_macaroon(self) -> pymacaroons.Macaroon:
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
return pymacaroons.Macaroon.deserialize(token)

View File

@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase):
) )
self.assertEqual(user_id, "@user:tesths") self.assertEqual(user_id, "@user:tesths")
def test_short_term_login_token(self):
"""Test the generation and verification of short-term login tokens"""
token = self.macaroon_generator.generate_short_term_login_token(
user_id="@user:tesths",
auth_provider_id="oidc",
auth_provider_session_id="sid",
duration_in_ms=2 * 60 * 1000,
)
info = self.macaroon_generator.verify_short_term_login_token(token)
self.assertEqual(info.user_id, "@user:tesths")
self.assertEqual(info.auth_provider_id, "oidc")
self.assertEqual(info.auth_provider_session_id, "sid")
# Raises with another secret key
with self.assertRaises(MacaroonVerificationFailedException):
self.other_macaroon_generator.verify_short_term_login_token(token)
# Wait a minute
self.reactor.pump([60])
# Shouldn't raise
self.macaroon_generator.verify_short_term_login_token(token)
# Wait another minute
self.reactor.pump([60])
# Should raise since it expired
with self.assertRaises(MacaroonVerificationFailedException):
self.macaroon_generator.verify_short_term_login_token(token)
def test_oidc_session_token(self): def test_oidc_session_token(self):
"""Test the generation and verification of OIDC session cookies""" """Test the generation and verification of OIDC session cookies"""
state = "arandomstate" state = "arandomstate"