mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
23740eaa3d
During the migration the automated script to update the copyright headers accidentally got rid of some of the existing copyright lines. Reinstate them.
333 lines
12 KiB
Python
333 lines
12 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
|
# Copyright (C) 2023 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
# Originally licensed under the Apache License, Version 2.0:
|
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
|
#
|
|
# [This file includes modifications made by New Vector Limited]
|
|
#
|
|
#
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type
|
|
|
|
from twisted.web.client import PartialDownloadError
|
|
|
|
from synapse.api.constants import LoginType
|
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
|
from synapse.util import json_decoder
|
|
|
|
if TYPE_CHECKING:
|
|
from synapse.server import HomeServer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class UserInteractiveAuthChecker(ABC):
|
|
"""Abstract base class for an interactive auth checker"""
|
|
|
|
# This should really be an "abstract class property", i.e. it should
|
|
# be an error to instantiate a subclass that doesn't specify an AUTH_TYPE.
|
|
# But calling this a `ClassVar` is simpler than a decorator stack of
|
|
# @property @abstractmethod and @classmethod (if that's even the right order).
|
|
AUTH_TYPE: ClassVar[str]
|
|
|
|
def __init__(self, hs: "HomeServer"): # noqa: B027
|
|
pass
|
|
|
|
@abstractmethod
|
|
def is_enabled(self) -> bool:
|
|
"""Check if the configuration of the homeserver allows this checker to work
|
|
|
|
Returns:
|
|
True if this login type is enabled.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
|
"""Given the authentication dict from the client, attempt to check this step
|
|
|
|
Args:
|
|
authdict: authentication dictionary from the client
|
|
clientip: The IP address of the client.
|
|
|
|
Raises:
|
|
LoginError if authentication failed.
|
|
|
|
Returns:
|
|
The result of authentication (to pass back to the client?)
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
class DummyAuthChecker(UserInteractiveAuthChecker):
|
|
AUTH_TYPE = LoginType.DUMMY
|
|
|
|
def is_enabled(self) -> bool:
|
|
return True
|
|
|
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
|
return True
|
|
|
|
|
|
class TermsAuthChecker(UserInteractiveAuthChecker):
|
|
AUTH_TYPE = LoginType.TERMS
|
|
|
|
def is_enabled(self) -> bool:
|
|
return True
|
|
|
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
|
return True
|
|
|
|
|
|
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
|
AUTH_TYPE = LoginType.RECAPTCHA
|
|
|
|
def __init__(self, hs: "HomeServer"):
|
|
super().__init__(hs)
|
|
self._enabled = bool(hs.config.captcha.recaptcha_private_key)
|
|
self._http_client = hs.get_proxied_http_client()
|
|
self._url = hs.config.captcha.recaptcha_siteverify_api
|
|
self._secret = hs.config.captcha.recaptcha_private_key
|
|
|
|
def is_enabled(self) -> bool:
|
|
return self._enabled
|
|
|
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
|
try:
|
|
user_response = authdict["response"]
|
|
except KeyError:
|
|
# Client tried to provide captcha but didn't give the parameter:
|
|
# bad request.
|
|
raise LoginError(
|
|
400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
|
|
)
|
|
|
|
logger.info(
|
|
"Submitting recaptcha response %s with remoteip %s", user_response, clientip
|
|
)
|
|
|
|
# TODO: get this from the homeserver rather than creating a new one for
|
|
# each request
|
|
try:
|
|
assert self._secret is not None
|
|
|
|
resp_body = await self._http_client.post_urlencoded_get_json(
|
|
self._url,
|
|
args={
|
|
"secret": self._secret,
|
|
"response": user_response,
|
|
"remoteip": clientip,
|
|
},
|
|
)
|
|
except PartialDownloadError as pde:
|
|
# Twisted is silly
|
|
data = pde.response
|
|
# For mypy's benefit. A general Error.response is Optional[bytes], but
|
|
# a PartialDownloadError.response should be bytes AFAICS.
|
|
assert data is not None
|
|
resp_body = json_decoder.decode(data.decode("utf-8"))
|
|
|
|
if "success" in resp_body:
|
|
# Note that we do NOT check the hostname here: we explicitly
|
|
# intend the CAPTCHA to be presented by whatever client the
|
|
# user is using, we just care that they have completed a CAPTCHA.
|
|
logger.info(
|
|
"%s reCAPTCHA from hostname %s",
|
|
"Successful" if resp_body["success"] else "Failed",
|
|
resp_body.get("hostname"),
|
|
)
|
|
if resp_body["success"]:
|
|
return True
|
|
raise LoginError(
|
|
401, "Captcha authentication failed", errcode=Codes.UNAUTHORIZED
|
|
)
|
|
|
|
|
|
class _BaseThreepidAuthChecker:
|
|
def __init__(self, hs: "HomeServer"):
|
|
self.hs = hs
|
|
self.store = hs.get_datastores().main
|
|
|
|
async def _check_threepid(self, medium: str, authdict: dict) -> dict:
|
|
if "threepid_creds" not in authdict:
|
|
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
|
|
|
|
threepid_creds = authdict["threepid_creds"]
|
|
|
|
identity_handler = self.hs.get_identity_handler()
|
|
|
|
logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
|
|
|
|
# msisdns are currently always verified via the IS
|
|
if medium == "msisdn":
|
|
if not self.hs.config.registration.account_threepid_delegate_msisdn:
|
|
raise SynapseError(
|
|
400, "Phone number verification is not enabled on this homeserver"
|
|
)
|
|
threepid = await identity_handler.threepid_from_creds(
|
|
self.hs.config.registration.account_threepid_delegate_msisdn,
|
|
threepid_creds,
|
|
)
|
|
elif medium == "email":
|
|
if self.hs.config.email.can_verify_email:
|
|
threepid = None
|
|
row = await self.store.get_threepid_validation_session(
|
|
medium,
|
|
threepid_creds["client_secret"],
|
|
sid=threepid_creds["sid"],
|
|
validated=True,
|
|
)
|
|
|
|
if row:
|
|
threepid = {
|
|
"medium": row.medium,
|
|
"address": row.address,
|
|
"validated_at": row.validated_at,
|
|
}
|
|
|
|
# Valid threepid returned, delete from the db
|
|
await self.store.delete_threepid_session(threepid_creds["sid"])
|
|
else:
|
|
raise SynapseError(
|
|
400, "Email address verification is not enabled on this homeserver"
|
|
)
|
|
else:
|
|
# this can't happen!
|
|
raise AssertionError("Unrecognized threepid medium: %s" % (medium,))
|
|
|
|
if not threepid:
|
|
raise LoginError(
|
|
401, "Unable to get validated threepid", errcode=Codes.UNAUTHORIZED
|
|
)
|
|
|
|
if threepid["medium"] != medium:
|
|
raise LoginError(
|
|
401,
|
|
"Expecting threepid of type '%s', got '%s'"
|
|
% (medium, threepid["medium"]),
|
|
errcode=Codes.UNAUTHORIZED,
|
|
)
|
|
|
|
threepid["threepid_creds"] = authdict["threepid_creds"]
|
|
|
|
return threepid
|
|
|
|
|
|
class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
|
AUTH_TYPE = LoginType.EMAIL_IDENTITY
|
|
|
|
def __init__(self, hs: "HomeServer"):
|
|
UserInteractiveAuthChecker.__init__(self, hs)
|
|
_BaseThreepidAuthChecker.__init__(self, hs)
|
|
|
|
def is_enabled(self) -> bool:
|
|
return self.hs.config.email.can_verify_email
|
|
|
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
|
return await self._check_threepid("email", authdict)
|
|
|
|
|
|
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
|
|
AUTH_TYPE = LoginType.MSISDN
|
|
|
|
def __init__(self, hs: "HomeServer"):
|
|
UserInteractiveAuthChecker.__init__(self, hs)
|
|
_BaseThreepidAuthChecker.__init__(self, hs)
|
|
|
|
def is_enabled(self) -> bool:
|
|
return bool(self.hs.config.registration.account_threepid_delegate_msisdn)
|
|
|
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
|
return await self._check_threepid("msisdn", authdict)
|
|
|
|
|
|
class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
|
|
AUTH_TYPE = LoginType.REGISTRATION_TOKEN
|
|
|
|
def __init__(self, hs: "HomeServer"):
|
|
super().__init__(hs)
|
|
self.hs = hs
|
|
self._enabled = bool(
|
|
hs.config.registration.registration_requires_token
|
|
) or bool(hs.config.registration.enable_registration_token_3pid_bypass)
|
|
self.store = hs.get_datastores().main
|
|
|
|
def is_enabled(self) -> bool:
|
|
return self._enabled
|
|
|
|
async def check_auth(self, authdict: dict, clientip: str) -> Any:
|
|
if "token" not in authdict:
|
|
raise LoginError(400, "Missing registration token", Codes.MISSING_PARAM)
|
|
if not isinstance(authdict["token"], str):
|
|
raise LoginError(
|
|
400, "Registration token must be a string", Codes.INVALID_PARAM
|
|
)
|
|
if "session" not in authdict:
|
|
raise LoginError(400, "Missing UIA session", Codes.MISSING_PARAM)
|
|
|
|
# Get these here to avoid cyclic dependencies
|
|
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
|
|
|
auth_handler = self.hs.get_auth_handler()
|
|
|
|
session = authdict["session"]
|
|
token = authdict["token"]
|
|
|
|
# If the LoginType.REGISTRATION_TOKEN stage has already been completed,
|
|
# return early to avoid incrementing `pending` again.
|
|
stored_token = await auth_handler.get_session_data(
|
|
session, UIAuthSessionDataConstants.REGISTRATION_TOKEN
|
|
)
|
|
if stored_token:
|
|
if token != stored_token:
|
|
raise LoginError(
|
|
400, "Registration token has changed", Codes.INVALID_PARAM
|
|
)
|
|
else:
|
|
return token
|
|
|
|
if await self.store.registration_token_is_valid(token):
|
|
# Increment pending counter, so that if token has limited uses it
|
|
# can't be used up by someone else in the meantime.
|
|
await self.store.set_registration_token_pending(token)
|
|
# Store the token in the UIA session, so that once registration
|
|
# is complete `completed` can be incremented.
|
|
await auth_handler.set_session_data(
|
|
session,
|
|
UIAuthSessionDataConstants.REGISTRATION_TOKEN,
|
|
token,
|
|
)
|
|
# The token will be stored as the result of the authentication stage
|
|
# in ui_auth_sessions_credentials. This allows the pending counter
|
|
# for tokens to be decremented when expired sessions are deleted.
|
|
return token
|
|
else:
|
|
raise LoginError(
|
|
401, "Invalid registration token", errcode=Codes.UNAUTHORIZED
|
|
)
|
|
|
|
|
|
INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [
|
|
DummyAuthChecker,
|
|
TermsAuthChecker,
|
|
RecaptchaAuthChecker,
|
|
EmailIdentityAuthChecker,
|
|
MsisdnAuthChecker,
|
|
RegistrationTokenAuthChecker,
|
|
]
|
|
"""A list of UserInteractiveAuthChecker classes"""
|