Port handlers.account_validity to async/await.

This commit is contained in:
Erik Johnston 2019-12-10 11:13:15 +00:00
parent 353396e3a7
commit 257ef2c727
3 changed files with 42 additions and 49 deletions

View File

@ -25,7 +25,7 @@ class AccountDataEventSource(object):
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key
current_stream_id = await self.store.get_max_account_data_stream_id() current_stream_id = self.store.get_max_account_data_stream_id()
results = [] results = []
tags = await self.store.get_updated_tags(user_id, last_stream_id) tags = await self.store.get_updated_tags(user_id, last_stream_id)

View File

@ -18,8 +18,7 @@ import email.utils
import logging import logging
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from typing import List
from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -78,42 +77,39 @@ class AccountValidityHandler(object):
# run as a background process to make sure that the database transactions # run as a background process to make sure that the database transactions
# have a logcontext to report to # have a logcontext to report to
return run_as_background_process( return run_as_background_process(
"send_renewals", self.send_renewal_emails "send_renewals", self._send_renewal_emails
) )
self.clock.looping_call(send_emails, 30 * 60 * 1000) self.clock.looping_call(send_emails, 30 * 60 * 1000)
@defer.inlineCallbacks async def _send_renewal_emails(self):
def send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time """Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity`` configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account. have an email 3PID attached to their account.
""" """
expiring_users = yield self.store.get_users_expiring_soon() expiring_users = await self.store.get_users_expiring_soon()
if expiring_users: if expiring_users:
for user in expiring_users: for user in expiring_users:
yield self._send_renewal_email( await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
) )
@defer.inlineCallbacks async def send_renewal_email_to_user(self, user_id: str):
def send_renewal_email_to_user(self, user_id): expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) await self._send_renewal_email(user_id, expiration_ts)
yield self._send_renewal_email(user_id, expiration_ts)
@defer.inlineCallbacks async def _send_renewal_email(self, user_id: str, expiration_ts: int):
def _send_renewal_email(self, user_id, expiration_ts):
"""Sends out a renewal email to every email address attached to the given user """Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account. with a unique link allowing them to renew their account.
Args: Args:
user_id (str): ID of the user to send email(s) to. user_id: ID of the user to send email(s) to.
expiration_ts (int): Timestamp in milliseconds for the expiration date of expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates). this user's account (used in the email templates).
""" """
addresses = yield self._get_email_addresses_for_user(user_id) addresses = await self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address. # Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their # In this case, they will have to ask their server admin to renew their
@ -125,7 +121,7 @@ class AccountValidityHandler(object):
return return
try: try:
user_display_name = yield self.store.get_profile_displayname( user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart UserID.from_string(user_id).localpart
) )
if user_display_name is None: if user_display_name is None:
@ -133,7 +129,7 @@ class AccountValidityHandler(object):
except StoreError: except StoreError:
user_display_name = user_id user_display_name = user_id
renewal_token = yield self._get_renewal_token(user_id) renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl, self.hs.config.public_baseurl,
renewal_token, renewal_token,
@ -165,7 +161,7 @@ class AccountValidityHandler(object):
logger.info("Sending renewal email to %s", address) logger.info("Sending renewal email to %s", address)
yield make_deferred_yieldable( await make_deferred_yieldable(
self.sendmail( self.sendmail(
self.hs.config.email_smtp_host, self.hs.config.email_smtp_host,
self._raw_from, self._raw_from,
@ -180,19 +176,18 @@ class AccountValidityHandler(object):
) )
) )
yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
@defer.inlineCallbacks async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
def _get_email_addresses_for_user(self, user_id):
"""Retrieve the list of email addresses attached to a user's account. """Retrieve the list of email addresses attached to a user's account.
Args: Args:
user_id (str): ID of the user to lookup email addresses for. user_id: ID of the user to lookup email addresses for.
Returns: Returns:
defer.Deferred[list[str]]: Email addresses for this account. Email addresses for this account.
""" """
threepids = yield self.store.user_get_threepids(user_id) threepids = await self.store.user_get_threepids(user_id)
addresses = [] addresses = []
for threepid in threepids: for threepid in threepids:
@ -201,16 +196,15 @@ class AccountValidityHandler(object):
return addresses return addresses
@defer.inlineCallbacks async def _get_renewal_token(self, user_id: str) -> str:
def _get_renewal_token(self, user_id):
"""Generates a 32-byte long random string that will be inserted into the """Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database. user's renewal email's unique link, then saves it into the database.
Args: Args:
user_id (str): ID of the user to generate a string for. user_id: ID of the user to generate a string for.
Returns: Returns:
defer.Deferred[str]: The generated string. The generated string.
Raises: Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts. StoreError(500): Couldn't generate a unique string after 5 attempts.
@ -219,52 +213,52 @@ class AccountValidityHandler(object):
while attempts < 5: while attempts < 5:
try: try:
renewal_token = stringutils.random_string(32) renewal_token = stringutils.random_string(32)
yield self.store.set_renewal_token_for_user(user_id, renewal_token) await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token return renewal_token
except StoreError: except StoreError:
attempts += 1 attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.") raise StoreError(500, "Couldn't generate a unique string as refresh string.")
@defer.inlineCallbacks async def renew_account(self, renewal_token: str) -> bool:
def renew_account(self, renewal_token):
"""Renews the account attached to a given renewal token by pushing back the """Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration. expiration date by the current validity period in the server's configuration.
Args: Args:
renewal_token (str): Token sent with the renewal request. renewal_token: Token sent with the renewal request.
Returns: Returns:
bool: Whether the provided token is valid. Whether the provided token is valid.
""" """
try: try:
user_id = yield self.store.get_user_from_renewal_token(renewal_token) user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError: except StoreError:
defer.returnValue(False) return False
logger.debug("Renewing an account for user %s", user_id) logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id) await self.renew_account_for_user(user_id)
defer.returnValue(True) return True
@defer.inlineCallbacks async def renew_account_for_user(
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): self, user_id: str, expiration_ts: int = None, email_sent: bool = False
) -> int:
"""Renews the account attached to a given user by pushing back the """Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's expiration date by the current validity period in the server's
configuration. configuration.
Args: Args:
renewal_token (str): Token sent with the renewal request. renewal_token: Token sent with the renewal request.
expiration_ts (int): New expiration date. Defaults to now + validity period. expiration_ts: New expiration date. Defaults to now + validity period.
email_sent (bool): Whether an email has been sent for this validity period. email_sen: Whether an email has been sent for this validity period.
Defaults to False. Defaults to False.
Returns: Returns:
defer.Deferred[int]: New expiration date for this account, as a timestamp New expiration date for this account, as a timestamp in
in milliseconds since epoch. milliseconds since epoch.
""" """
if expiration_ts is None: if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period expiration_ts = self.clock.time_msec() + self._account_validity.period
yield self.store.set_account_validity_for_user( await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
) )

View File

@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Email config. # Email config.
self.email_attempts = [] self.email_attempts = []
def sendmail(*args, **kwargs): async def sendmail(*args, **kwargs):
self.email_attempts.append((args, kwargs)) self.email_attempts.append((args, kwargs))
return
config["email"] = { config["email"] = {
"enable_notifs": True, "enable_notifs": True,