Port "Allow users to click account renewal links multiple times without hitting an 'Invalid Token' page #74" from synapse-dinsic (#9832)

This attempts to be a direct port of https://github.com/matrix-org/synapse-dinsic/pull/74 to mainline. There was some fiddling required to deal with the changes that have been made to mainline since (mainly dealing with the split of `RegistrationWorkerStore` from `RegistrationStore`, and the changes made to `self.make_request` in test code).
This commit is contained in:
Andrew Morgan 2021-04-19 19:16:34 +01:00 committed by GitHub
parent e694a598f8
commit 71f0623de9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 496 additions and 263 deletions

View file

@ -91,12 +91,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
id_column=None,
)
self._account_validity = hs.config.account_validity
if hs.config.run_background_tasks and self._account_validity.enabled:
self._clock.call_later(
0.0,
self._set_expiration_date_when_missing,
self._account_validity_enabled = (
hs.config.account_validity.account_validity_enabled
)
self._account_validity_period = None
self._account_validity_startup_job_max_delta = None
if self._account_validity_enabled:
self._account_validity_period = (
hs.config.account_validity.account_validity_period
)
self._account_validity_startup_job_max_delta = (
hs.config.account_validity.account_validity_startup_job_max_delta
)
if hs.config.run_background_tasks:
self._clock.call_later(
0.0,
self._set_expiration_date_when_missing,
)
# Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks:
@ -194,6 +206,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts: int,
email_sent: bool,
renewal_token: Optional[str] = None,
token_used_ts: Optional[int] = None,
) -> None:
"""Updates the account validity properties of the given account, with the
given values.
@ -207,6 +220,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
period.
renewal_token: Renewal token the user can use to extend the validity
of their account. Defaults to no token.
token_used_ts: A timestamp of when the current token was used to renew
the account.
"""
def set_account_validity_for_user_txn(txn):
@ -218,6 +233,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"expiration_ts_ms": expiration_ts,
"email_sent": email_sent,
"renewal_token": renewal_token,
"token_used_ts_ms": token_used_ts,
},
)
self._invalidate_cache_and_stream(
@ -231,7 +247,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def set_renewal_token_for_user(
self, user_id: str, renewal_token: str
) -> None:
"""Defines a renewal token for a given user.
"""Defines a renewal token for a given user, and clears the token_used timestamp.
Args:
user_id: ID of the user to set the renewal token for.
@ -244,26 +260,40 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
updatevalues={"renewal_token": renewal_token, "token_used_ts_ms": None},
desc="set_renewal_token_for_user",
)
async def get_user_from_renewal_token(self, renewal_token: str) -> str:
"""Get a user ID from a renewal token.
async def get_user_from_renewal_token(
self, renewal_token: str
) -> Tuple[str, int, Optional[int]]:
"""Get a user ID and renewal status from a renewal token.
Args:
renewal_token: The renewal token to perform the lookup with.
Returns:
The ID of the user to which the token belongs.
A tuple of containing the following values:
* The ID of a user to which the token belongs.
* An int representing the user's expiry timestamp as milliseconds since the
epoch, or 0 if the token was invalid.
* An optional int representing the timestamp of when the user renewed their
account timestamp as milliseconds since the epoch. None if the account
has not been renewed using the current token yet.
"""
return await self.db_pool.simple_select_one_onecol(
ret_dict = await self.db_pool.simple_select_one(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
retcols=["user_id", "expiration_ts_ms", "token_used_ts_ms"],
desc="get_user_from_renewal_token",
)
return (
ret_dict["user_id"],
ret_dict["expiration_ts_ms"],
ret_dict["token_used_ts_ms"],
)
async def get_renewal_token_for_user(self, user_id: str) -> str:
"""Get the renewal token associated with a given user ID.
@ -302,7 +332,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"get_users_expiring_soon",
select_users_txn,
self._clock.time_msec(),
self.config.account_validity.renew_at,
self.config.account_validity_renew_at,
)
async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
@ -964,11 +994,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
delta equal to 10% of the validity period.
"""
now_ms = self._clock.time_msec()
expiration_ts = now_ms + self._account_validity.period
expiration_ts = now_ms + self._account_validity_period
if use_delta:
expiration_ts = self.rand.randrange(
expiration_ts - self._account_validity.startup_job_max_delta,
expiration_ts - self._account_validity_startup_job_max_delta,
expiration_ts,
)
@ -1412,7 +1442,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
except self.database_engine.module.IntegrityError:
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
if self._account_validity.enabled:
if self._account_validity_enabled:
self.set_expiration_date_for_user_txn(txn, user_id)
if create_profile_with_displayname: