mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add type hints for account validity handler (#8620)
This also fixes a bug by fixing handling of an account which doesn't expire.
This commit is contained in:
parent
66e6801c3e
commit
10f45d85bb
1
changelog.d/8620.bugfix
Normal file
1
changelog.d/8620.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Fix a bug where the account validity endpoint would silently fail if the user ID did not have an expiration time. It now returns a 400 error.
|
1
mypy.ini
1
mypy.ini
@ -17,6 +17,7 @@ files =
|
|||||||
synapse/federation,
|
synapse/federation,
|
||||||
synapse/handlers/_base.py,
|
synapse/handlers/_base.py,
|
||||||
synapse/handlers/account_data.py,
|
synapse/handlers/account_data.py,
|
||||||
|
synapse/handlers/account_validity.py,
|
||||||
synapse/handlers/appservice.py,
|
synapse/handlers/appservice.py,
|
||||||
synapse/handlers/auth.py,
|
synapse/handlers/auth.py,
|
||||||
synapse/handlers/cas_handler.py,
|
synapse/handlers/cas_handler.py,
|
||||||
|
@ -18,19 +18,22 @@ import email.utils
|
|||||||
import logging
|
import logging
|
||||||
from email.mime.multipart import MIMEMultipart
|
from email.mime.multipart import MIMEMultipart
|
||||||
from email.mime.text import MIMEText
|
from email.mime.text import MIMEText
|
||||||
from typing import List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError, SynapseError
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AccountValidityHandler:
|
class AccountValidityHandler:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
@ -67,7 +70,7 @@ class AccountValidityHandler:
|
|||||||
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
|
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
|
||||||
|
|
||||||
@wrap_as_background_process("send_renewals")
|
@wrap_as_background_process("send_renewals")
|
||||||
async def _send_renewal_emails(self):
|
async def _send_renewal_emails(self) -> None:
|
||||||
"""Gets the list of users whose account is expiring in the amount of time
|
"""Gets the list of users whose account is expiring in the amount of time
|
||||||
configured in the ``renew_at`` parameter from the ``account_validity``
|
configured in the ``renew_at`` parameter from the ``account_validity``
|
||||||
configuration, and sends renewal emails to all of these users as long as they
|
configuration, and sends renewal emails to all of these users as long as they
|
||||||
@ -81,11 +84,25 @@ class AccountValidityHandler:
|
|||||||
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
|
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_renewal_email_to_user(self, user_id: str):
|
async def send_renewal_email_to_user(self, user_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Send a renewal email for a specific user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to send a renewal email for.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SynapseError if the user is not set to renew.
|
||||||
|
"""
|
||||||
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
||||||
|
|
||||||
|
# If this user isn't set to be expired, raise an error.
|
||||||
|
if expiration_ts is None:
|
||||||
|
raise SynapseError(400, "User has no expiration time: %s" % (user_id,))
|
||||||
|
|
||||||
await self._send_renewal_email(user_id, expiration_ts)
|
await self._send_renewal_email(user_id, expiration_ts)
|
||||||
|
|
||||||
async def _send_renewal_email(self, user_id: str, expiration_ts: int):
|
async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None:
|
||||||
"""Sends out a renewal email to every email address attached to the given user
|
"""Sends out a renewal email to every email address attached to the given user
|
||||||
with a unique link allowing them to renew their account.
|
with a unique link allowing them to renew their account.
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ class ProfileHandler(BaseHandler):
|
|||||||
profile = await self.store.get_from_remote_profile_cache(user_id)
|
profile = await self.store.get_from_remote_profile_cache(user_id)
|
||||||
return profile or {}
|
return profile or {}
|
||||||
|
|
||||||
async def get_displayname(self, target_user: UserID) -> str:
|
async def get_displayname(self, target_user: UserID) -> Optional[str]:
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
displayname = await self.store.get_profile_displayname(
|
displayname = await self.store.get_profile_displayname(
|
||||||
@ -218,7 +218,7 @@ class ProfileHandler(BaseHandler):
|
|||||||
|
|
||||||
await self._update_join_states(requester, target_user)
|
await self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
async def get_avatar_url(self, target_user: UserID) -> str:
|
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
avatar_url = await self.store.get_profile_avatar_url(
|
avatar_url = await self.store.get_profile_avatar_url(
|
||||||
|
@ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_profile_displayname(self, user_localpart: str) -> str:
|
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
|
||||||
return await self.db_pool.simple_select_one_onecol(
|
return await self.db_pool.simple_select_one_onecol(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
@ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||||||
desc="get_profile_displayname",
|
desc="get_profile_displayname",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_profile_avatar_url(self, user_localpart: str) -> str:
|
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
|
||||||
return await self.db_pool.simple_select_one_onecol(
|
return await self.db_pool.simple_select_one_onecol(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
|
@ -240,13 +240,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
desc="get_renewal_token_for_user",
|
desc="get_renewal_token_for_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
|
async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
|
||||||
"""Selects users whose account will expire in the [now, now + renew_at] time
|
"""Selects users whose account will expire in the [now, now + renew_at] time
|
||||||
window (see configuration for account_validity for information on what renew_at
|
window (see configuration for account_validity for information on what renew_at
|
||||||
refers to).
|
refers to).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of dictionaries mapping user ID to expiration time (in milliseconds).
|
A list of dictionaries, each with a user ID and expiration time (in milliseconds).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def select_users_txn(txn, now_ms, renew_at):
|
def select_users_txn(txn, now_ms, renew_at):
|
||||||
|
Loading…
Reference in New Issue
Block a user