mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add type hints for account validity handler (#8620)
This also fixes a bug by fixing handling of an account which doesn't expire.
This commit is contained in:
parent
66e6801c3e
commit
10f45d85bb
1
changelog.d/8620.bugfix
Normal file
1
changelog.d/8620.bugfix
Normal file
@ -0,0 +1 @@
|
||||
Fix a bug where the account validity endpoint would silently fail if the user ID did not have an expiration time. It now returns a 400 error.
|
1
mypy.ini
1
mypy.ini
@ -17,6 +17,7 @@ files =
|
||||
synapse/federation,
|
||||
synapse/handlers/_base.py,
|
||||
synapse/handlers/account_data.py,
|
||||
synapse/handlers/account_validity.py,
|
||||
synapse/handlers/appservice.py,
|
||||
synapse/handlers/auth.py,
|
||||
synapse/handlers/cas_handler.py,
|
||||
|
@ -18,19 +18,22 @@ import email.utils
|
||||
import logging
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.api.errors import StoreError, SynapseError
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.types import UserID
|
||||
from synapse.util import stringutils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccountValidityHandler:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.config = hs.config
|
||||
self.store = self.hs.get_datastore()
|
||||
@ -67,7 +70,7 @@ class AccountValidityHandler:
|
||||
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
|
||||
|
||||
@wrap_as_background_process("send_renewals")
|
||||
async def _send_renewal_emails(self):
|
||||
async def _send_renewal_emails(self) -> None:
|
||||
"""Gets the list of users whose account is expiring in the amount of time
|
||||
configured in the ``renew_at`` parameter from the ``account_validity``
|
||||
configuration, and sends renewal emails to all of these users as long as they
|
||||
@ -81,11 +84,25 @@ class AccountValidityHandler:
|
||||
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
|
||||
)
|
||||
|
||||
async def send_renewal_email_to_user(self, user_id: str):
|
||||
async def send_renewal_email_to_user(self, user_id: str) -> None:
|
||||
"""
|
||||
Send a renewal email for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to send a renewal email for.
|
||||
|
||||
Raises:
|
||||
SynapseError if the user is not set to renew.
|
||||
"""
|
||||
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
||||
|
||||
# If this user isn't set to be expired, raise an error.
|
||||
if expiration_ts is None:
|
||||
raise SynapseError(400, "User has no expiration time: %s" % (user_id,))
|
||||
|
||||
await self._send_renewal_email(user_id, expiration_ts)
|
||||
|
||||
async def _send_renewal_email(self, user_id: str, expiration_ts: int):
|
||||
async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None:
|
||||
"""Sends out a renewal email to every email address attached to the given user
|
||||
with a unique link allowing them to renew their account.
|
||||
|
||||
|
@ -131,7 +131,7 @@ class ProfileHandler(BaseHandler):
|
||||
profile = await self.store.get_from_remote_profile_cache(user_id)
|
||||
return profile or {}
|
||||
|
||||
async def get_displayname(self, target_user: UserID) -> str:
|
||||
async def get_displayname(self, target_user: UserID) -> Optional[str]:
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
displayname = await self.store.get_profile_displayname(
|
||||
@ -218,7 +218,7 @@ class ProfileHandler(BaseHandler):
|
||||
|
||||
await self._update_join_states(requester, target_user)
|
||||
|
||||
async def get_avatar_url(self, target_user: UserID) -> str:
|
||||
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
avatar_url = await self.store.get_profile_avatar_url(
|
||||
|
@ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||
)
|
||||
|
||||
async def get_profile_displayname(self, user_localpart: str) -> str:
|
||||
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
@ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
|
||||
async def get_profile_avatar_url(self, user_localpart: str) -> str:
|
||||
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
|
@ -240,13 +240,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
desc="get_renewal_token_for_user",
|
||||
)
|
||||
|
||||
async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
|
||||
async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
|
||||
"""Selects users whose account will expire in the [now, now + renew_at] time
|
||||
window (see configuration for account_validity for information on what renew_at
|
||||
refers to).
|
||||
|
||||
Returns:
|
||||
A list of dictionaries mapping user ID to expiration time (in milliseconds).
|
||||
A list of dictionaries, each with a user ID and expiration time (in milliseconds).
|
||||
"""
|
||||
|
||||
def select_users_txn(txn, now_ms, renew_at):
|
||||
|
Loading…
Reference in New Issue
Block a user