Allow admins to require a manual approval process before new accounts can be used (using MSC3866) (#13556)

This commit is contained in:
Brendan Abolivier 2022-09-29 14:23:24 +01:00 committed by GitHub
parent 8625ad8099
commit be76cd8200
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 731 additions and 34 deletions

View File

@ -0,0 +1 @@
Allow server admins to require a manual approval process before new accounts can be used (using [MSC3866](https://github.com/matrix-org/matrix-spec-proposals/pull/3866)).

View File

@ -107,7 +107,7 @@ BOOLEAN_COLUMNS = {
"redactions": ["have_censored"], "redactions": ["have_censored"],
"room_stats_state": ["is_federatable"], "room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"], "local_media_repository": ["safe_from_quarantine"],
"users": ["shadow_banned"], "users": ["shadow_banned", "approved"],
"e2e_fallback_keys_json": ["used"], "e2e_fallback_keys_json": ["used"],
"access_tokens": ["used"], "access_tokens": ["used"],
"device_lists_changes_in_room": ["converted_to_destinations"], "device_lists_changes_in_room": ["converted_to_destinations"],

View File

@ -269,3 +269,14 @@ class PublicRoomsFilterFields:
GENERIC_SEARCH_TERM: Final = "generic_search_term" GENERIC_SEARCH_TERM: Final = "generic_search_term"
ROOM_TYPES: Final = "room_types" ROOM_TYPES: Final = "room_types"
class ApprovalNoticeMedium:
"""Identifier for the medium this server will use to serve notice of approval for a
specific user's registration.
As defined in https://github.com/matrix-org/matrix-spec-proposals/blob/babolivier/m_not_approved/proposals/3866-user-not-approved-error.md
"""
NONE = "org.matrix.msc3866.none"
EMAIL = "org.matrix.msc3866.email"

View File

@ -106,6 +106,8 @@ class Codes(str, Enum):
# Part of MSC3895. # Part of MSC3895.
UNABLE_DUE_TO_PARTIAL_STATE = "ORG.MATRIX.MSC3895_UNABLE_DUE_TO_PARTIAL_STATE" UNABLE_DUE_TO_PARTIAL_STATE = "ORG.MATRIX.MSC3895_UNABLE_DUE_TO_PARTIAL_STATE"
USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes. """An exception with integer code and message string attributes.
@ -566,6 +568,20 @@ class UnredactedContentDeletedError(SynapseError):
return cs_error(self.msg, self.errcode, **extra) return cs_error(self.msg, self.errcode, **extra)
class NotApprovedError(SynapseError):
def __init__(
self,
msg: str,
approval_notice_medium: str,
):
super().__init__(
code=403,
msg=msg,
errcode=Codes.USER_AWAITING_APPROVAL,
additional_fields={"approval_notice_medium": approval_notice_medium},
)
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict": def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
"""Utility method for constructing an error response for client-server """Utility method for constructing an error response for client-server
interactions. interactions.

View File

@ -14,10 +14,25 @@
from typing import Any from typing import Any
import attr
from synapse.config._base import Config from synapse.config._base import Config
from synapse.types import JsonDict from synapse.types import JsonDict
@attr.s(auto_attribs=True, frozen=True, slots=True)
class MSC3866Config:
"""Configuration for MSC3866 (mandating approval for new users)"""
# Whether the base support for the approval process is enabled. This includes the
# ability for administrators to check and update the approval of users, even if no
# approval is currently required.
enabled: bool = False
# Whether to require that new users are approved by an admin before their account
# can be used. Note that this setting is ignored if 'enabled' is false.
require_approval_for_new_accounts: bool = False
class ExperimentalConfig(Config): class ExperimentalConfig(Config):
"""Config section for enabling experimental features""" """Config section for enabling experimental features"""
@ -97,6 +112,10 @@ class ExperimentalConfig(Config):
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
# MSC3866: M_USER_AWAITING_APPROVAL error code
raw_msc3866_config = experimental.get("msc3866", {})
self.msc3866 = MSC3866Config(**raw_msc3866_config)
# MSC3881: Remotely toggle push notifications for another client # MSC3881: Remotely toggle push notifications for another client
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)

View File

@ -32,6 +32,7 @@ class AdminHandler:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state self._state_storage_controller = self._storage_controllers.state
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
async def get_whois(self, user: UserID) -> JsonDict: async def get_whois(self, user: UserID) -> JsonDict:
connections = [] connections = []
@ -75,6 +76,10 @@ class AdminHandler:
"is_guest", "is_guest",
} }
if self._msc3866_enabled:
# Only include the approved flag if support for MSC3866 is enabled.
user_info_to_return.add("approved")
# Restrict returned keys to a known set. # Restrict returned keys to a known set.
user_info_dict = { user_info_dict = {
key: value key: value

View File

@ -1009,6 +1009,17 @@ class AuthHandler:
return res[0] return res[0]
return None return None
async def is_user_approved(self, user_id: str) -> bool:
"""Checks if a user is approved and therefore can be allowed to log in.
Args:
user_id: the user to check the approval status of.
Returns:
A boolean that is True if the user is approved, False otherwise.
"""
return await self.store.is_user_approved(user_id)
async def _find_user_id_and_pwd_hash( async def _find_user_id_and_pwd_hash(
self, user_id: str self, user_id: str
) -> Optional[Tuple[str, str]]: ) -> Optional[Tuple[str, str]]:

View File

@ -220,6 +220,7 @@ class RegistrationHandler:
by_admin: bool = False, by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None, user_agent_ips: Optional[List[Tuple[str, str]]] = None,
auth_provider_id: Optional[str] = None, auth_provider_id: Optional[str] = None,
approved: bool = False,
) -> str: ) -> str:
"""Registers a new client on the server. """Registers a new client on the server.
@ -246,6 +247,8 @@ class RegistrationHandler:
user_agent_ips: Tuples of user-agents and IP addresses used user_agent_ips: Tuples of user-agents and IP addresses used
during the registration process. during the registration process.
auth_provider_id: The SSO IdP the user used, if any. auth_provider_id: The SSO IdP the user used, if any.
approved: True if the new user should be considered already
approved by an administrator.
Returns: Returns:
The registered user_id. The registered user_id.
Raises: Raises:
@ -307,6 +310,7 @@ class RegistrationHandler:
user_type=user_type, user_type=user_type,
address=address, address=address,
shadow_banned=shadow_banned, shadow_banned=shadow_banned,
approved=approved,
) )
profile = await self.store.get_profileinfo(localpart) profile = await self.store.get_profileinfo(localpart)
@ -695,6 +699,7 @@ class RegistrationHandler:
user_type: Optional[str] = None, user_type: Optional[str] = None,
address: Optional[str] = None, address: Optional[str] = None,
shadow_banned: bool = False, shadow_banned: bool = False,
approved: bool = False,
) -> None: ) -> None:
"""Register user in the datastore. """Register user in the datastore.
@ -713,6 +718,7 @@ class RegistrationHandler:
api.constants.UserTypes, or None for a normal user. api.constants.UserTypes, or None for a normal user.
address: the IP address used to perform the registration. address: the IP address used to perform the registration.
shadow_banned: Whether to shadow-ban the user shadow_banned: Whether to shadow-ban the user
approved: Whether to mark the user as approved by an administrator
""" """
if self.hs.config.worker.worker_app: if self.hs.config.worker.worker_app:
await self._register_client( await self._register_client(
@ -726,6 +732,7 @@ class RegistrationHandler:
user_type=user_type, user_type=user_type,
address=address, address=address,
shadow_banned=shadow_banned, shadow_banned=shadow_banned,
approved=approved,
) )
else: else:
await self.store.register_user( await self.store.register_user(
@ -738,6 +745,7 @@ class RegistrationHandler:
admin=admin, admin=admin,
user_type=user_type, user_type=user_type,
shadow_banned=shadow_banned, shadow_banned=shadow_banned,
approved=approved,
) )
# Only call the account validity module(s) on the main process, to avoid # Only call the account validity module(s) on the main process, to avoid

View File

@ -51,6 +51,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type: Optional[str], user_type: Optional[str],
address: Optional[str], address: Optional[str],
shadow_banned: bool, shadow_banned: bool,
approved: bool,
) -> JsonDict: ) -> JsonDict:
""" """
Args: Args:
@ -68,6 +69,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
or None for a normal user. or None for a normal user.
address: the IP address used to perform the regitration. address: the IP address used to perform the regitration.
shadow_banned: Whether to shadow-ban the user shadow_banned: Whether to shadow-ban the user
approved: Whether the user should be considered already approved by an
administrator.
""" """
return { return {
"password_hash": password_hash, "password_hash": password_hash,
@ -79,6 +82,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"user_type": user_type, "user_type": user_type,
"address": address, "address": address,
"shadow_banned": shadow_banned, "shadow_banned": shadow_banned,
"approved": approved,
} }
async def _handle_request( # type: ignore[override] async def _handle_request( # type: ignore[override]
@ -99,6 +103,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
user_type=content["user_type"], user_type=content["user_type"],
address=content["address"], address=content["address"],
shadow_banned=content["shadow_banned"], shadow_banned=content["shadow_banned"],
approved=content["approved"],
) )
return 200, {} return 200, {}

View File

@ -69,6 +69,7 @@ class UsersRestServletV2(RestServlet):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
@ -95,6 +96,13 @@ class UsersRestServletV2(RestServlet):
guests = parse_boolean(request, "guests", default=True) guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False) deactivated = parse_boolean(request, "deactivated", default=False)
# If support for MSC3866 is not enabled, apply no filtering based on the
# `approved` column.
if self._msc3866_enabled:
approved = parse_boolean(request, "approved", default=True)
else:
approved = True
order_by = parse_string( order_by = parse_string(
request, request,
"order_by", "order_by",
@ -115,8 +123,22 @@ class UsersRestServletV2(RestServlet):
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
users, total = await self.store.get_users_paginate( users, total = await self.store.get_users_paginate(
start, limit, user_id, name, guests, deactivated, order_by, direction start,
limit,
user_id,
name,
guests,
deactivated,
order_by,
direction,
approved,
) )
# If support for MSC3866 is not enabled, don't show the approval flag.
if not self._msc3866_enabled:
for user in users:
del user["approved"]
ret = {"users": users, "total": total} ret = {"users": users, "total": total}
if (start + limit) < total: if (start + limit) < total:
ret["next_token"] = str(start + len(users)) ret["next_token"] = str(start + len(users))
@ -163,6 +185,7 @@ class UserRestServletV2(RestServlet):
self.deactivate_account_handler = hs.get_deactivate_account_handler() self.deactivate_account_handler = hs.get_deactivate_account_handler()
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
@ -239,6 +262,15 @@ class UserRestServletV2(RestServlet):
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean" HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
) )
approved: Optional[bool] = None
if "approved" in body and self._msc3866_enabled:
approved = body["approved"]
if not isinstance(approved, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"'approved' parameter is not of type boolean",
)
# convert List[Dict[str, str]] into List[Tuple[str, str]] # convert List[Dict[str, str]] into List[Tuple[str, str]]
if external_ids is not None: if external_ids is not None:
new_external_ids = [ new_external_ids = [
@ -343,6 +375,9 @@ class UserRestServletV2(RestServlet):
if "user_type" in body: if "user_type" in body:
await self.store.set_user_type(target_user, user_type) await self.store.set_user_type(target_user, user_type)
if approved is not None:
await self.store.update_user_approval_status(target_user, approved)
user = await self.admin_handler.get_user(target_user) user = await self.admin_handler.get_user(target_user)
assert user is not None assert user is not None
@ -355,6 +390,10 @@ class UserRestServletV2(RestServlet):
if password is not None: if password is not None:
password_hash = await self.auth_handler.hash(password) password_hash = await self.auth_handler.hash(password)
new_user_approved = True
if self._msc3866_enabled and approved is not None:
new_user_approved = approved
user_id = await self.registration_handler.register_user( user_id = await self.registration_handler.register_user(
localpart=target_user.localpart, localpart=target_user.localpart,
password_hash=password_hash, password_hash=password_hash,
@ -362,6 +401,7 @@ class UserRestServletV2(RestServlet):
default_display_name=displayname, default_display_name=displayname,
user_type=user_type, user_type=user_type,
by_admin=True, by_admin=True,
approved=new_user_approved,
) )
if threepids is not None: if threepids is not None:
@ -550,6 +590,7 @@ class UserRegisterServlet(RestServlet):
user_type=user_type, user_type=user_type,
default_display_name=displayname, default_display_name=displayname,
by_admin=True, by_admin=True,
approved=True,
) )
result = await register._create_registration_details(user_id, body) result = await register._create_registration_details(user_id, body)

View File

@ -28,7 +28,14 @@ from typing import (
from typing_extensions import TypedDict from typing_extensions import TypedDict
from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError from synapse.api.constants import ApprovalNoticeMedium
from synapse.api.errors import (
Codes,
InvalidClientTokenError,
LoginError,
NotApprovedError,
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.api.urls import CLIENT_API_PREFIX from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
@ -55,11 +62,11 @@ logger = logging.getLogger(__name__)
class LoginResponse(TypedDict, total=False): class LoginResponse(TypedDict, total=False):
user_id: str user_id: str
access_token: str access_token: Optional[str]
home_server: str home_server: str
expires_in_ms: Optional[int] expires_in_ms: Optional[int]
refresh_token: Optional[str] refresh_token: Optional[str]
device_id: str device_id: Optional[str]
well_known: Optional[Dict[str, Any]] well_known: Optional[Dict[str, Any]]
@ -92,6 +99,12 @@ class LoginRestServlet(RestServlet):
hs.config.registration.refreshable_access_token_lifetime is not None hs.config.registration.refreshable_access_token_lifetime is not None
) )
# Whether we need to check if the user has been approved or not.
self._require_approval = (
hs.config.experimental.msc3866.enabled
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -220,6 +233,14 @@ class LoginRestServlet(RestServlet):
except KeyError: except KeyError:
raise SynapseError(400, "Missing JSON keys.") raise SynapseError(400, "Missing JSON keys.")
if self._require_approval:
approved = await self.auth_handler.is_user_approved(result["user_id"])
if not approved:
raise NotApprovedError(
msg="This account is pending approval by a server administrator.",
approval_notice_medium=ApprovalNoticeMedium.NONE,
)
well_known_data = self._well_known_builder.get_well_known() well_known_data = self._well_known_builder.get_well_known()
if well_known_data: if well_known_data:
result["well_known"] = well_known_data result["well_known"] = well_known_data
@ -356,6 +377,16 @@ class LoginRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
if self._require_approval:
approved = await self.auth_handler.is_user_approved(user_id)
if not approved:
# If the user isn't approved (and needs to be) we won't allow them to
# actually log in, so we don't want to create a device/access token.
return LoginResponse(
user_id=user_id,
home_server=self.hs.hostname,
)
initial_display_name = login_submission.get("initial_device_display_name") initial_display_name = login_submission.get("initial_device_display_name")
( (
device_id, device_id,

View File

@ -21,10 +21,15 @@ from twisted.web.server import Request
import synapse import synapse
import synapse.api.auth import synapse.api.auth
import synapse.types import synapse.types
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.constants import (
APP_SERVICE_REGISTRATION_TYPE,
ApprovalNoticeMedium,
LoginType,
)
from synapse.api.errors import ( from synapse.api.errors import (
Codes, Codes,
InteractiveAuthIncompleteError, InteractiveAuthIncompleteError,
NotApprovedError,
SynapseError, SynapseError,
ThreepidValidationError, ThreepidValidationError,
UnrecognizedRequestError, UnrecognizedRequestError,
@ -414,6 +419,11 @@ class RegisterRestServlet(RestServlet):
hs.config.registration.inhibit_user_in_use_error hs.config.registration.inhibit_user_in_use_error
) )
self._require_approval = (
hs.config.experimental.msc3866.enabled
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
self._registration_flows = _calculate_registration_flows( self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler hs.config, self.auth_handler
) )
@ -734,6 +744,12 @@ class RegisterRestServlet(RestServlet):
access_token=return_dict.get("access_token"), access_token=return_dict.get("access_token"),
) )
if self._require_approval:
raise NotApprovedError(
msg="This account needs to be approved by an administrator before it can be used.",
approval_notice_medium=ApprovalNoticeMedium.NONE,
)
return 200, return_dict return 200, return_dict
async def _do_appservice_registration( async def _do_appservice_registration(
@ -778,7 +794,9 @@ class RegisterRestServlet(RestServlet):
"user_id": user_id, "user_id": user_id,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
if not params.get("inhibit_login", False): # We don't want to log the user in if we're going to deny them access because
# they need to be approved first.
if not params.get("inhibit_login", False) and not self._require_approval:
device_id = params.get("device_id") device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
( (

View File

@ -203,6 +203,7 @@ class DataStore(
deactivated: bool = False, deactivated: bool = False,
order_by: str = UserSortOrder.USER_ID.value, order_by: str = UserSortOrder.USER_ID.value,
direction: str = "f", direction: str = "f",
approved: bool = True,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from """Function to retrieve a paginated list of users from
users list. This will return a json list of users and the users list. This will return a json list of users and the
@ -217,6 +218,7 @@ class DataStore(
deactivated: whether to include deactivated users deactivated: whether to include deactivated users
order_by: the sort order of the returned list order_by: the sort order of the returned list
direction: sort ascending or descending direction: sort ascending or descending
approved: whether to include approved users
Returns: Returns:
A tuple of a list of mappings from user to information and a count of total users. A tuple of a list of mappings from user to information and a count of total users.
""" """
@ -249,6 +251,11 @@ class DataStore(
if not deactivated: if not deactivated:
filters.append("deactivated = 0") filters.append("deactivated = 0")
if not approved:
# We ignore NULL values for the approved flag because these should only
# be already existing users that we consider as already approved.
filters.append("approved IS FALSE")
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else "" where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
sql_base = f""" sql_base = f"""
@ -262,7 +269,7 @@ class DataStore(
sql = f""" sql = f"""
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
displayname, avatar_url, creation_ts * 1000 as creation_ts displayname, avatar_url, creation_ts * 1000 as creation_ts, approved
{sql_base} {sql_base}
ORDER BY {order_by_column} {order}, u.name ASC ORDER BY {order_by_column} {order}, u.name ASC
LIMIT ? OFFSET ? LIMIT ? OFFSET ?

View File

@ -166,27 +166,49 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
@cached() @cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead""" """Deprecated: use get_userinfo_by_id instead"""
return await self.db_pool.simple_select_one(
table="users", def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
keyvalues={"name": user_id}, # We could technically use simple_select_one here, but it would not perform
retcols=[ # the COALESCEs (unless hacked into the column names), which could yield
"name", # confusing results.
"password_hash", txn.execute(
"is_guest", """
"admin", SELECT
"consent_version", name, password_hash, is_guest, admin, consent_version, consent_ts,
"consent_ts", consent_server_notice_sent, appservice_id, creation_ts, user_type,
"consent_server_notice_sent", deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
"appservice_id", COALESCE(approved, TRUE) AS approved
"creation_ts", FROM users
"user_type", WHERE name = ?
"deactivated", """,
"shadow_banned", (user_id,),
],
allow_none=True,
desc="get_user_by_id",
) )
rows = self.db_pool.cursor_to_dict(txn)
if len(rows) == 0:
return None
return rows[0]
row = await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
if row is not None:
# If we're using SQLite our boolean values will be integers. Because we
# present some of this data as is to e.g. server admins via REST APIs, we
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
for column in boolean_columns:
if not isinstance(row[column], bool):
row[column] = bool(row[column])
return row
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]: async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID. """Get a UserInfo object for a user by user ID.
@ -1779,6 +1801,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return res if res else False return res if res else False
@cached()
async def is_user_approved(self, user_id: str) -> bool:
"""Checks if a user is approved and therefore can be allowed to log in.
If the user's 'approved' column is NULL, we consider it as true given it means
the user was registered when support for an approval flow was either disabled
or nonexistent.
Args:
user_id: the user to check the approval status of.
Returns:
A boolean that is True if the user is approved, False otherwise.
"""
def is_user_approved_txn(txn: LoggingTransaction) -> bool:
txn.execute(
"""
SELECT COALESCE(approved, TRUE) AS approved FROM users WHERE name = ?
""",
(user_id,),
)
rows = self.db_pool.cursor_to_dict(txn)
# We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite.
return bool(rows[0]["approved"])
return await self.db_pool.runInteraction(
desc="is_user_pending_approval",
func=is_user_approved_txn,
)
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__( def __init__(
@ -1916,6 +1972,29 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,))
def update_user_approval_status_txn(
self, txn: LoggingTransaction, user_id: str, approved: bool
) -> None:
"""Set the user's 'approved' flag to the given value.
The boolean is turned into an int because the column is a smallint.
Args:
txn: the current database transaction.
user_id: the user to update the flag for.
approved: the value to set the flag to.
"""
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"approved": approved},
)
# Invalidate the caches of methods that read the value of the 'approved' flag.
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__( def __init__(
@ -1933,6 +2012,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
# If support for MSC3866 is enabled and configured to require approval for new
# account, we will create new users with an 'approved' flag set to false.
self._require_approval = (
hs.config.experimental.msc3866.enabled
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
async def add_access_token_to_user( async def add_access_token_to_user(
self, self,
user_id: str, user_id: str,
@ -2065,6 +2151,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool = False, admin: bool = False,
user_type: Optional[str] = None, user_type: Optional[str] = None,
shadow_banned: bool = False, shadow_banned: bool = False,
approved: bool = False,
) -> None: ) -> None:
"""Attempts to register an account. """Attempts to register an account.
@ -2083,6 +2170,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
or None for a normal user. or None for a normal user.
shadow_banned: Whether the user is shadow-banned, i.e. they may be shadow_banned: Whether the user is shadow-banned, i.e. they may be
told their requests succeeded but we ignore them. told their requests succeeded but we ignore them.
approved: Whether to consider the user has already been approved by an
administrator.
Raises: Raises:
StoreError if the user_id could not be registered. StoreError if the user_id could not be registered.
@ -2099,6 +2188,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin, admin,
user_type, user_type,
shadow_banned, shadow_banned,
approved,
) )
def _register_user( def _register_user(
@ -2113,11 +2203,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool, admin: bool,
user_type: Optional[str], user_type: Optional[str],
shadow_banned: bool, shadow_banned: bool,
approved: bool,
) -> None: ) -> None:
user_id_obj = UserID.from_string(user_id) user_id_obj = UserID.from_string(user_id)
now = int(self._clock.time()) now = int(self._clock.time())
user_approved = approved or not self._require_approval
try: try:
if was_guest: if was_guest:
# Ensure that the guest user actually exists # Ensure that the guest user actually exists
@ -2143,6 +2236,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"admin": 1 if admin else 0, "admin": 1 if admin else 0,
"user_type": user_type, "user_type": user_type,
"shadow_banned": shadow_banned, "shadow_banned": shadow_banned,
"approved": user_approved,
}, },
) )
else: else:
@ -2158,6 +2252,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"admin": 1 if admin else 0, "admin": 1 if admin else 0,
"user_type": user_type, "user_type": user_type,
"shadow_banned": shadow_banned, "shadow_banned": shadow_banned,
"approved": user_approved,
}, },
) )
@ -2503,6 +2598,25 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn, start_or_continue_validation_session_txn,
) )
async def update_user_approval_status(
self, user_id: UserID, approved: bool
) -> None:
"""Set the user's 'approved' flag to the given value.
The boolean will be turned into an int (in update_user_approval_status_txn)
because the column is a smallint.
Args:
user_id: the user to update the flag for.
approved: the value to set the flag to.
"""
await self.db_pool.runInteraction(
"update_user_approval_status",
self.update_user_approval_status_txn,
user_id.to_string(),
approved,
)
def find_max_generated_user_id_localpart(cur: Cursor) -> int: def find_max_generated_user_id_localpart(cur: Cursor) -> int:
""" """

View File

@ -0,0 +1,20 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a column to the users table to track whether the user needs to be approved by an
-- administrator.
-- A NULL column means the user was created before this feature was supported by Synapse,
-- and should be considered as TRUE.
ALTER TABLE users ADD COLUMN approved BOOLEAN;

View File

@ -25,10 +25,10 @@ from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import UserTypes from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.rest.client import devices, login, logout, profile, room, sync from synapse.rest.client import devices, login, logout, profile, register, room, sync
from synapse.rest.media.v1.filepath import MediaFilePaths from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
@ -578,6 +578,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id") _search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id") _search_test(None, "bar", "user_id")
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_invalid_parameter(self) -> None: def test_invalid_parameter(self) -> None:
""" """
If parameters are invalid, an error is returned. If parameters are invalid, an error is returned.
@ -623,6 +633,16 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid approved
channel = self.make_request(
"GET",
self.url + "?approved=not_bool",
access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# unkown order_by # unkown order_by
channel = self.make_request( channel = self.make_request(
"GET", "GET",
@ -841,6 +861,69 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self._order_test([self.admin_user, user1, user2], "creation_ts", "f") self._order_test([self.admin_user, user1, user2], "creation_ts", "f")
self._order_test([user2, user1, self.admin_user], "creation_ts", "b") self._order_test([user2, user1, self.admin_user], "creation_ts", "b")
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_filter_out_approved(self) -> None:
"""Tests that the endpoint can filter out approved users."""
# Create our users.
self._create_users(2)
# Get the list of users.
channel = self.make_request(
"GET",
self.url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, channel.result)
# Exclude the admin, because we don't want to accidentally un-approve the admin.
non_admin_user_ids = [
user["name"]
for user in channel.json_body["users"]
if user["name"] != self.admin_user
]
self.assertEqual(2, len(non_admin_user_ids), non_admin_user_ids)
# Select a user and un-approve them. We do this rather than the other way around
# because, since these users are created by an admin, we consider them already
# approved.
not_approved_user = non_admin_user_ids[0]
channel = self.make_request(
"PUT",
f"/_synapse/admin/v2/users/{not_approved_user}",
{"approved": False},
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, channel.result)
# Now get the list of users again, this time filtering out approved users.
channel = self.make_request(
"GET",
self.url + "?approved=false",
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, channel.result)
non_admin_user_ids = [
user["name"]
for user in channel.json_body["users"]
if user["name"] != self.admin_user
]
# We should only have our unapproved user now.
self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids)
self.assertEqual(not_approved_user, non_admin_user_ids[0])
def _order_test( def _order_test(
self, self,
expected_user_list: List[str], expected_user_list: List[str],
@ -1272,6 +1355,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets,
login.register_servlets, login.register_servlets,
sync.register_servlets, sync.register_servlets,
register.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@ -2536,6 +2620,104 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Ensure they're still alive # Ensure they're still alive
self.assertEqual(0, channel.json_body["deactivated"]) self.assertEqual(0, channel.json_body["deactivated"])
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_approve_account(self) -> None:
"""Tests that approving an account correctly sets the approved flag for the user."""
url = self.url_prefix % "@bob:test"
# Create the user using the client-server API since otherwise the user will be
# marked as approved automatically.
channel = self.make_request(
"POST",
"register",
{
"username": "bob",
"password": "test",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
self.assertEqual(
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
# Get user
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIs(False, channel.json_body["approved"])
# Approve user
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content={"approved": True},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIs(True, channel.json_body["approved"])
# Check that the user is now approved
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIs(True, channel.json_body["approved"])
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_register_approved(self) -> None:
url = self.url_prefix % "@bob:test"
# Create user
channel = self.make_request(
"PUT",
url,
access_token=self.admin_user_tok,
content={"password": "abc123", "approved": True},
)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual(1, channel.json_body["approved"])
# Get user
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual(1, channel.json_body["approved"])
def _is_erased(self, user_id: str, expect: bool) -> None: def _is_erased(self, user_id: str, expect: bool) -> None:
"""Assert that the user is erased or not""" """Assert that the user is erased or not"""
d = self.store.is_user_erased(user_id) d = self.store.is_user_erased(user_id)

View File

@ -20,7 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import LoginType from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.synapse.client import build_synapse_client_resource_tree
@ -567,6 +568,36 @@ class UIAuthTests(unittest.HomeserverTestCase):
body={"auth": {"session": session_id}}, body={"auth": {"session": session_id}},
) )
@skip_unless(HAS_OIDC, "requires OIDC")
@override_config(
{
"oidc_config": TEST_OIDC_CONFIG,
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
},
}
)
def test_sso_not_approved(self) -> None:
"""Tests that if we register a user via SSO while requiring approval for new
accounts, we still raise the correct error before logging the user in.
"""
login_resp = self.helper.login_via_oidc("username", expected_status=403)
self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
self.assertEqual(
ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"]
)
# Check that we didn't register a device for the user during the login attempt.
devices = self.get_success(
self.hs.get_datastores().main.get_devices_by_user("@username:test")
)
self.assertEqual(len(devices), 0)
class RefreshAuthTests(unittest.HomeserverTestCase): class RefreshAuthTests(unittest.HomeserverTestCase):
servlets = [ servlets = [

View File

@ -23,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource from twisted.web.resource import Resource
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.rest.client import devices, login, logout, register from synapse.rest.client import devices, login, logout, register
from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.client.account import WhoamiRestServlet
@ -94,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
logout.register_servlets, logout.register_servlets,
devices.register_servlets, devices.register_servlets,
lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), lambda hs, http_server: WhoamiRestServlet(hs).register(http_server),
register.register_servlets,
] ]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
@ -406,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400) self.assertEqual(channel.code, 400)
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_require_approval(self) -> None:
channel = self.make_request(
"POST",
"register",
{
"username": "kermit",
"password": "monkey",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
self.assertEqual(
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
params = {
"type": LoginType.PASSWORD,
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
channel = self.make_request("POST", LOGIN_URL, params)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
self.assertEqual(
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
@skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC")
class MultiSSOTestCase(unittest.HomeserverTestCase): class MultiSSOTestCase(unittest.HomeserverTestCase):

View File

@ -22,7 +22,11 @@ import pkg_resources
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.constants import (
APP_SERVICE_REGISTRATION_TYPE,
ApprovalNoticeMedium,
LoginType,
)
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync from synapse.rest.client import account, account_validity, login, logout, register, sync
@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.json_body) self.assertEqual(channel.code, 400, channel.json_body)
self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE) self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
}
)
def test_require_approval(self) -> None:
channel = self.make_request(
"POST",
"register",
{
"username": "kermit",
"password": "monkey",
"auth": {"type": LoginType.DUMMY},
},
)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"])
self.assertEqual(
ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"]
)
class AccountValidityTestCase(unittest.HomeserverTestCase): class AccountValidityTestCase(unittest.HomeserverTestCase):

View File

@ -543,8 +543,12 @@ class RestHelper:
return channel.json_body return channel.json_body
def login_via_oidc(self, remote_user_id: str) -> JsonDict: def login_via_oidc(
"""Log in (as a new user) via OIDC self,
remote_user_id: str,
expected_status: int = 200,
) -> JsonDict:
"""Log in via OIDC
Returns the result of the final token login. Returns the result of the final token login.
@ -578,7 +582,9 @@ class RestHelper:
"/login", "/login",
content={"type": "m.login.token", "token": login_token}, content={"type": "m.login.token", "token": login_token},
) )
assert channel.code == HTTPStatus.OK assert (
channel.code == expected_status
), f"unexpected status in response: {channel.code}"
return channel.json_body return channel.json_body
def auth_via_oidc( def auth_via_oidc(

View File

@ -16,9 +16,10 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError from synapse.api.errors import ThreepidValidationError
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase, override_config
class RegistrationStoreTestCase(HomeserverTestCase): class RegistrationStoreTestCase(HomeserverTestCase):
@ -48,6 +49,7 @@ class RegistrationStoreTestCase(HomeserverTestCase):
"user_type": None, "user_type": None,
"deactivated": 0, "deactivated": 0,
"shadow_banned": 0, "shadow_banned": 0,
"approved": 1,
}, },
(self.get_success(self.store.get_user_by_id(self.user_id))), (self.get_success(self.store.get_user_by_id(self.user_id))),
) )
@ -166,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase):
ThreepidValidationError, ThreepidValidationError,
) )
self.assertEqual(e.value.msg, "Validation token not found or has expired", e) self.assertEqual(e.value.msg, "Validation token not found or has expired", e)
class ApprovalRequiredRegistrationTestCase(HomeserverTestCase):
def default_config(self) -> JsonDict:
config = super().default_config()
# If there's already some config for this feature in the default config, it
# means we're overriding it with @override_config. In this case we don't want
# to do anything more with it.
msc3866_config = config.get("experimental_features", {}).get("msc3866")
if msc3866_config is not None:
return config
# Require approval for all new accounts.
config["experimental_features"] = {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": True,
}
}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = "@my-user:test"
self.pwhash = "{xx1}123456789"
@override_config(
{
"experimental_features": {
"msc3866": {
"enabled": True,
"require_approval_for_new_accounts": False,
}
}
}
)
def test_approval_not_required(self) -> None:
"""Tests that if we don't require approval for new accounts, newly created
accounts are automatically marked as approved.
"""
self.get_success(self.store.register_user(self.user_id, self.pwhash))
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
self.assertTrue(user["approved"])
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
def test_approval_required(self) -> None:
"""Tests that if we require approval for new accounts, newly created accounts
are not automatically marked as approved.
"""
self.get_success(self.store.register_user(self.user_id, self.pwhash))
user = self.get_success(self.store.get_user_by_id(self.user_id))
assert user is not None
self.assertFalse(user["approved"])
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved)
def test_override(self) -> None:
"""Tests that if we require approval for new accounts, but we explicitly say the
new user should be considered approved, they're marked as approved.
"""
self.get_success(
self.store.register_user(
self.user_id,
self.pwhash,
approved=True,
)
)
user = self.get_success(self.store.get_user_by_id(self.user_id))
self.assertIsNotNone(user)
assert user is not None
self.assertEqual(user["approved"], 1)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)
def test_approve_user(self) -> None:
"""Tests that approving the user updates their approval status."""
self.get_success(self.store.register_user(self.user_id, self.pwhash))
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertFalse(approved)
self.get_success(
self.store.update_user_approval_status(
UserID.from_string(self.user_id), True
)
)
approved = self.get_success(self.store.is_user_approved(self.user_id))
self.assertTrue(approved)