mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Merge remote-tracking branch 'origin/develop' into rav/remove_unused_mocks
This commit is contained in:
commit
269ba1bc84
1
changelog.d/8858.bugfix
Normal file
1
changelog.d/8858.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password.
|
1
changelog.d/8864.misc
Normal file
1
changelog.d/8864.misc
Normal file
@ -0,0 +1 @@
|
||||
Remove unused `FakeResponse` class from unit tests.
|
@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
|
||||
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): homeserver
|
||||
resource (TransportLayerServer): resource class to register to
|
||||
resource (JsonResource): resource class to register to
|
||||
authenticator (Authenticator): authenticator to use
|
||||
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
|
||||
servlet_groups (list[str], optional): List of servlet groups to register.
|
||||
|
@ -193,9 +193,7 @@ class AuthHandler(BaseHandler):
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._password_enabled = hs.config.password_enabled
|
||||
self._sso_enabled = (
|
||||
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
|
||||
)
|
||||
self._password_localdb_enabled = hs.config.password_localdb_enabled
|
||||
|
||||
# we keep this as a list despite the O(N^2) implication so that we can
|
||||
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||
@ -205,7 +203,7 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
# start out by assuming PASSWORD is enabled; we will remove it later if not.
|
||||
login_types = []
|
||||
if hs.config.password_localdb_enabled:
|
||||
if self._password_localdb_enabled:
|
||||
login_types.append(LoginType.PASSWORD)
|
||||
|
||||
for provider in self.password_providers:
|
||||
@ -219,14 +217,6 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
self._supported_login_types = login_types
|
||||
|
||||
# Login types and UI Auth types have a heavy overlap, but are not
|
||||
# necessarily identical. Login types have SSO (and other login types)
|
||||
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
|
||||
ui_auth_types = login_types.copy()
|
||||
if self._sso_enabled:
|
||||
ui_auth_types.append(LoginType.SSO)
|
||||
self._supported_ui_auth_types = ui_auth_types
|
||||
|
||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||
# as per `rc_login.failed_attempts`.
|
||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||
@ -339,7 +329,10 @@ class AuthHandler(BaseHandler):
|
||||
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
|
||||
|
||||
# build a list of supported flows
|
||||
flows = [[login_type] for login_type in self._supported_ui_auth_types]
|
||||
supported_ui_auth_types = await self._get_available_ui_auth_types(
|
||||
requester.user
|
||||
)
|
||||
flows = [[login_type] for login_type in supported_ui_auth_types]
|
||||
|
||||
try:
|
||||
result, params, session_id = await self.check_ui_auth(
|
||||
@ -351,7 +344,7 @@ class AuthHandler(BaseHandler):
|
||||
raise
|
||||
|
||||
# find the completed login type
|
||||
for login_type in self._supported_ui_auth_types:
|
||||
for login_type in supported_ui_auth_types:
|
||||
if login_type not in result:
|
||||
continue
|
||||
|
||||
@ -367,6 +360,41 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
return params, session_id
|
||||
|
||||
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
|
||||
"""Get a list of the authentication types this user can use
|
||||
"""
|
||||
|
||||
ui_auth_types = set()
|
||||
|
||||
# if the HS supports password auth, and the user has a non-null password, we
|
||||
# support password auth
|
||||
if self._password_localdb_enabled and self._password_enabled:
|
||||
lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
|
||||
if lookupres:
|
||||
_, password_hash = lookupres
|
||||
if password_hash:
|
||||
ui_auth_types.add(LoginType.PASSWORD)
|
||||
|
||||
# also allow auth from password providers
|
||||
for provider in self.password_providers:
|
||||
for t in provider.get_supported_login_types().keys():
|
||||
if t == LoginType.PASSWORD and not self._password_enabled:
|
||||
continue
|
||||
ui_auth_types.add(t)
|
||||
|
||||
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
||||
# from sso to mxid.
|
||||
if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
|
||||
if await self.store.get_external_ids_by_user(user.to_string()):
|
||||
ui_auth_types.add(LoginType.SSO)
|
||||
|
||||
# Our CAS impl does not (yet) correctly register users in user_external_ids,
|
||||
# so always offer that if it's available.
|
||||
if self.hs.config.cas.cas_enabled:
|
||||
ui_auth_types.add(LoginType.SSO)
|
||||
|
||||
return ui_auth_types
|
||||
|
||||
def get_enabled_auth_types(self):
|
||||
"""Return the enabled user-interactive authentication types
|
||||
|
||||
@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler):
|
||||
if result:
|
||||
return result
|
||||
|
||||
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
|
||||
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
|
||||
known_login_type = True
|
||||
|
||||
# we've already checked that there is a (valid) password field
|
||||
|
@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
desc="get_user_by_external_id",
|
||||
)
|
||||
|
||||
async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
|
||||
"""Look up external ids for the given user
|
||||
|
||||
Args:
|
||||
mxid: the MXID to be looked up
|
||||
|
||||
Returns:
|
||||
Tuples of (auth_provider, external_id)
|
||||
"""
|
||||
res = await self.db_pool.simple_select_list(
|
||||
table="user_external_ids",
|
||||
keyvalues={"user_id": mxid},
|
||||
retcols=("auth_provider", "external_id"),
|
||||
desc="get_external_ids_by_user",
|
||||
)
|
||||
return [(r["auth_provider"], r["external_id"]) for r in res]
|
||||
|
||||
async def count_all_users(self):
|
||||
"""Counts all users registered on the homeserver."""
|
||||
|
||||
@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
"user_external_ids_user_id_idx",
|
||||
index_name="user_external_ids_user_id_idx",
|
||||
table="user_external_ids",
|
||||
columns=["user_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
async def _background_update_set_deactivated_flag(self, progress, batch_size):
|
||||
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
|
||||
for each of them.
|
||||
|
@ -0,0 +1,17 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(5825, 'user_external_ids_user_id_idx', '{}');
|
@ -17,30 +17,15 @@ from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
import attr
|
||||
import pymacaroons
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
||||
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests.test_utils import FakeResponse
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeResponse:
|
||||
code = attr.ib()
|
||||
body = attr.ib()
|
||||
phrase = attr.ib()
|
||||
|
||||
def deliverBody(self, protocol):
|
||||
protocol.dataReceived(self.body)
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
|
||||
# These are a few constants that are used as config parameters in the tests.
|
||||
ISSUER = "https://issuer/"
|
||||
CLIENT_ID = "test-client-id"
|
||||
|
@ -15,15 +15,16 @@
|
||||
|
||||
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from mock import ANY, Mock, call
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.federation.transport import server as federation_server
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
from synapse.types import UserID, create_requester
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
@ -53,20 +54,7 @@ def _make_edu_transaction_json(edu_type, content):
|
||||
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
|
||||
|
||||
|
||||
def register_federation_servlets(hs, resource):
|
||||
federation_server.register_servlets(
|
||||
hs,
|
||||
resource=resource,
|
||||
authenticator=federation_server.Authenticator(hs),
|
||||
ratelimiter=FederationRateLimiter(
|
||||
hs.get_clock(), config=hs.config.rc_federation
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [register_federation_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
# we mock out the keyring so as to skip the authentication check on the
|
||||
# federation API call.
|
||||
@ -89,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
d["/_matrix/federation"] = TransportLayerServer(self.hs)
|
||||
return d
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
mock_notifier = hs.get_notifier()
|
||||
self.on_new_event = mock_notifier.on_new_event
|
||||
|
@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.web.http import HTTPChannel
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.app.generic_worker import (
|
||||
GenericWorkerReplicationHandler,
|
||||
@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
|
||||
)
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.site import SynapseRequest, SynapseSite
|
||||
from synapse.replication.http import ReplicationRestResource, streams
|
||||
from synapse.replication.http import ReplicationRestResource
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
if not hiredis:
|
||||
skip = "Requires hiredis"
|
||||
|
||||
servlets = [
|
||||
streams.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# build a replication server
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
self._client_transport = None
|
||||
self._server_transport = None
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
|
||||
return d
|
||||
|
||||
def _get_worker_hs_config(self) -> dict:
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.generic_worker"
|
||||
|
@ -2,7 +2,7 @@
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2018-2019 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -17,17 +17,23 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from mock import patch
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Site
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests.server import FakeSite, make_request
|
||||
from tests.test_utils import FakeResponse
|
||||
|
||||
|
||||
@attr.s
|
||||
@ -344,3 +350,111 @@ class RestHelper:
|
||||
)
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def login_via_oidc(self, remote_user_id: str) -> JsonDict:
|
||||
"""Log in (as a new user) via OIDC
|
||||
|
||||
Returns the result of the final token login.
|
||||
|
||||
Requires that "oidc_config" in the homeserver config be set appropriately
|
||||
(TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
|
||||
"public_base_url".
|
||||
|
||||
Also requires the login servlet and the OIDC callback resource to be mounted at
|
||||
the normal places.
|
||||
"""
|
||||
client_redirect_url = "https://x"
|
||||
|
||||
# first hit the redirect url (which will issue a cookie and state)
|
||||
_, channel = make_request(
|
||||
self.hs.get_reactor(),
|
||||
self.site,
|
||||
"GET",
|
||||
"/login/sso/redirect?redirectUrl=" + client_redirect_url,
|
||||
)
|
||||
# that will redirect to the OIDC IdP, but we skip that and go straight
|
||||
# back to synapse's OIDC callback resource. However, we do need the "state"
|
||||
# param that synapse passes to the IdP via query params, and the cookie that
|
||||
# synapse passes to the client.
|
||||
assert channel.code == 302
|
||||
oauth_uri = channel.headers.getRawHeaders("Location")[0]
|
||||
params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
|
||||
redirect_uri = "%s?%s" % (
|
||||
urllib.parse.urlparse(params["redirect_uri"][0]).path,
|
||||
urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
|
||||
)
|
||||
cookies = {}
|
||||
for h in channel.headers.getRawHeaders("Set-Cookie"):
|
||||
parts = h.split(";")
|
||||
k, v = parts[0].split("=", maxsplit=1)
|
||||
cookies[k] = v
|
||||
|
||||
# before we hit the callback uri, stub out some methods in the http client so
|
||||
# that we don't have to handle full HTTPS requests.
|
||||
|
||||
# (expected url, json response) pairs, in the order we expect them.
|
||||
expected_requests = [
|
||||
# first we get a hit to the token endpoint, which we tell to return
|
||||
# a dummy OIDC access token
|
||||
("https://issuer.test/token", {"access_token": "TEST"}),
|
||||
# and then one to the user_info endpoint, which returns our remote user id.
|
||||
("https://issuer.test/userinfo", {"sub": remote_user_id}),
|
||||
]
|
||||
|
||||
async def mock_req(method: str, uri: str, data=None, headers=None):
|
||||
(expected_uri, resp_obj) = expected_requests.pop(0)
|
||||
assert uri == expected_uri
|
||||
resp = FakeResponse(
|
||||
code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
|
||||
)
|
||||
return resp
|
||||
|
||||
with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
|
||||
# now hit the callback URI with the right params and a made-up code
|
||||
_, channel = make_request(
|
||||
self.hs.get_reactor(),
|
||||
self.site,
|
||||
"GET",
|
||||
redirect_uri,
|
||||
custom_headers=[
|
||||
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
|
||||
],
|
||||
)
|
||||
|
||||
# expect a confirmation page
|
||||
assert channel.code == 200
|
||||
|
||||
# fish the matrix login token out of the body of the confirmation page
|
||||
m = re.search(
|
||||
'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
|
||||
channel.result["body"].decode("utf-8"),
|
||||
)
|
||||
assert m
|
||||
login_token = m.group(1)
|
||||
|
||||
# finally, submit the matrix login token to the login API, which gives us our
|
||||
# matrix access token and device id.
|
||||
_, channel = make_request(
|
||||
self.hs.get_reactor(),
|
||||
self.site,
|
||||
"POST",
|
||||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
)
|
||||
assert channel.code == 200
|
||||
return channel.json_body
|
||||
|
||||
|
||||
# an 'oidc_config' suitable for login_with_oidc.
|
||||
TEST_OIDC_CONFIG = {
|
||||
"enabled": True,
|
||||
"discover": False,
|
||||
"issuer": "https://issuer.test",
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
"scopes": ["profile"],
|
||||
"authorization_endpoint": "https://z",
|
||||
"token_endpoint": "https://issuer.test/token",
|
||||
"userinfo_endpoint": "https://issuer.test/userinfo",
|
||||
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
|
||||
}
|
||||
|
@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from twisted.internet.defer import succeed
|
||||
@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v1 import login
|
||||
from synapse.rest.client.v2_alpha import auth, devices, register
|
||||
from synapse.types import JsonDict
|
||||
from synapse.rest.oidc import OIDCResource
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
|
||||
from tests.server import FakeChannel
|
||||
|
||||
|
||||
@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
register.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
|
||||
# we enable OIDC as a way of testing SSO flows
|
||||
oidc_config = {}
|
||||
oidc_config.update(TEST_OIDC_CONFIG)
|
||||
oidc_config["allow_existing_users"] = True
|
||||
|
||||
config["oidc_config"] = oidc_config
|
||||
config["public_baseurl"] = "https://synapse.test"
|
||||
return config
|
||||
|
||||
def create_resource_dict(self):
|
||||
resource_dict = super().create_resource_dict()
|
||||
# mount the OIDC resource at /_synapse/oidc
|
||||
resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
|
||||
return resource_dict
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.user_pass = "pass"
|
||||
self.user = self.register_user("test", self.user_pass)
|
||||
self.user_tok = self.login("test", self.user_pass)
|
||||
|
||||
def get_device_ids(self) -> List[str]:
|
||||
def get_device_ids(self, access_token: str) -> List[str]:
|
||||
# Get the list of devices so one can be deleted.
|
||||
request, channel = self.make_request(
|
||||
"GET", "devices", access_token=self.user_tok,
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
|
||||
# Get the ID of the device.
|
||||
self.assertEqual(request.code, 200)
|
||||
_, channel = self.make_request("GET", "devices", access_token=access_token,)
|
||||
self.assertEqual(channel.code, 200)
|
||||
return [d["device_id"] for d in channel.json_body["devices"]]
|
||||
|
||||
def delete_device(
|
||||
self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
|
||||
self,
|
||||
access_token: str,
|
||||
device: str,
|
||||
expected_response: int,
|
||||
body: Union[bytes, JsonDict] = b"",
|
||||
) -> FakeChannel:
|
||||
"""Delete an individual device."""
|
||||
request, channel = self.make_request(
|
||||
"DELETE", "devices/" + device, body, access_token=self.user_tok
|
||||
"DELETE", "devices/" + device, body, access_token=access_token,
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
|
||||
# Ensure the response is sane.
|
||||
@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Test user interactive authentication outside of registration.
|
||||
"""
|
||||
device_id = self.get_device_ids()[0]
|
||||
device_id = self.get_device_ids(self.user_tok)[0]
|
||||
|
||||
# Attempt to delete this device.
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.delete_device(device_id, 401)
|
||||
channel = self.delete_device(self.user_tok, device_id, 401)
|
||||
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
|
||||
# Make another request providing the UI auth flow.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
device_id,
|
||||
200,
|
||||
{
|
||||
@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
UIA - check that still works.
|
||||
"""
|
||||
|
||||
device_id = self.get_device_ids()[0]
|
||||
channel = self.delete_device(device_id, 401)
|
||||
device_id = self.get_device_ids(self.user_tok)[0]
|
||||
channel = self.delete_device(self.user_tok, device_id, 401)
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Make another request providing the UI auth flow.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
device_id,
|
||||
200,
|
||||
{
|
||||
@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
# Create a second login.
|
||||
self.login("test", self.user_pass)
|
||||
|
||||
device_ids = self.get_device_ids()
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
self.assertEqual(len(device_ids), 2)
|
||||
|
||||
# Attempt to delete the first device.
|
||||
@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
# Create a second login.
|
||||
self.login("test", self.user_pass)
|
||||
|
||||
device_ids = self.get_device_ids()
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
self.assertEqual(len(device_ids), 2)
|
||||
|
||||
# Attempt to delete the first device.
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.delete_device(device_ids[0], 401)
|
||||
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
# Make another request providing the UI auth flow, but try to delete the
|
||||
# second device. This results in an error.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
device_ids[1],
|
||||
403,
|
||||
{
|
||||
@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def test_does_not_offer_password_for_sso_user(self):
|
||||
login_resp = self.helper.login_via_oidc("username")
|
||||
user_tok = login_resp["access_token"]
|
||||
device_id = login_resp["device_id"]
|
||||
|
||||
# now call the device deletion API: we should get the option to auth with SSO
|
||||
# and not password.
|
||||
channel = self.delete_device(user_tok, device_id, 401)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
|
||||
|
||||
def test_does_not_offer_sso_for_password_user(self):
|
||||
# now call the device deletion API: we should get the option to auth with SSO
|
||||
# and not password.
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
self.assertEqual(flows, [{"stages": ["m.login.password"]}])
|
||||
|
||||
def test_offers_both_flows_for_upgraded_user(self):
|
||||
"""A user that had a password and then logged in with SSO should get both flows
|
||||
"""
|
||||
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
|
||||
self.assertEqual(login_resp["user_id"], self.user)
|
||||
|
||||
device_ids = self.get_device_ids(self.user_tok)
|
||||
channel = self.delete_device(self.user_tok, device_ids[0], 401)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
# we have no particular expectations of ordering here
|
||||
self.assertIn({"stages": ["m.login.password"]}, flows)
|
||||
self.assertIn({"stages": ["m.login.sso"]}, flows)
|
||||
self.assertEqual(len(flows), 2)
|
||||
|
@ -18,41 +18,15 @@ import re
|
||||
|
||||
from mock import patch
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet._resolver import HostResolution
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeResponse:
|
||||
version = attr.ib()
|
||||
code = attr.ib()
|
||||
phrase = attr.ib()
|
||||
headers = attr.ib()
|
||||
body = attr.ib()
|
||||
absoluteURI = attr.ib()
|
||||
|
||||
@property
|
||||
def request(self):
|
||||
@attr.s
|
||||
class FakeTransport:
|
||||
absoluteURI = self.absoluteURI
|
||||
|
||||
return FakeTransport()
|
||||
|
||||
def deliverBody(self, protocol):
|
||||
protocol.dataReceived(self.body)
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
|
||||
class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
|
||||
hijack_auth = True
|
||||
|
@ -216,8 +216,9 @@ def make_request(
|
||||
and not path.startswith(b"/_matrix")
|
||||
and not path.startswith(b"/_synapse")
|
||||
):
|
||||
if path.startswith(b"/"):
|
||||
path = path[1:]
|
||||
path = b"/_matrix/client/r0/" + path
|
||||
path = path.replace(b"//", b"/")
|
||||
|
||||
if not path.startswith(b"/"):
|
||||
path = b"/" + path
|
||||
@ -258,6 +259,7 @@ def make_request(
|
||||
for k, v in custom_headers:
|
||||
req.requestHeaders.addRawHeader(k, v)
|
||||
|
||||
req.parseCookies()
|
||||
req.requestReceived(method, path, b"1.1")
|
||||
|
||||
if await_result:
|
||||
|
@ -22,6 +22,11 @@ import warnings
|
||||
from asyncio import Future
|
||||
from typing import Any, Awaitable, Callable, TypeVar
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.client import ResponseDone
|
||||
|
||||
TV = TypeVar("TV")
|
||||
|
||||
|
||||
@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]:
|
||||
sys.unraisablehook = unraisablehook # type: ignore
|
||||
|
||||
return cleanup
|
||||
|
||||
|
||||
@attr.s
|
||||
class FakeResponse:
|
||||
"""A fake twisted.web.IResponse object
|
||||
|
||||
there is a similar class at treq.test.test_response, but it lacks a `phrase`
|
||||
attribute, and didn't support deliverBody until recently.
|
||||
"""
|
||||
|
||||
# HTTP response code
|
||||
code = attr.ib(type=int)
|
||||
|
||||
# HTTP response phrase (eg b'OK' for a 200)
|
||||
phrase = attr.ib(type=bytes)
|
||||
|
||||
# body of the response
|
||||
body = attr.ib(type=bytes)
|
||||
|
||||
def deliverBody(self, protocol):
|
||||
protocol.dataReceived(self.body)
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
@ -20,7 +20,7 @@ import hmac
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Tuple, Type, TypeVar, Union, overload
|
||||
from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
@ -46,6 +46,7 @@ from synapse.logging.context import (
|
||||
)
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID, create_requester
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
|
||||
@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
|
||||
"""
|
||||
Create a the root resource for the test server.
|
||||
|
||||
The default implementation creates a JsonResource and calls each function in
|
||||
`servlets` to register servletes against it
|
||||
The default calls `self.create_resource_dict` and builds the resultant dict
|
||||
into a tree.
|
||||
"""
|
||||
resource = JsonResource(self.hs)
|
||||
root_resource = Resource()
|
||||
create_resource_tree(self.create_resource_dict(), root_resource)
|
||||
return root_resource
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
"""Create a resource tree for the test server
|
||||
|
||||
A resource tree is a mapping from path to twisted.web.resource.
|
||||
|
||||
The default implementation creates a JsonResource and calls each function in
|
||||
`servlets` to register servlets against it.
|
||||
"""
|
||||
servlet_resource = JsonResource(self.hs)
|
||||
for servlet in self.servlets:
|
||||
servlet(self.hs, resource)
|
||||
|
||||
return resource
|
||||
servlet(self.hs, servlet_resource)
|
||||
return {
|
||||
"/_matrix/client": servlet_resource,
|
||||
"/_synapse/admin": servlet_resource,
|
||||
}
|
||||
|
||||
def default_config(self):
|
||||
"""
|
||||
@ -691,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||
A federating homeserver that authenticates incoming requests as `other.example.com`.
|
||||
"""
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
|
||||
return d
|
||||
|
||||
|
||||
class TestTransportLayerServer(JsonResource):
|
||||
"""A test implementation of TransportLayerServer
|
||||
|
||||
authenticates incoming requests as `other.example.com`.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
|
||||
class Authenticator:
|
||||
def authenticate_request(self, request, content):
|
||||
return succeed("other.example.com")
|
||||
|
||||
authenticator = Authenticator()
|
||||
|
||||
ratelimiter = FederationRateLimiter(
|
||||
clock,
|
||||
hs.get_clock(),
|
||||
FederationRateLimitConfig(
|
||||
window_size=1,
|
||||
sleep_limit=1,
|
||||
@ -706,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
|
||||
concurrent_requests=1000,
|
||||
),
|
||||
)
|
||||
federation_server.register_servlets(
|
||||
homeserver, self.resource, Authenticator(), ratelimiter
|
||||
)
|
||||
|
||||
return super().prepare(reactor, clock, homeserver)
|
||||
federation_server.register_servlets(hs, self, authenticator, ratelimiter)
|
||||
|
||||
|
||||
def override_config(extra_config):
|
||||
|
Loading…
Reference in New Issue
Block a user