Add type hints for account validity handler (#8620)

This also fixes a bug by fixing handling of an account which doesn't expire.
This commit is contained in:
Patrick Cloke 2020-10-26 14:17:31 -04:00 committed by GitHub
parent 66e6801c3e
commit 10f45d85bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 31 additions and 12 deletions

1
changelog.d/8620.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug where the account validity endpoint would silently fail if the user ID did not have an expiration time. It now returns a 400 error.

View File

@ -17,6 +17,7 @@ files =
synapse/federation, synapse/federation,
synapse/handlers/_base.py, synapse/handlers/_base.py,
synapse/handlers/account_data.py, synapse/handlers/account_data.py,
synapse/handlers/account_validity.py,
synapse/handlers/appservice.py, synapse/handlers/appservice.py,
synapse/handlers/auth.py, synapse/handlers/auth.py,
synapse/handlers/cas_handler.py, synapse/handlers/cas_handler.py,

View File

@ -18,19 +18,22 @@ import email.utils
import logging import logging
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from typing import List from typing import TYPE_CHECKING, List
from synapse.api.errors import StoreError from synapse.api.errors import StoreError, SynapseError
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID from synapse.types import UserID
from synapse.util import stringutils from synapse.util import stringutils
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AccountValidityHandler: class AccountValidityHandler:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.config = hs.config self.config = hs.config
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
@ -67,7 +70,7 @@ class AccountValidityHandler:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000) self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
@wrap_as_background_process("send_renewals") @wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self): async def _send_renewal_emails(self) -> None:
"""Gets the list of users whose account is expiring in the amount of time """Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity`` configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they configuration, and sends renewal emails to all of these users as long as they
@ -81,11 +84,25 @@ class AccountValidityHandler:
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
) )
async def send_renewal_email_to_user(self, user_id: str): async def send_renewal_email_to_user(self, user_id: str) -> None:
"""
Send a renewal email for a specific user.
Args:
user_id: The user ID to send a renewal email for.
Raises:
SynapseError if the user is not set to renew.
"""
expiration_ts = await self.store.get_expiration_ts_for_user(user_id) expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
# If this user isn't set to be expired, raise an error.
if expiration_ts is None:
raise SynapseError(400, "User has no expiration time: %s" % (user_id,))
await self._send_renewal_email(user_id, expiration_ts) await self._send_renewal_email(user_id, expiration_ts)
async def _send_renewal_email(self, user_id: str, expiration_ts: int): async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None:
"""Sends out a renewal email to every email address attached to the given user """Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account. with a unique link allowing them to renew their account.

View File

@ -131,7 +131,7 @@ class ProfileHandler(BaseHandler):
profile = await self.store.get_from_remote_profile_cache(user_id) profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {} return profile or {}
async def get_displayname(self, target_user: UserID) -> str: async def get_displayname(self, target_user: UserID) -> Optional[str]:
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
displayname = await self.store.get_profile_displayname( displayname = await self.store.get_profile_displayname(
@ -218,7 +218,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
async def get_avatar_url(self, target_user: UserID) -> str: async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
avatar_url = await self.store.get_profile_avatar_url( avatar_url = await self.store.get_profile_avatar_url(

View File

@ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"] avatar_url=profile["avatar_url"], display_name=profile["displayname"]
) )
async def get_profile_displayname(self, user_localpart: str) -> str: async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
@ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_displayname", desc="get_profile_displayname",
) )
async def get_profile_avatar_url(self, user_localpart: str) -> str: async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},

View File

@ -240,13 +240,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_renewal_token_for_user", desc="get_renewal_token_for_user",
) )
async def get_users_expiring_soon(self) -> List[Dict[str, int]]: async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
"""Selects users whose account will expire in the [now, now + renew_at] time """Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at window (see configuration for account_validity for information on what renew_at
refers to). refers to).
Returns: Returns:
A list of dictionaries mapping user ID to expiration time (in milliseconds). A list of dictionaries, each with a user ID and expiration time (in milliseconds).
""" """
def select_users_txn(txn, now_ms, renew_at): def select_users_txn(txn, now_ms, renew_at):