mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Merge pull request #8858 from matrix-org/rav/sso_uia
UIA: offer only available auth flows
This commit is contained in:
commit
ed5172852a
1
changelog.d/8858.bugfix
Normal file
1
changelog.d/8858.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password.
|
@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
hs (synapse.server.HomeServer): homeserver
|
hs (synapse.server.HomeServer): homeserver
|
||||||
resource (TransportLayerServer): resource class to register to
|
resource (JsonResource): resource class to register to
|
||||||
authenticator (Authenticator): authenticator to use
|
authenticator (Authenticator): authenticator to use
|
||||||
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
|
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
|
||||||
servlet_groups (list[str], optional): List of servlet groups to register.
|
servlet_groups (list[str], optional): List of servlet groups to register.
|
||||||
|
@ -193,9 +193,7 @@ class AuthHandler(BaseHandler):
|
|||||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
self._password_enabled = hs.config.password_enabled
|
self._password_enabled = hs.config.password_enabled
|
||||||
self._sso_enabled = (
|
self._password_localdb_enabled = hs.config.password_localdb_enabled
|
||||||
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
|
|
||||||
)
|
|
||||||
|
|
||||||
# we keep this as a list despite the O(N^2) implication so that we can
|
# we keep this as a list despite the O(N^2) implication so that we can
|
||||||
# keep PASSWORD first and avoid confusing clients which pick the first
|
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||||
@ -205,7 +203,7 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
||||||
login_types = []
|
login_types = []
|
||||||
if hs.config.password_localdb_enabled:
|
if self._password_localdb_enabled:
|
||||||
login_types.append(LoginType.PASSWORD)
|
login_types.append(LoginType.PASSWORD)
|
||||||
|
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
@ -219,14 +217,6 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
self._supported_login_types = login_types
|
self._supported_login_types = login_types
|
||||||
|
|
||||||
# Login types and UI Auth types have a heavy overlap, but are not
|
|
||||||
# necessarily identical. Login types have SSO (and other login types)
|
|
||||||
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
|
|
||||||
ui_auth_types = login_types.copy()
|
|
||||||
if self._sso_enabled:
|
|
||||||
ui_auth_types.append(LoginType.SSO)
|
|
||||||
self._supported_ui_auth_types = ui_auth_types
|
|
||||||
|
|
||||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||||
# as per `rc_login.failed_attempts`.
|
# as per `rc_login.failed_attempts`.
|
||||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||||
@ -339,7 +329,10 @@ class AuthHandler(BaseHandler):
|
|||||||
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
|
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
|
||||||
|
|
||||||
# build a list of supported flows
|
# build a list of supported flows
|
||||||
flows = [[login_type] for login_type in self._supported_ui_auth_types]
|
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
||||||
|
requester.user
|
||||||
|
)
|
||||||
|
flows = [[login_type] for login_type in supported_ui_auth_types]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result, params, session_id = await self.check_ui_auth(
|
result, params, session_id = await self.check_ui_auth(
|
||||||
@ -351,7 +344,7 @@ class AuthHandler(BaseHandler):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
# find the completed login type
|
# find the completed login type
|
||||||
for login_type in self._supported_ui_auth_types:
|
for login_type in supported_ui_auth_types:
|
||||||
if login_type not in result:
|
if login_type not in result:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -367,6 +360,41 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
return params, session_id
|
return params, session_id
|
||||||
|
|
||||||
|
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
|
||||||
|
"""Get a list of the authentication types this user can use
|
||||||
|
"""
|
||||||
|
|
||||||
|
ui_auth_types = set()
|
||||||
|
|
||||||
|
# if the HS supports password auth, and the user has a non-null password, we
|
||||||
|
# support password auth
|
||||||
|
if self._password_localdb_enabled and self._password_enabled:
|
||||||
|
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
|
||||||
|
if lookupres:
|
||||||
|
_, password_hash = lookupres
|
||||||
|
if password_hash:
|
||||||
|
ui_auth_types.add(LoginType.PASSWORD)
|
||||||
|
|
||||||
|
# also allow auth from password providers
|
||||||
|
for provider in self.password_providers:
|
||||||
|
for t in provider.get_supported_login_types().keys():
|
||||||
|
if t == LoginType.PASSWORD and not self._password_enabled:
|
||||||
|
continue
|
||||||
|
ui_auth_types.add(t)
|
||||||
|
|
||||||
|
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
||||||
|
# from sso to mxid.
|
||||||
|
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
|
||||||
|
if await self.store.get_external_ids_by_user(user.to_string()):
|
||||||
|
ui_auth_types.add(LoginType.SSO)
|
||||||
|
|
||||||
|
# Our CAS impl does not (yet) correctly register users in user_external_ids,
|
||||||
|
# so always offer that if it's available.
|
||||||
|
if self.hs.config.cas.cas_enabled:
|
||||||
|
ui_auth_types.add(LoginType.SSO)
|
||||||
|
|
||||||
|
return ui_auth_types
|
||||||
|
|
||||||
def get_enabled_auth_types(self):
|
def get_enabled_auth_types(self):
|
||||||
"""Return the enabled user-interactive authentication types
|
"""Return the enabled user-interactive authentication types
|
||||||
|
|
||||||
@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler):
|
|||||||
if result:
|
if result:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
|
||||||
known_login_type = True
|
known_login_type = True
|
||||||
|
|
||||||
# we've already checked that there is a (valid) password field
|
# we've already checked that there is a (valid) password field
|
||||||
|
@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
desc="get_user_by_external_id",
|
desc="get_user_by_external_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
|
||||||
|
"""Look up external ids for the given user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mxid: the MXID to be looked up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuples of (auth_provider, external_id)
|
||||||
|
"""
|
||||||
|
res = await self.db_pool.simple_select_list(
|
||||||
|
table="user_external_ids",
|
||||||
|
keyvalues={"user_id": mxid},
|
||||||
|
retcols=("auth_provider", "external_id"),
|
||||||
|
desc="get_external_ids_by_user",
|
||||||
|
)
|
||||||
|
return [(r["auth_provider"], r["external_id"]) for r in res]
|
||||||
|
|
||||||
async def count_all_users(self):
|
async def count_all_users(self):
|
||||||
"""Counts all users registered on the homeserver."""
|
"""Counts all users registered on the homeserver."""
|
||||||
|
|
||||||
@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||||||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.db_pool.updates.register_background_index_update(
|
||||||
|
"user_external_ids_user_id_idx",
|
||||||
|
index_name="user_external_ids_user_id_idx",
|
||||||
|
table="user_external_ids",
|
||||||
|
columns=["user_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
async def _background_update_set_deactivated_flag(self, progress, batch_size):
|
async def _background_update_set_deactivated_flag(self, progress, batch_size):
|
||||||
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
||||||
for each of them.
|
for each of them.
|
||||||
|
@ -0,0 +1,17 @@
|
|||||||
|
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||||
|
(5825, 'user_external_ids_user_id_idx', '{}');
|
@ -17,30 +17,15 @@ from urllib.parse import parse_qs, urlparse
|
|||||||
|
|
||||||
from mock import Mock, patch
|
from mock import Mock, patch
|
||||||
|
|
||||||
import attr
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
from twisted.python.failure import Failure
|
|
||||||
from twisted.web._newclient import ResponseDone
|
|
||||||
|
|
||||||
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
|
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
|
||||||
from synapse.handlers.sso import MappingException
|
from synapse.handlers.sso import MappingException
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
from tests.test_utils import FakeResponse
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
|
||||||
class FakeResponse:
|
|
||||||
code = attr.ib()
|
|
||||||
body = attr.ib()
|
|
||||||
phrase = attr.ib()
|
|
||||||
|
|
||||||
def deliverBody(self, protocol):
|
|
||||||
protocol.dataReceived(self.body)
|
|
||||||
protocol.connectionLost(Failure(ResponseDone()))
|
|
||||||
|
|
||||||
|
|
||||||
# These are a few constants that are used as config parameters in the tests.
|
# These are a few constants that are used as config parameters in the tests.
|
||||||
ISSUER = "https://issuer/"
|
ISSUER = "https://issuer/"
|
||||||
CLIENT_ID = "test-client-id"
|
CLIENT_ID = "test-client-id"
|
||||||
|
@ -15,18 +15,20 @@
|
|||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from mock import ANY, Mock, call
|
from mock import ANY, Mock, call
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
|
from synapse.federation.transport.server import TransportLayerServer
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
from tests.utils import register_federation_servlets
|
|
||||||
|
|
||||||
# Some local users to test with
|
# Some local users to test with
|
||||||
U_APPLE = UserID.from_string("@apple:test")
|
U_APPLE = UserID.from_string("@apple:test")
|
||||||
@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
|
|||||||
|
|
||||||
|
|
||||||
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
servlets = [register_federation_servlets]
|
|
||||||
|
|
||||||
def make_homeserver(self, reactor, clock):
|
def make_homeserver(self, reactor, clock):
|
||||||
# we mock out the keyring so as to skip the authentication check on the
|
# we mock out the keyring so as to skip the authentication check on the
|
||||||
# federation API call.
|
# federation API call.
|
||||||
@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
|
d = super().create_resource_dict()
|
||||||
|
d["/_matrix/federation"] = TransportLayerServer(self.hs)
|
||||||
|
return d
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
mock_notifier = hs.get_notifier()
|
mock_notifier = hs.get_notifier()
|
||||||
self.on_new_event = mock_notifier.on_new_event
|
self.on_new_event = mock_notifier.on_new_event
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
|||||||
from twisted.internet.protocol import Protocol
|
from twisted.internet.protocol import Protocol
|
||||||
from twisted.internet.task import LoopingCall
|
from twisted.internet.task import LoopingCall
|
||||||
from twisted.web.http import HTTPChannel
|
from twisted.web.http import HTTPChannel
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from synapse.app.generic_worker import (
|
from synapse.app.generic_worker import (
|
||||||
GenericWorkerReplicationHandler,
|
GenericWorkerReplicationHandler,
|
||||||
@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
|
|||||||
)
|
)
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
from synapse.http.site import SynapseRequest, SynapseSite
|
from synapse.http.site import SynapseRequest, SynapseSite
|
||||||
from synapse.replication.http import ReplicationRestResource, streams
|
from synapse.replication.http import ReplicationRestResource
|
||||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||||
@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||||||
if not hiredis:
|
if not hiredis:
|
||||||
skip = "Requires hiredis"
|
skip = "Requires hiredis"
|
||||||
|
|
||||||
servlets = [
|
|
||||||
streams.register_servlets,
|
|
||||||
]
|
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
# build a replication server
|
# build a replication server
|
||||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||||
@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||||||
self._client_transport = None
|
self._client_transport = None
|
||||||
self._server_transport = None
|
self._server_transport = None
|
||||||
|
|
||||||
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
|
d = super().create_resource_dict()
|
||||||
|
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
|
||||||
|
return d
|
||||||
|
|
||||||
def _get_worker_hs_config(self) -> dict:
|
def _get_worker_hs_config(self) -> dict:
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
config["worker_app"] = "synapse.app.generic_worker"
|
config["worker_app"] = "synapse.app.generic_worker"
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
# Copyright 2017 Vector Creations Ltd
|
# Copyright 2017 Vector Creations Ltd
|
||||||
# Copyright 2018-2019 New Vector Ltd
|
# Copyright 2018-2019 New Vector Ltd
|
||||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -17,17 +17,23 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
|
import urllib.parse
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from mock import patch
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import Site
|
from twisted.web.server import Site
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
from tests.server import FakeSite, make_request
|
from tests.server import FakeSite, make_request
|
||||||
|
from tests.test_utils import FakeResponse
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
@ -344,3 +350,111 @@ class RestHelper:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return channel.json_body
|
return channel.json_body
|
||||||
|
|
||||||
|
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
|
||||||
|
"""Log in (as a new user) via OIDC
|
||||||
|
|
||||||
|
Returns the result of the final token login.
|
||||||
|
|
||||||
|
Requires that "oidc_config" in the homeserver config be set appropriately
|
||||||
|
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
|
||||||
|
"public_base_url".
|
||||||
|
|
||||||
|
Also requires the login servlet and the OIDC callback resource to be mounted at
|
||||||
|
the normal places.
|
||||||
|
"""
|
||||||
|
client_redirect_url = "https://x"
|
||||||
|
|
||||||
|
# first hit the redirect url (which will issue a cookie and state)
|
||||||
|
_, channel = make_request(
|
||||||
|
self.hs.get_reactor(),
|
||||||
|
self.site,
|
||||||
|
"GET",
|
||||||
|
"/login/sso/redirect?redirectUrl=" + client_redirect_url,
|
||||||
|
)
|
||||||
|
# that will redirect to the OIDC IdP, but we skip that and go straight
|
||||||
|
# back to synapse's OIDC callback resource. However, we do need the "state"
|
||||||
|
# param that synapse passes to the IdP via query params, and the cookie that
|
||||||
|
# synapse passes to the client.
|
||||||
|
assert channel.code == 302
|
||||||
|
oauth_uri = channel.headers.getRawHeaders("Location")[0]
|
||||||
|
params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
|
||||||
|
redirect_uri = "%s?%s" % (
|
||||||
|
urllib.parse.urlparse(params["redirect_uri"][0]).path,
|
||||||
|
urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
|
||||||
|
)
|
||||||
|
cookies = {}
|
||||||
|
for h in channel.headers.getRawHeaders("Set-Cookie"):
|
||||||
|
parts = h.split(";")
|
||||||
|
k, v = parts[0].split("=", maxsplit=1)
|
||||||
|
cookies[k] = v
|
||||||
|
|
||||||
|
# before we hit the callback uri, stub out some methods in the http client so
|
||||||
|
# that we don't have to handle full HTTPS requests.
|
||||||
|
|
||||||
|
# (expected url, json response) pairs, in the order we expect them.
|
||||||
|
expected_requests = [
|
||||||
|
# first we get a hit to the token endpoint, which we tell to return
|
||||||
|
# a dummy OIDC access token
|
||||||
|
("https://issuer.test/token", {"access_token": "TEST"}),
|
||||||
|
# and then one to the user_info endpoint, which returns our remote user id.
|
||||||
|
("https://issuer.test/userinfo", {"sub": remote_user_id}),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def mock_req(method: str, uri: str, data=None, headers=None):
|
||||||
|
(expected_uri, resp_obj) = expected_requests.pop(0)
|
||||||
|
assert uri == expected_uri
|
||||||
|
resp = FakeResponse(
|
||||||
|
code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
|
||||||
|
)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
|
||||||
|
# now hit the callback URI with the right params and a made-up code
|
||||||
|
_, channel = make_request(
|
||||||
|
self.hs.get_reactor(),
|
||||||
|
self.site,
|
||||||
|
"GET",
|
||||||
|
redirect_uri,
|
||||||
|
custom_headers=[
|
||||||
|
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# expect a confirmation page
|
||||||
|
assert channel.code == 200
|
||||||
|
|
||||||
|
# fish the matrix login token out of the body of the confirmation page
|
||||||
|
m = re.search(
|
||||||
|
'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
|
||||||
|
channel.result["body"].decode("utf-8"),
|
||||||
|
)
|
||||||
|
assert m
|
||||||
|
login_token = m.group(1)
|
||||||
|
|
||||||
|
# finally, submit the matrix login token to the login API, which gives us our
|
||||||
|
# matrix access token and device id.
|
||||||
|
_, channel = make_request(
|
||||||
|
self.hs.get_reactor(),
|
||||||
|
self.site,
|
||||||
|
"POST",
|
||||||
|
"/login",
|
||||||
|
content={"type": "m.login.token", "token": login_token},
|
||||||
|
)
|
||||||
|
assert channel.code == 200
|
||||||
|
return channel.json_body
|
||||||
|
|
||||||
|
|
||||||
|
# an 'oidc_config' suitable for login_with_oidc.
|
||||||
|
TEST_OIDC_CONFIG = {
|
||||||
|
"enabled": True,
|
||||||
|
"discover": False,
|
||||||
|
"issuer": "https://issuer.test",
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"client_secret": "test-client-secret",
|
||||||
|
"scopes": ["profile"],
|
||||||
|
"authorization_endpoint": "https://z",
|
||||||
|
"token_endpoint": "https://issuer.test/token",
|
||||||
|
"userinfo_endpoint": "https://issuer.test/userinfo",
|
||||||
|
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
|
||||||
|
}
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from twisted.internet.defer import succeed
|
from twisted.internet.defer import succeed
|
||||||
@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
|||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.client.v1 import login
|
from synapse.rest.client.v1 import login
|
||||||
from synapse.rest.client.v2_alpha import auth, devices, register
|
from synapse.rest.client.v2_alpha import auth, devices, register
|
||||||
from synapse.types import JsonDict
|
from synapse.rest.oidc import OIDCResource
|
||||||
|
from synapse.types import JsonDict, UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
|
||||||
from tests.server import FakeChannel
|
from tests.server import FakeChannel
|
||||||
|
|
||||||
|
|
||||||
@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||||||
register.register_servlets,
|
register.register_servlets,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def default_config(self):
|
||||||
|
config = super().default_config()
|
||||||
|
|
||||||
|
# we enable OIDC as a way of testing SSO flows
|
||||||
|
oidc_config = {}
|
||||||
|
oidc_config.update(TEST_OIDC_CONFIG)
|
||||||
|
oidc_config["allow_existing_users"] = True
|
||||||
|
|
||||||
|
config["oidc_config"] = oidc_config
|
||||||
|
config["public_baseurl"] = "https://synapse.test"
|
||||||
|
return config
|
||||||
|
|
||||||
|
def create_resource_dict(self):
|
||||||
|
resource_dict = super().create_resource_dict()
|
||||||
|
# mount the OIDC resource at /_synapse/oidc
|
||||||
|
resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
|
||||||
|
return resource_dict
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.user_pass = "pass"
|
self.user_pass = "pass"
|
||||||
self.user = self.register_user("test", self.user_pass)
|
self.user = self.register_user("test", self.user_pass)
|
||||||
self.user_tok = self.login("test", self.user_pass)
|
self.user_tok = self.login("test", self.user_pass)
|
||||||
|
|
||||||
def get_device_ids(self) -> List[str]:
|
def get_device_ids(self, access_token: str) -> List[str]:
|
||||||
# Get the list of devices so one can be deleted.
|
# Get the list of devices so one can be deleted.
|
||||||
request, channel = self.make_request(
|
_, channel = self.make_request("GET", "devices", access_token=access_token,)
|
||||||
"GET", "devices", access_token=self.user_tok,
|
self.assertEqual(channel.code, 200)
|
||||||
) # type: SynapseRequest, FakeChannel
|
|
||||||
|
|
||||||
# Get the ID of the device.
|
|
||||||
self.assertEqual(request.code, 200)
|
|
||||||
return [d["device_id"] for d in channel.json_body["devices"]]
|
return [d["device_id"] for d in channel.json_body["devices"]]
|
||||||
|
|
||||||
def delete_device(
|
def delete_device(
|
||||||
self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
|
self,
|
||||||
|
access_token: str,
|
||||||
|
device: str,
|
||||||
|
expected_response: int,
|
||||||
|
body: Union[bytes, JsonDict] = b"",
|
||||||
) -> FakeChannel:
|
) -> FakeChannel:
|
||||||
"""Delete an individual device."""
|
"""Delete an individual device."""
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
"DELETE", "devices/" + device, body, access_token=self.user_tok
|
"DELETE", "devices/" + device, body, access_token=access_token,
|
||||||
) # type: SynapseRequest, FakeChannel
|
) # type: SynapseRequest, FakeChannel
|
||||||
|
|
||||||
# Ensure the response is sane.
|
# Ensure the response is sane.
|
||||||
@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||||||
"""
|
"""
|
||||||
Test user interactive authentication outside of registration.
|
Test user interactive authentication outside of registration.
|
||||||
"""
|
"""
|
||||||
device_id = self.get_device_ids()[0]
|
device_id = self.get_device_ids(self.user_tok)[0]
|
||||||
|
|
||||||
# Attempt to delete this device.
|
# Attempt to delete this device.
|
||||||
# Returns a 401 as per the spec
|
# Returns a 401 as per the spec
|
||||||
channel = self.delete_device(device_id, 401)
|
channel = self.delete_device(self.user_tok, device_id, 401)
|
||||||
|
|
||||||
# Grab the session
|
# Grab the session
|
||||||
session = channel.json_body["session"]
|
session = channel.json_body["session"]
|
||||||
@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# Make another request providing the UI auth flow.
|
# Make another request providing the UI auth flow.
|
||||||
self.delete_device(
|
self.delete_device(
|
||||||
|
self.user_tok,
|
||||||
device_id,
|
device_id,
|
||||||
200,
|
200,
|
||||||
{
|
{
|
||||||
@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||||||
UIA - check that still works.
|
UIA - check that still works.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
device_id = self.get_device_ids()[0]
|
device_id = self.get_device_ids(self.user_tok)[0]
|
||||||
channel = self.delete_device(device_id, 401)
|
channel = self.delete_device(self.user_tok, device_id, 401)
|
||||||
session = channel.json_body["session"]
|
session = channel.json_body["session"]
|
||||||
|
|
||||||
# Make another request providing the UI auth flow.
|
# Make another request providing the UI auth flow.
|
||||||
self.delete_device(
|
self.delete_device(
|
||||||
|
self.user_tok,
|
||||||
device_id,
|
device_id,
|
||||||
200,
|
200,
|
||||||
{
|
{
|
||||||
@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||||||
# Create a second login.
|
# Create a second login.
|
||||||
self.login("test", self.user_pass)
|
self.login("test", self.user_pass)
|
||||||
|
|
||||||
device_ids = self.get_device_ids()
|
device_ids = self.get_device_ids(self.user_tok)
|
||||||
self.assertEqual(len(device_ids), 2)
|
self.assertEqual(len(device_ids), 2)
|
||||||
|
|
||||||
# Attempt to delete the first device.
|
# Attempt to delete the first device.
|
||||||
@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||||||
# Create a second login.
|
# Create a second login.
|
||||||
self.login("test", self.user_pass)
|
self.login("test", self.user_pass)
|
||||||
|
|
||||||
device_ids = self.get_device_ids()
|
device_ids = self.get_device_ids(self.user_tok)
|
||||||
self.assertEqual(len(device_ids), 2)
|
self.assertEqual(len(device_ids), 2)
|
||||||
|
|
||||||
# Attempt to delete the first device.
|
# Attempt to delete the first device.
|
||||||
# Returns a 401 as per the spec
|
# Returns a 401 as per the spec
|
||||||
channel = self.delete_device(device_ids[0], 401)
|
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||||
|
|
||||||
# Grab the session
|
# Grab the session
|
||||||
session = channel.json_body["session"]
|
session = channel.json_body["session"]
|
||||||
@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||||||
# Make another request providing the UI auth flow, but try to delete the
|
# Make another request providing the UI auth flow, but try to delete the
|
||||||
# second device. This results in an error.
|
# second device. This results in an error.
|
||||||
self.delete_device(
|
self.delete_device(
|
||||||
|
self.user_tok,
|
||||||
device_ids[1],
|
device_ids[1],
|
||||||
403,
|
403,
|
||||||
{
|
{
|
||||||
@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_does_not_offer_password_for_sso_user(self):
|
||||||
|
login_resp = self.helper.login_via_oidc("username")
|
||||||
|
user_tok = login_resp["access_token"]
|
||||||
|
device_id = login_resp["device_id"]
|
||||||
|
|
||||||
|
# now call the device deletion API: we should get the option to auth with SSO
|
||||||
|
# and not password.
|
||||||
|
channel = self.delete_device(user_tok, device_id, 401)
|
||||||
|
|
||||||
|
flows = channel.json_body["flows"]
|
||||||
|
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
|
||||||
|
|
||||||
|
def test_does_not_offer_sso_for_password_user(self):
|
||||||
|
# now call the device deletion API: we should get the option to auth with SSO
|
||||||
|
# and not password.
|
||||||
|
device_ids = self.get_device_ids(self.user_tok)
|
||||||
|
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||||
|
|
||||||
|
flows = channel.json_body["flows"]
|
||||||
|
self.assertEqual(flows, [{"stages": ["m.login.password"]}])
|
||||||
|
|
||||||
|
def test_offers_both_flows_for_upgraded_user(self):
|
||||||
|
"""A user that had a password and then logged in with SSO should get both flows
|
||||||
|
"""
|
||||||
|
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
|
||||||
|
self.assertEqual(login_resp["user_id"], self.user)
|
||||||
|
|
||||||
|
device_ids = self.get_device_ids(self.user_tok)
|
||||||
|
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||||
|
|
||||||
|
flows = channel.json_body["flows"]
|
||||||
|
# we have no particular expectations of ordering here
|
||||||
|
self.assertIn({"stages": ["m.login.password"]}, flows)
|
||||||
|
self.assertIn({"stages": ["m.login.sso"]}, flows)
|
||||||
|
self.assertEqual(len(flows), 2)
|
||||||
|
@ -216,8 +216,9 @@ def make_request(
|
|||||||
and not path.startswith(b"/_matrix")
|
and not path.startswith(b"/_matrix")
|
||||||
and not path.startswith(b"/_synapse")
|
and not path.startswith(b"/_synapse")
|
||||||
):
|
):
|
||||||
|
if path.startswith(b"/"):
|
||||||
|
path = path[1:]
|
||||||
path = b"/_matrix/client/r0/" + path
|
path = b"/_matrix/client/r0/" + path
|
||||||
path = path.replace(b"//", b"/")
|
|
||||||
|
|
||||||
if not path.startswith(b"/"):
|
if not path.startswith(b"/"):
|
||||||
path = b"/" + path
|
path = b"/" + path
|
||||||
@ -258,6 +259,7 @@ def make_request(
|
|||||||
for k, v in custom_headers:
|
for k, v in custom_headers:
|
||||||
req.requestHeaders.addRawHeader(k, v)
|
req.requestHeaders.addRawHeader(k, v)
|
||||||
|
|
||||||
|
req.parseCookies()
|
||||||
req.requestReceived(method, path, b"1.1")
|
req.requestReceived(method, path, b"1.1")
|
||||||
|
|
||||||
if await_result:
|
if await_result:
|
||||||
|
@ -22,6 +22,11 @@ import warnings
|
|||||||
from asyncio import Future
|
from asyncio import Future
|
||||||
from typing import Any, Awaitable, Callable, TypeVar
|
from typing import Any, Awaitable, Callable, TypeVar
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
from twisted.web.client import ResponseDone
|
||||||
|
|
||||||
TV = TypeVar("TV")
|
TV = TypeVar("TV")
|
||||||
|
|
||||||
|
|
||||||
@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]:
|
|||||||
sys.unraisablehook = unraisablehook # type: ignore
|
sys.unraisablehook = unraisablehook # type: ignore
|
||||||
|
|
||||||
return cleanup
|
return cleanup
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class FakeResponse:
|
||||||
|
"""A fake twisted.web.IResponse object
|
||||||
|
|
||||||
|
there is a similar class at treq.test.test_response, but it lacks a `phrase`
|
||||||
|
attribute, and didn't support deliverBody until recently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# HTTP response code
|
||||||
|
code = attr.ib(type=int)
|
||||||
|
|
||||||
|
# HTTP response phrase (eg b'OK' for a 200)
|
||||||
|
phrase = attr.ib(type=bytes)
|
||||||
|
|
||||||
|
# body of the response
|
||||||
|
body = attr.ib(type=bytes)
|
||||||
|
|
||||||
|
def deliverBody(self, protocol):
|
||||||
|
protocol.dataReceived(self.body)
|
||||||
|
protocol.connectionLost(Failure(ResponseDone()))
|
||||||
|
@ -20,7 +20,7 @@ import hmac
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Tuple, Type, TypeVar, Union, overload
|
from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload
|
||||||
|
|
||||||
from mock import Mock, patch
|
from mock import Mock, patch
|
||||||
|
|
||||||
@ -46,6 +46,7 @@ from synapse.logging.context import (
|
|||||||
)
|
)
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
|
||||||
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
|
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
|
||||||
@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
|
|||||||
"""
|
"""
|
||||||
Create a the root resource for the test server.
|
Create a the root resource for the test server.
|
||||||
|
|
||||||
The default implementation creates a JsonResource and calls each function in
|
The default calls `self.create_resource_dict` and builds the resultant dict
|
||||||
`servlets` to register servletes against it
|
into a tree.
|
||||||
"""
|
"""
|
||||||
resource = JsonResource(self.hs)
|
root_resource = Resource()
|
||||||
|
create_resource_tree(self.create_resource_dict(), root_resource)
|
||||||
|
return root_resource
|
||||||
|
|
||||||
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
|
"""Create a resource tree for the test server
|
||||||
|
|
||||||
|
A resource tree is a mapping from path to twisted.web.resource.
|
||||||
|
|
||||||
|
The default implementation creates a JsonResource and calls each function in
|
||||||
|
`servlets` to register servlets against it.
|
||||||
|
"""
|
||||||
|
servlet_resource = JsonResource(self.hs)
|
||||||
for servlet in self.servlets:
|
for servlet in self.servlets:
|
||||||
servlet(self.hs, resource)
|
servlet(self.hs, servlet_resource)
|
||||||
|
return {
|
||||||
return resource
|
"/_matrix/client": servlet_resource,
|
||||||
|
"/_synapse/admin": servlet_resource,
|
||||||
|
}
|
||||||
|
|
||||||
def default_config(self):
|
def default_config(self):
|
||||||
"""
|
"""
|
||||||
@ -691,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
|||||||
A federating homeserver that authenticates incoming requests as `other.example.com`.
|
A federating homeserver that authenticates incoming requests as `other.example.com`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare(self, reactor, clock, homeserver):
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
|
d = super().create_resource_dict()
|
||||||
|
d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransportLayerServer(JsonResource):
|
||||||
|
"""A test implementation of TransportLayerServer
|
||||||
|
|
||||||
|
authenticates incoming requests as `other.example.com`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super().__init__(hs)
|
||||||
|
|
||||||
class Authenticator:
|
class Authenticator:
|
||||||
def authenticate_request(self, request, content):
|
def authenticate_request(self, request, content):
|
||||||
return succeed("other.example.com")
|
return succeed("other.example.com")
|
||||||
|
|
||||||
|
authenticator = Authenticator()
|
||||||
|
|
||||||
ratelimiter = FederationRateLimiter(
|
ratelimiter = FederationRateLimiter(
|
||||||
clock,
|
hs.get_clock(),
|
||||||
FederationRateLimitConfig(
|
FederationRateLimitConfig(
|
||||||
window_size=1,
|
window_size=1,
|
||||||
sleep_limit=1,
|
sleep_limit=1,
|
||||||
@ -706,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
|||||||
concurrent_requests=1000,
|
concurrent_requests=1000,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
federation_server.register_servlets(
|
|
||||||
homeserver, self.resource, Authenticator(), ratelimiter
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().prepare(reactor, clock, homeserver)
|
federation_server.register_servlets(hs, self, authenticator, ratelimiter)
|
||||||
|
|
||||||
|
|
||||||
def override_config(extra_config):
|
def override_config(extra_config):
|
||||||
|
Loading…
Reference in New Issue
Block a user