Converts event_federation and registration databases to async/await (#8061)

This commit is contained in:
Patrick Cloke 2020-08-11 17:21:13 -04:00 committed by GitHub
parent 61d8ff0d44
commit a0acdfa9e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 150 additions and 177 deletions

View file

@ -17,9 +17,8 @@
import logging
import re
from typing import Optional
from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes
@ -30,7 +29,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import cached
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_id",
)
@defer.inlineCallbacks
def is_trial_user(self, user_id):
async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config
Args:
user_id (str)
Returns:
Deferred[bool]
user_id: The user to check for trial status.
"""
info = yield self.get_user_by_id(user_id)
info = await self.get_user_by_id(user_id)
if not info:
return False
@ -105,41 +100,42 @@ class RegistrationWorkerStore(SQLBaseStore):
"get_user_by_access_token", self._query_for_auth, token
)
@cachedInlineCallbacks()
def get_expiration_ts_for_user(self, user_id):
@cached()
async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
user_id (str): The ID of the user.
user_id: The ID of the user.
Returns:
defer.Deferred: None, if the account has no expiration timestamp,
otherwise int representation of the timestamp (as a number of
milliseconds since epoch).
None, if the account has no expiration timestamp, otherwise int
representation of the timestamp (as a number of milliseconds since epoch).
"""
res = yield self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
desc="get_expiration_ts_for_user",
)
return res
@defer.inlineCallbacks
def set_account_validity_for_user(
self, user_id, expiration_ts, email_sent, renewal_token=None
):
async def set_account_validity_for_user(
self,
user_id: str,
expiration_ts: int,
email_sent: bool,
renewal_token: Optional[str] = None,
) -> None:
"""Updates the account validity properties of the given account, with the
given values.
Args:
user_id (str): ID of the account to update properties for.
expiration_ts (int): New expiration date, as a timestamp in milliseconds
user_id: ID of the account to update properties for.
expiration_ts: New expiration date, as a timestamp in milliseconds
since epoch.
email_sent (bool): True means a renewal email has been sent for this
account and there's no need to send another one for the current validity
email_sent: True means a renewal email has been sent for this account
and there's no need to send another one for the current validity
period.
renewal_token (str): Renewal token the user can use to extend the validity
renewal_token: Renewal token the user can use to extend the validity
of their account. Defaults to no token.
"""
@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_expiration_ts_for_user, (user_id,)
)
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"set_account_validity_for_user", set_account_validity_for_user_txn
)
@defer.inlineCallbacks
def set_renewal_token_for_user(self, user_id, renewal_token):
async def set_renewal_token_for_user(
self, user_id: str, renewal_token: str
) -> None:
"""Defines a renewal token for a given user.
Args:
user_id (str): ID of the user to set the renewal token for.
renewal_token (str): Random unique string that will be used to renew the
user_id: ID of the user to set the renewal token for.
renewal_token: Random unique string that will be used to renew the
user's account.
Raises:
StoreError: The provided token is already set for another user.
"""
yield self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
desc="set_renewal_token_for_user",
)
@defer.inlineCallbacks
def get_user_from_renewal_token(self, renewal_token):
async def get_user_from_renewal_token(self, renewal_token: str) -> str:
"""Get a user ID from a renewal token.
Args:
renewal_token (str): The renewal token to perform the lookup with.
renewal_token: The renewal token to perform the lookup with.
Returns:
defer.Deferred[str]: The ID of the user to which the token belongs.
The ID of the user to which the token belongs.
"""
res = yield self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
desc="get_user_from_renewal_token",
)
return res
@defer.inlineCallbacks
def get_renewal_token_for_user(self, user_id):
async def get_renewal_token_for_user(self, user_id: str) -> str:
"""Get the renewal token associated with a given user ID.
Args:
user_id (str): The user ID to lookup a token for.
user_id: The user ID to lookup a token for.
Returns:
defer.Deferred[str]: The renewal token associated with this user ID.
The renewal token associated with this user ID.
"""
res = yield self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
desc="get_renewal_token_for_user",
)
return res
@defer.inlineCallbacks
def get_users_expiring_soon(self):
async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
A list of dictionaries mapping user ID to expiration time (in milliseconds).
"""
def select_users_txn(txn, now_ms, renew_at):
@ -238,53 +228,49 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, values)
return self.db_pool.cursor_to_dict(txn)
res = yield self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
self.config.account_validity.renew_at,
)
return res
@defer.inlineCallbacks
def set_renewal_mail_status(self, user_id, email_sent):
async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
"""Sets or unsets the flag that indicates whether a renewal email has been sent
to the user (and the user hasn't renewed their account yet).
Args:
user_id (str): ID of the user to set/unset the flag for.
email_sent (bool): Flag which indicates whether a renewal email has been sent
user_id: ID of the user to set/unset the flag for.
email_sent: Flag which indicates whether a renewal email has been sent
to this user.
"""
yield self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
desc="set_renewal_mail_status",
)
@defer.inlineCallbacks
def delete_account_validity_for_user(self, user_id):
async def delete_account_validity_for_user(self, user_id: str) -> None:
"""Deletes the entry for the given user in the account validity table, removing
their expiration date and renewal token.
Args:
user_id (str): ID of the user to remove from the account validity table.
user_id: ID of the user to remove from the account validity table.
"""
yield self.db_pool.simple_delete_one(
await self.db_pool.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
)
async def is_server_admin(self, user):
async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
Args:
user (UserID): user ID of the user to test
user: user ID of the user to test
Returns (bool):
Returns:
true iff the user is a server admin, false otherwise.
"""
res = await self.db_pool.simple_select_one_onecol(
@ -332,32 +318,31 @@ class RegistrationWorkerStore(SQLBaseStore):
return None
@cachedInlineCallbacks()
def is_real_user(self, user_id):
@cached()
async def is_real_user(self, user_id: str) -> bool:
"""Determines if the user is a real user, ie does not have a 'user_type'.
Args:
user_id (str): user id to test
user_id: user id to test
Returns:
Deferred[bool]: True if user 'user_type' is null or empty string
True if user 'user_type' is null or empty string
"""
res = yield self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"is_real_user", self.is_real_user_txn, user_id
)
return res
@cached()
def is_support_user(self, user_id):
async def is_support_user(self, user_id: str) -> bool:
"""Determines if the user is of type UserTypes.SUPPORT
Args:
user_id (str): user id to test
user_id: user id to test
Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT
True if user is of type UserTypes.SUPPORT
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
@ -413,8 +398,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_external_id",
)
@defer.inlineCallbacks
def count_all_users(self):
async def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
@ -424,8 +408,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]["users"]
return 0
ret = yield self.db_pool.runInteraction("count_users", _count_users)
return ret
return await self.db_pool.runInteraction("count_users", _count_users)
def count_daily_user_type(self):
"""
@ -460,8 +443,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"count_daily_user_type", _count_daily_user_type
)
@defer.inlineCallbacks
def count_nonbridged_users(self):
async def count_nonbridged_users(self):
def _count_users(txn):
txn.execute(
"""
@ -472,11 +454,9 @@ class RegistrationWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
ret = yield self.db_pool.runInteraction("count_users", _count_users)
return ret
return await self.db_pool.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
def count_real_users(self):
async def count_real_users(self):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn):
@ -486,8 +466,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]["users"]
return 0
ret = yield self.db_pool.runInteraction("count_real_users", _count_users)
return ret
return await self.db_pool.runInteraction("count_real_users", _count_users)
async def generate_user_id(self) -> str:
"""Generate a suitable localpart for a guest user
@ -537,23 +516,20 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret["user_id"]
return None
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self.db_pool.simple_upsert(
async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self.db_pool.simple_select_list(
async def user_get_threepids(self, user_id):
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
"user_get_threepids",
)
return ret
def user_delete_threepid(self, user_id, medium, address):
return self.db_pool.simple_delete(
@ -668,18 +644,18 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_id_servers_user_bound",
)
@cachedInlineCallbacks()
def get_user_deactivated_status(self, user_id):
@cached()
async def get_user_deactivated_status(self, user_id: str) -> bool:
"""Retrieve the value for the `deactivated` property for the provided user.
Args:
user_id (str): The ID of the user to retrieve the status for.
user_id: The ID of the user to retrieve the status for.
Returns:
defer.Deferred(bool): The requested value.
True if the user was deactivated, false if the user is still active.
"""
res = yield self.db_pool.simple_select_one_onecol(
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
@ -818,8 +794,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
@defer.inlineCallbacks
def _background_update_set_deactivated_flag(self, progress, batch_size):
async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
@ -870,19 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
else:
return False, len(rows)
end, nb_processed = yield self.db_pool.runInteraction(
end, nb_processed = await self.db_pool.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
"users_set_deactivated_flag"
)
return nb_processed
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
async def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
@ -903,11 +877,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self.db_pool.updates._end_background_update("user_threepids_grandfather")
await self.db_pool.updates._end_background_update("user_threepids_grandfather")
return 1
@ -937,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
async def add_access_token_to_user(
self,
user_id: str,
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
) -> None:
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
device_id (str): ID of the device to associate with the access
token
valid_until_ms (int|None): when the token is valid until. None for
no expiry.
user_id: The user ID.
token: The new access token to add.
device_id: ID of the device to associate with the access token
valid_until_ms: when the token is valid until. None for no expiry.
Raises:
StoreError if there was a problem adding this.
"""
next_id = self._access_tokens_id_gen.get_next()
yield self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"access_tokens",
{
"id": next_id,
@ -1097,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
@ -1241,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return self.db_pool.runInteraction("delete_access_token", f)
@cachedInlineCallbacks()
def is_guest(self, user_id):
res = yield self.db_pool.simple_select_one_onecol(
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@ -1481,16 +1457,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self.clock.time_msec(),
)
@defer.inlineCallbacks
def set_user_deactivated_status(self, user_id, deactivated):
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
Args:
user_id (str): The ID of the user to set the status for.
deactivated (bool): The value to set for `deactivated`.
user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
@ -1507,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
@defer.inlineCallbacks
def _set_expiration_date_when_missing(self):
async def _set_expiration_date_when_missing(self):
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
@ -1533,7 +1510,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, user["name"], use_delta=True
)
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)