forked-synapse/synapse/handlers/ui_auth/checkers.py
2023-11-21 15:29:58 -05:00

332 lines
12 KiB
Python

#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type
from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class UserInteractiveAuthChecker(ABC):
"""Abstract base class for an interactive auth checker"""
# This should really be an "abstract class property", i.e. it should
# be an error to instantiate a subclass that doesn't specify an AUTH_TYPE.
# But calling this a `ClassVar` is simpler than a decorator stack of
# @property @abstractmethod and @classmethod (if that's even the right order).
AUTH_TYPE: ClassVar[str]
def __init__(self, hs: "HomeServer"): # noqa: B027
pass
@abstractmethod
def is_enabled(self) -> bool:
"""Check if the configuration of the homeserver allows this checker to work
Returns:
True if this login type is enabled.
"""
raise NotImplementedError()
@abstractmethod
async def check_auth(self, authdict: dict, clientip: str) -> Any:
"""Given the authentication dict from the client, attempt to check this step
Args:
authdict: authentication dictionary from the client
clientip: The IP address of the client.
Raises:
LoginError if authentication failed.
Returns:
The result of authentication (to pass back to the client?)
"""
raise NotImplementedError()
class DummyAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.DUMMY
def is_enabled(self) -> bool:
return True
async def check_auth(self, authdict: dict, clientip: str) -> Any:
return True
class TermsAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.TERMS
def is_enabled(self) -> bool:
return True
async def check_auth(self, authdict: dict, clientip: str) -> Any:
return True
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.RECAPTCHA
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._enabled = bool(hs.config.captcha.recaptcha_private_key)
self._http_client = hs.get_proxied_http_client()
self._url = hs.config.captcha.recaptcha_siteverify_api
self._secret = hs.config.captcha.recaptcha_private_key
def is_enabled(self) -> bool:
return self._enabled
async def check_auth(self, authdict: dict, clientip: str) -> Any:
try:
user_response = authdict["response"]
except KeyError:
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
"Submitting recaptcha response %s with remoteip %s", user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
assert self._secret is not None
resp_body = await self._http_client.post_urlencoded_get_json(
self._url,
args={
"secret": self._secret,
"response": user_response,
"remoteip": clientip,
},
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
# For mypy's benefit. A general Error.response is Optional[bytes], but
# a PartialDownloadError.response should be bytes AFAICS.
assert data is not None
resp_body = json_decoder.decode(data.decode("utf-8"))
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body["success"] else "Failed",
resp_body.get("hostname"),
)
if resp_body["success"]:
return True
raise LoginError(
401, "Captcha authentication failed", errcode=Codes.UNAUTHORIZED
)
class _BaseThreepidAuthChecker:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
async def _check_threepid(self, medium: str, authdict: dict) -> dict:
if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict["threepid_creds"]
identity_handler = self.hs.get_identity_handler()
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
# msisdns are currently always verified via the IS
if medium == "msisdn":
if not self.hs.config.registration.account_threepid_delegate_msisdn:
raise SynapseError(
400, "Phone number verification is not enabled on this homeserver"
)
threepid = await identity_handler.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds,
)
elif medium == "email":
if self.hs.config.email.can_verify_email:
threepid = None
row = await self.store.get_threepid_validation_session(
medium,
threepid_creds["client_secret"],
sid=threepid_creds["sid"],
validated=True,
)
if row:
threepid = {
"medium": row.medium,
"address": row.address,
"validated_at": row.validated_at,
}
# Valid threepid returned, delete from the db
await self.store.delete_threepid_session(threepid_creds["sid"])
else:
raise SynapseError(
400, "Email address verification is not enabled on this homeserver"
)
else:
# this can't happen!
raise AssertionError("Unrecognized threepid medium: %s" % (medium,))
if not threepid:
raise LoginError(
401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED
)
if threepid["medium"] != medium:
raise LoginError(
401,
"Expecting threepid of type '%s', got '%s'"
% (medium, threepid["medium"]),
errcode=Codes.UNAUTHORIZED,
)
threepid["threepid_creds"] = authdict["threepid_creds"]
return threepid
class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
AUTH_TYPE = LoginType.EMAIL_IDENTITY
def __init__(self, hs: "HomeServer"):
UserInteractiveAuthChecker.__init__(self, hs)
_BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self) -> bool:
return self.hs.config.email.can_verify_email
async def check_auth(self, authdict: dict, clientip: str) -> Any:
return await self._check_threepid("email", authdict)
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
AUTH_TYPE = LoginType.MSISDN
def __init__(self, hs: "HomeServer"):
UserInteractiveAuthChecker.__init__(self, hs)
_BaseThreepidAuthChecker.__init__(self, hs)
def is_enabled(self) -> bool:
return bool(self.hs.config.registration.account_threepid_delegate_msisdn)
async def check_auth(self, authdict: dict, clientip: str) -> Any:
return await self._check_threepid("msisdn", authdict)
class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.REGISTRATION_TOKEN
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self._enabled = bool(
hs.config.registration.registration_requires_token
) or bool(hs.config.registration.enable_registration_token_3pid_bypass)
self.store = hs.get_datastores().main
def is_enabled(self) -> bool:
return self._enabled
async def check_auth(self, authdict: dict, clientip: str) -> Any:
if "token" not in authdict:
raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
if not isinstance(authdict["token"], str):
raise LoginError(
400, "Registration token must be a string", Codes.INVALID_PARAM
)
if "session" not in authdict:
raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
# Get these here to avoid cyclic dependencies
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
auth_handler = self.hs.get_auth_handler()
session = authdict["session"]
token = authdict["token"]
# If the LoginType.REGISTRATION_TOKEN stage has already been completed,
# return early to avoid incrementing `pending` again.
stored_token = await auth_handler.get_session_data(
session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
)
if stored_token:
if token != stored_token:
raise LoginError(
400, "Registration token has changed", Codes.INVALID_PARAM
)
else:
return token
if await self.store.registration_token_is_valid(token):
# Increment pending counter, so that if token has limited uses it
# can't be used up by someone else in the meantime.
await self.store.set_registration_token_pending(token)
# Store the token in the UIA session, so that once registration
# is complete `completed` can be incremented.
await auth_handler.set_session_data(
session,
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
token,
)
# The token will be stored as the result of the authentication stage
# in ui_auth_sessions_credentials. This allows the pending counter
# for tokens to be decremented when expired sessions are deleted.
return token
else:
raise LoginError(
401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
)
INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
EmailIdentityAuthChecker,
MsisdnAuthChecker,
RegistrationTokenAuthChecker,
]
"""A list of UserInteractiveAuthChecker classes"""