Allow admins to require a manual approval process before new accounts can be used (using MSC3866) (#13556)

This commit is contained in:
Brendan Abolivier 2022-09-29 14:23:24 +01:00 committed by GitHub
parent 8625ad8099
commit be76cd8200
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 731 additions and 34 deletions

View file

@ -166,27 +166,49 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
"name",
"password_hash",
"is_guest",
"admin",
"consent_version",
"consent_ts",
"consent_server_notice_sent",
"appservice_id",
"creation_ts",
"user_type",
"deactivated",
"shadow_banned",
],
allow_none=True,
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
# We could technically use simple_select_one here, but it would not perform
# the COALESCEs (unless hacked into the column names), which could yield
# confusing results.
txn.execute(
"""
SELECT
name, password_hash, is_guest, admin, consent_version, consent_ts,
consent_server_notice_sent, appservice_id, creation_ts, user_type,
deactivated, COALESCE(shadow_banned, FALSE) AS shadow_banned,
COALESCE(approved, TRUE) AS approved
FROM users
WHERE name = ?
""",
(user_id,),
)
rows = self.db_pool.cursor_to_dict(txn)
if len(rows) == 0:
return None
return rows[0]
row = await self.db_pool.runInteraction(
desc="get_user_by_id",
func=get_user_by_id_txn,
)
if row is not None:
# If we're using SQLite our boolean values will be integers. Because we
# present some of this data as is to e.g. server admins via REST APIs, we
# want to make sure we're returning the right type of data.
# Note: when adding a column name to this list, be wary of NULLable columns,
# since NULL values will be turned into False.
boolean_columns = ["admin", "deactivated", "shadow_banned", "approved"]
for column in boolean_columns:
if not isinstance(row[column], bool):
row[column] = bool(row[column])
return row
async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
"""Get a UserInfo object for a user by user ID.
@ -1779,6 +1801,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return res if res else False
@cached()
async def is_user_approved(self, user_id: str) -> bool:
"""Checks if a user is approved and therefore can be allowed to log in.
If the user's 'approved' column is NULL, we consider it as true given it means
the user was registered when support for an approval flow was either disabled
or nonexistent.
Args:
user_id: the user to check the approval status of.
Returns:
A boolean that is True if the user is approved, False otherwise.
"""
def is_user_approved_txn(txn: LoggingTransaction) -> bool:
txn.execute(
"""
SELECT COALESCE(approved, TRUE) AS approved FROM users WHERE name = ?
""",
(user_id,),
)
rows = self.db_pool.cursor_to_dict(txn)
# We cast to bool because the value returned by the database engine might
# be an integer if we're using SQLite.
return bool(rows[0]["approved"])
return await self.db_pool.runInteraction(
desc="is_user_pending_approval",
func=is_user_approved_txn,
)
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@ -1916,6 +1972,29 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
def update_user_approval_status_txn(
self, txn: LoggingTransaction, user_id: str, approved: bool
) -> None:
"""Set the user's 'approved' flag to the given value.
The boolean is turned into an int because the column is a smallint.
Args:
txn: the current database transaction.
user_id: the user to update the flag for.
approved: the value to set the flag to.
"""
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"approved": approved},
)
# Invalidate the caches of methods that read the value of the 'approved' flag.
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
self._invalidate_cache_and_stream(txn, self.is_user_approved, (user_id,))
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(
@ -1933,6 +2012,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
# If support for MSC3866 is enabled and configured to require approval for new
# account, we will create new users with an 'approved' flag set to false.
self._require_approval = (
hs.config.experimental.msc3866.enabled
and hs.config.experimental.msc3866.require_approval_for_new_accounts
)
async def add_access_token_to_user(
self,
user_id: str,
@ -2065,6 +2151,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool = False,
user_type: Optional[str] = None,
shadow_banned: bool = False,
approved: bool = False,
) -> None:
"""Attempts to register an account.
@ -2083,6 +2170,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
or None for a normal user.
shadow_banned: Whether the user is shadow-banned, i.e. they may be
told their requests succeeded but we ignore them.
approved: Whether to consider the user has already been approved by an
administrator.
Raises:
StoreError if the user_id could not be registered.
@ -2099,6 +2188,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin,
user_type,
shadow_banned,
approved,
)
def _register_user(
@ -2113,11 +2203,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
admin: bool,
user_type: Optional[str],
shadow_banned: bool,
approved: bool,
) -> None:
user_id_obj = UserID.from_string(user_id)
now = int(self._clock.time())
user_approved = approved or not self._require_approval
try:
if was_guest:
# Ensure that the guest user actually exists
@ -2143,6 +2236,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"admin": 1 if admin else 0,
"user_type": user_type,
"shadow_banned": shadow_banned,
"approved": user_approved,
},
)
else:
@ -2158,6 +2252,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"admin": 1 if admin else 0,
"user_type": user_type,
"shadow_banned": shadow_banned,
"approved": user_approved,
},
)
@ -2503,6 +2598,25 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn,
)
async def update_user_approval_status(
self, user_id: UserID, approved: bool
) -> None:
"""Set the user's 'approved' flag to the given value.
The boolean will be turned into an int (in update_user_approval_status_txn)
because the column is a smallint.
Args:
user_id: the user to update the flag for.
approved: the value to set the flag to.
"""
await self.db_pool.runInteraction(
"update_user_approval_status",
self.update_user_approval_status_txn,
user_id.to_string(),
approved,
)
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""