From 8000cf131592b6edcded65ef4be20b8ac0f1bfd3 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Tue, 16 Mar 2021 16:44:25 +0100 Subject: [PATCH] Return m.change_password.enabled=false if local database is disabled (#9588) Instead of if the user does not have a password hash. This allows a SSO user to add a password to their account, but only if the local password database is configured. --- changelog.d/9588.bugfix | 1 + synapse/handlers/auth.py | 13 +++++++ synapse/rest/client/v2_alpha/capabilities.py | 23 ++++++------ .../rest/client/v2_alpha/test_capabilities.py | 36 ++++++++++++++++--- 4 files changed, 58 insertions(+), 15 deletions(-) create mode 100644 changelog.d/9588.bugfix diff --git a/changelog.d/9588.bugfix b/changelog.d/9588.bugfix new file mode 100644 index 000000000..b8d614056 --- /dev/null +++ b/changelog.d/9588.bugfix @@ -0,0 +1 @@ +Fix the `/capabilities` endpoint to return `m.change_password` as disabled if the local password database is not used for authentication. Contributed by @dklimpel. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fb5f8118f..badac8c26 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -886,6 +886,19 @@ class AuthHandler(BaseHandler): ) return result + def can_change_password(self) -> bool: + """Get whether users on this server are allowed to change or set a password. + + Both `config.password_enabled` and `config.password_localdb_enabled` must be true. + + Note that any account (even SSO accounts) are allowed to add passwords if the above + is true. + + Returns: + Whether users on this server are allowed to change or set a password + """ + return self._password_enabled and self._password_localdb_enabled + def get_supported_login_types(self) -> Iterable[str]: """Get a the login types supported for the /login API diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index 76879ac55..44ccf10ed 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -13,12 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import TYPE_CHECKING, Tuple from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet +from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from ._base import client_patterns +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -27,21 +33,16 @@ class CapabilitiesRestServlet(RestServlet): PATTERNS = client_patterns("/capabilities$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.config = hs.config self.auth = hs.get_auth() - self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() - async def on_GET(self, request): - requester = await self.auth.get_user_by_req(request, allow_guest=True) - user = await self.store.get_user_by_id(requester.user.to_string()) - change_password = bool(user["password_hash"]) + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.auth.get_user_by_req(request, allow_guest=True) + change_password = self.auth_handler.can_change_password() response = { "capabilities": { @@ -58,5 +59,5 @@ class CapabilitiesRestServlet(RestServlet): return 200, response -def register_servlets(hs, http_server): +def register_servlets(hs: "HomeServer", http_server): CapabilitiesRestServlet(hs).register(http_server) diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py index e808339fb..287a1a485 100644 --- a/tests/rest/client/v2_alpha/test_capabilities.py +++ b/tests/rest/client/v2_alpha/test_capabilities.py @@ -18,6 +18,7 @@ from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import capabilities from tests import unittest +from tests.unittest import override_config class CapabilitiesTestCase(unittest.HomeserverTestCase): @@ -33,6 +34,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver() self.store = hs.get_datastore() self.config = hs.config + self.auth_handler = hs.get_auth_handler() return hs def test_check_auth_required(self): @@ -56,7 +58,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities["m.room_versions"]["default"], ) - def test_get_change_password_capabilities(self): + def test_get_change_password_capabilities_password_login(self): localpart = "user" password = "pass" user = self.register_user(localpart, password) @@ -66,10 +68,36 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): capabilities = channel.json_body["capabilities"] self.assertEqual(channel.code, 200) - - # Test case where password is handled outside of Synapse self.assertTrue(capabilities["m.change_password"]["enabled"]) - self.get_success(self.store.user_set_password_hash(user, None)) + + @override_config({"password_config": {"localdb_enabled": False}}) + def test_get_change_password_capabilities_localdb_disabled(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, 200) + self.assertFalse(capabilities["m.change_password"]["enabled"]) + + @override_config({"password_config": {"enabled": False}}) + def test_get_change_password_capabilities_password_disabled(self): + localpart = "user" + password = "pass" + user = self.register_user(localpart, password) + access_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + user, device_id=None, valid_until_ms=None + ) + ) + channel = self.make_request("GET", self.url, access_token=access_token) capabilities = channel.json_body["capabilities"]