Convert event_push_actions, registration, and roommember datastores to async (#8197)

This commit is contained in:
Patrick Cloke 2020-08-28 11:34:50 -04:00 committed by GitHub
parent 22b926c284
commit d58fda99ff
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 169 additions and 160 deletions

View file

@ -17,7 +17,7 @@
import logging
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@ -84,17 +84,17 @@ class RegistrationWorkerStore(SQLBaseStore):
return is_trial
@cached()
def get_user_by_access_token(self, token):
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
token: The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
return bool(res) if res else False
def set_server_admin(self, user, admin):
async def set_server_admin(self, user: UserID, admin: bool) -> None:
"""Sets whether a user is an admin of this homeserver.
Args:
user (UserID): user ID of the user to test
admin (bool): true iff the user is to be a server admin,
false otherwise.
user: user ID of the user to test
admin: true iff the user is to be a server admin, false otherwise.
"""
def set_server_admin_txn(txn):
@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_user_by_id, (user.to_string(),)
)
return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return True if res == UserTypes.SUPPORT else False
def get_users_by_id_case_insensitive(self, user_id):
async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
Returns:
A mapping of user_id -> password_hash.
"""
def f(txn):
@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("count_users", _count_users)
def count_daily_user_type(self):
async def count_daily_user_type(self) -> Dict[str, int]:
"""
Counts 1) native non guest users
2) native guests users
@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_daily_user_type", _count_daily_user_type
)
@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
# Convert the integer into a boolean.
return res == 1
def get_threepid_validation_session(
self, medium, client_secret, address=None, sid=None, validated=True
):
async def get_threepid_validation_session(
self,
medium: Optional[str],
client_secret: str,
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
) -> Optional[Dict[str, Any]]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
Args:
medium (str|None): The medium of the 3PID
address (str|None): The address of the 3PID
sid (str|None): The ID of the validation session
client_secret (str): A unique string provided by the client to help identify this
medium: The medium of the 3PID
client_secret: A unique string provided by the client to help identify this
validation attempt
validated (bool|None): Whether sessions should be filtered by
address: The address of the 3PID
sid: The ID of the validation session
validated: Whether sessions should be filtered by
whether they have been validated already or not. None to
perform no filtering
Returns:
Deferred[dict|None]: A dict containing the following:
A dict containing the following:
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
def delete_threepid_session(self, session_id):
async def delete_threepid_session(self, session_id: str) -> None:
"""Removes a threepid validation session from the database. This can
be done after validation has been performed and whatever action was
waiting on it has been carried out
Args:
session_id (str): The ID of the session to delete
session_id: The ID of the session to delete
"""
def delete_threepid_session_txn(txn):
@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
keyvalues={"session_id": session_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
def register_user(
async def register_user(
self,
user_id,
password_hash=None,
was_guest=False,
make_guest=False,
appservice_id=None,
create_profile_with_displayname=None,
admin=False,
user_type=None,
shadow_banned=False,
):
user_id: str,
password_hash: Optional[str] = None,
was_guest: bool = False,
make_guest: bool = False,
appservice_id: Optional[str] = None,
create_profile_with_displayname: Optional[str] = None,
admin: bool = False,
user_type: Optional[str] = None,
shadow_banned: bool = False,
) -> None:
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str): The ID of the appservice registering the user.
create_profile_with_displayname (unicode): Optionally create a profile for
user_id: The desired user ID to register.
password_hash: Optional. The password hash for this user.
was_guest: Whether this is a guest account being upgraded to a
non-guest account.
make_guest: True if the the new user should be guest, false to add a
regular user account.
appservice_id: The ID of the appservice registering the user.
create_profile_with_displayname: Optionally create a profile for
the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
shadow_banned (bool): Whether the user is shadow-banned,
i.e. they may be told their requests succeeded but we ignore them.
admin: is an admin user?
user_type: type of user. One of the values from api.constants.UserTypes,
or None for a normal user.
shadow_banned: Whether the user is shadow-banned, i.e. they may be
told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
Returns:
Deferred
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"register_user",
self._register_user,
user_id,
@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
def user_set_password_hash(self, user_id, password_hash):
async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
def user_set_consent_version(self, user_id, consent_version):
async def user_set_consent_version(
self, user_id: str, consent_version: str
) -> None:
"""Updates the user table to record privacy policy consent
Args:
user_id (str): full mxid of the user to update
consent_version (str): version of the policy the user has consented
to
user_id: full mxid of the user to update
consent_version: version of the policy the user has consented to
Raises:
StoreError(404) if user not found
@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction("user_set_consent_version", f)
await self.db_pool.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
async def user_set_consent_server_notice_sent(
self, user_id: str, consent_version: str
) -> None:
"""Updates the user table to record that we have sent the user a server
notice about privacy policy consent
Args:
user_id (str): full mxid of the user to update
consent_version (str): version of the policy we have notified the
user about
user_id: full mxid of the user to update
consent_version: version of the policy we have notified the user about
Raises:
StoreError(404) if user not found
@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
async def user_delete_access_tokens(
self,
user_id: str,
except_token_id: Optional[str] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_id (str): list of access_tokens IDs which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
user_id: ID of user the tokens belong to
except_token_id: access_tokens ID which should *not* be deleted
device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
defer.Deferred[list[str, int, str|None, int]]: a list of
(token, token id, device id) for each of the deleted tokens
A tuple of (token, token id, device id) for each of the deleted tokens
"""
def f(txn):
@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
return self.db_pool.runInteraction("user_delete_access_tokens", f)
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
async def delete_access_token(self, access_token: str) -> None:
def f(txn):
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
return self.db_pool.runInteraction("delete_access_token", f)
await self.db_pool.runInteraction("delete_access_token", f)
@cached()
async def is_guest(self, user_id: str) -> bool:
@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="get_users_pending_deactivation",
)
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
async def validate_threepid_session(
self, session_id: str, client_secret: str, token: str, current_ts: int
) -> Optional[str]:
"""Attempt to validate a threepid session using a token
Args:
session_id (str): The id of a validation session
client_secret (str): A unique string provided by the client to
help identify this validation attempt
token (str): A validation token
current_ts (int): The current unix time in milliseconds. Used for
checking token expiry status
session_id: The id of a validation session
client_secret: A unique string provided by the client to help identify
this validation attempt
token: A validation token
current_ts: The current unix time in milliseconds. Used for checking
token expiry status
Raises:
ThreepidValidationError: if a matching validation token was not found or has
expired
Returns:
deferred str|None: A str representing a link to redirect the user
to if there is one.
A str representing a link to redirect the user to if there is one.
"""
# Insert everything into a transaction in order to run atomically
@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
def start_or_continue_validation_session(
async def start_or_continue_validation_session(
self,
medium,
address,
session_id,
client_secret,
send_attempt,
next_link,
token,
token_expires,
):
medium: str,
address: str,
session_id: str,
client_secret: str,
send_attempt: int,
next_link: Optional[str],
token: str,
token_expires: int,
) -> None:
"""Creates a new threepid validation session if it does not already
exist and associates a new validation token with it
Args:
medium (str): The medium of the 3PID
address (str): The address of the 3PID
session_id (str): The id of this validation session
client_secret (str): A unique string provided by the client to
help identify this validation attempt
send_attempt (int): The latest send_attempt on this session
next_link (str|None): The link to redirect the user to upon
successful validation
token (str): The validation token
token_expires (int): The timestamp for which after the token
will no longer be valid
medium: The medium of the 3PID
address: The address of the 3PID
session_id: The id of this validation session
client_secret: A unique string provided by the client to help
identify this validation attempt
send_attempt: The latest send_attempt on this session
next_link: The link to redirect the user to upon successful validation
token: The validation token
token_expires: The timestamp for which after the token will no
longer be valid
"""
def start_or_continue_validation_session_txn(txn):
@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
def cull_expired_threepid_validation_tokens(self):
async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
def cull_expired_threepid_validation_tokens_txn(txn, ts):
@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
DELETE FROM threepid_validation_token WHERE
expires < ?
"""
return txn.execute(sql, (ts,))
txn.execute(sql, (ts,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),