mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Convert event_push_actions
, registration
, and roommember
datastores to async (#8197)
This commit is contained in:
parent
22b926c284
commit
d58fda99ff
1
changelog.d/8197.misc
Normal file
1
changelog.d/8197.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert various parts of the codebase to async/await.
|
@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
||||||
@ -383,19 +383,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||||||
# Now return the first `limit`
|
# Now return the first `limit`
|
||||||
return notifs[:limit]
|
return notifs[:limit]
|
||||||
|
|
||||||
def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
|
async def get_if_maybe_push_in_range_for_user(
|
||||||
|
self, user_id: str, min_stream_ordering: int
|
||||||
|
) -> bool:
|
||||||
"""A fast check to see if there might be something to push for the
|
"""A fast check to see if there might be something to push for the
|
||||||
user since the given stream ordering. May return false positives.
|
user since the given stream ordering. May return false positives.
|
||||||
|
|
||||||
Useful to know whether to bother starting a pusher on start up or not.
|
Useful to know whether to bother starting a pusher on start up or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str)
|
user_id
|
||||||
min_stream_ordering (int)
|
min_stream_ordering
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[bool]: True if there may be push to process, False if
|
True if there may be push to process, False if there definitely isn't.
|
||||||
there definitely isn't.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_if_maybe_push_in_range_for_user_txn(txn):
|
def _get_if_maybe_push_in_range_for_user_txn(txn):
|
||||||
@ -408,22 +409,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||||||
txn.execute(sql, (user_id, min_stream_ordering))
|
txn.execute(sql, (user_id, min_stream_ordering))
|
||||||
return bool(txn.fetchone())
|
return bool(txn.fetchone())
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_if_maybe_push_in_range_for_user",
|
"get_if_maybe_push_in_range_for_user",
|
||||||
_get_if_maybe_push_in_range_for_user_txn,
|
_get_if_maybe_push_in_range_for_user_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def add_push_actions_to_staging(self, event_id, user_id_actions):
|
async def add_push_actions_to_staging(
|
||||||
|
self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]]
|
||||||
|
) -> None:
|
||||||
"""Add the push actions for the event to the push action staging area.
|
"""Add the push actions for the event to the push action staging area.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_id (str)
|
event_id
|
||||||
user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
|
user_id_actions: A mapping of user_id to list of push actions, where
|
||||||
user_id to list of push actions, where an action can either be
|
an action can either be a string or dict.
|
||||||
a string or dict.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not user_id_actions:
|
if not user_id_actions:
|
||||||
@ -507,7 +506,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||||||
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
|
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
|
||||||
)
|
)
|
||||||
|
|
||||||
def find_first_stream_ordering_after_ts(self, ts):
|
async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
|
||||||
"""Gets the stream ordering corresponding to a given timestamp.
|
"""Gets the stream ordering corresponding to a given timestamp.
|
||||||
|
|
||||||
Specifically, finds the stream_ordering of the first event that was
|
Specifically, finds the stream_ordering of the first event that was
|
||||||
@ -516,13 +515,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||||||
relatively slow.
|
relatively slow.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ts (int): timestamp in millis
|
ts: timestamp in millis
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[int]: stream ordering of the first event received on/after
|
stream ordering of the first event received on/after the timestamp
|
||||||
the timestamp
|
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"_find_first_stream_ordering_after_ts_txn",
|
"_find_first_stream_ordering_after_ts_txn",
|
||||||
self._find_first_stream_ordering_after_ts_txn,
|
self._find_first_stream_ordering_after_ts_txn,
|
||||||
ts,
|
ts,
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||||
@ -84,17 +84,17 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
return is_trial
|
return is_trial
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_user_by_access_token(self, token):
|
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
|
||||||
"""Get a user from the given access token.
|
"""Get a user from the given access token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token (str): The access token of a user.
|
token: The access token of a user.
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: None, if the token did not match, otherwise dict
|
None, if the token did not match, otherwise dict
|
||||||
including the keys `name`, `is_guest`, `device_id`, `token_id`,
|
including the keys `name`, `is_guest`, `device_id`, `token_id`,
|
||||||
`valid_until_ms`.
|
`valid_until_ms`.
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_user_by_access_token", self._query_for_auth, token
|
"get_user_by_access_token", self._query_for_auth, token
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return bool(res) if res else False
|
return bool(res) if res else False
|
||||||
|
|
||||||
def set_server_admin(self, user, admin):
|
async def set_server_admin(self, user: UserID, admin: bool) -> None:
|
||||||
"""Sets whether a user is an admin of this homeserver.
|
"""Sets whether a user is an admin of this homeserver.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user (UserID): user ID of the user to test
|
user: user ID of the user to test
|
||||||
admin (bool): true iff the user is to be a server admin,
|
admin: true iff the user is to be a server admin, false otherwise.
|
||||||
false otherwise.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def set_server_admin_txn(txn):
|
def set_server_admin_txn(txn):
|
||||||
@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
txn, self.get_user_by_id, (user.to_string(),)
|
txn, self.get_user_by_id, (user.to_string(),)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token):
|
def _query_for_auth(self, txn, token):
|
||||||
sql = (
|
sql = (
|
||||||
@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
return True if res == UserTypes.SUPPORT else False
|
return True if res == UserTypes.SUPPORT else False
|
||||||
|
|
||||||
def get_users_by_id_case_insensitive(self, user_id):
|
async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
|
||||||
"""Gets users that match user_id case insensitively.
|
"""Gets users that match user_id case insensitively.
|
||||||
Returns a mapping of user_id -> password_hash.
|
|
||||||
|
Returns:
|
||||||
|
A mapping of user_id -> password_hash.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
return dict(txn)
|
return dict(txn)
|
||||||
|
|
||||||
return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
|
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
|
||||||
|
|
||||||
async def get_user_by_external_id(
|
async def get_user_by_external_id(
|
||||||
self, auth_provider: str, external_id: str
|
self, auth_provider: str, external_id: str
|
||||||
@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return await self.db_pool.runInteraction("count_users", _count_users)
|
return await self.db_pool.runInteraction("count_users", _count_users)
|
||||||
|
|
||||||
def count_daily_user_type(self):
|
async def count_daily_user_type(self) -> Dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Counts 1) native non guest users
|
Counts 1) native non guest users
|
||||||
2) native guests users
|
2) native guests users
|
||||||
@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
results[row[0]] = row[1]
|
results[row[0]] = row[1]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"count_daily_user_type", _count_daily_user_type
|
"count_daily_user_type", _count_daily_user_type
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
# Convert the integer into a boolean.
|
# Convert the integer into a boolean.
|
||||||
return res == 1
|
return res == 1
|
||||||
|
|
||||||
def get_threepid_validation_session(
|
async def get_threepid_validation_session(
|
||||||
self, medium, client_secret, address=None, sid=None, validated=True
|
self,
|
||||||
):
|
medium: Optional[str],
|
||||||
|
client_secret: str,
|
||||||
|
address: Optional[str] = None,
|
||||||
|
sid: Optional[str] = None,
|
||||||
|
validated: Optional[bool] = True,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""Gets a session_id and last_send_attempt (if available) for a
|
"""Gets a session_id and last_send_attempt (if available) for a
|
||||||
combination of validation metadata
|
combination of validation metadata
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
medium (str|None): The medium of the 3PID
|
medium: The medium of the 3PID
|
||||||
address (str|None): The address of the 3PID
|
client_secret: A unique string provided by the client to help identify this
|
||||||
sid (str|None): The ID of the validation session
|
|
||||||
client_secret (str): A unique string provided by the client to help identify this
|
|
||||||
validation attempt
|
validation attempt
|
||||||
validated (bool|None): Whether sessions should be filtered by
|
address: The address of the 3PID
|
||||||
|
sid: The ID of the validation session
|
||||||
|
validated: Whether sessions should be filtered by
|
||||||
whether they have been validated already or not. None to
|
whether they have been validated already or not. None to
|
||||||
perform no filtering
|
perform no filtering
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict|None]: A dict containing the following:
|
A dict containing the following:
|
||||||
* address - address of the 3pid
|
* address - address of the 3pid
|
||||||
* medium - medium of the 3pid
|
* medium - medium of the 3pid
|
||||||
* client_secret - a secret provided by the client for this validation session
|
* client_secret - a secret provided by the client for this validation session
|
||||||
@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return rows[0]
|
return rows[0]
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_threepid_validation_session", get_threepid_validation_session_txn
|
"get_threepid_validation_session", get_threepid_validation_session_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_threepid_session(self, session_id):
|
async def delete_threepid_session(self, session_id: str) -> None:
|
||||||
"""Removes a threepid validation session from the database. This can
|
"""Removes a threepid validation session from the database. This can
|
||||||
be done after validation has been performed and whatever action was
|
be done after validation has been performed and whatever action was
|
||||||
waiting on it has been carried out
|
waiting on it has been carried out
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id (str): The ID of the session to delete
|
session_id: The ID of the session to delete
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def delete_threepid_session_txn(txn):
|
def delete_threepid_session_txn(txn):
|
||||||
@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
keyvalues={"session_id": session_id},
|
keyvalues={"session_id": session_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"delete_threepid_session", delete_threepid_session_txn
|
"delete_threepid_session", delete_threepid_session_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
desc="add_access_token_to_user",
|
desc="add_access_token_to_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
def register_user(
|
async def register_user(
|
||||||
self,
|
self,
|
||||||
user_id,
|
user_id: str,
|
||||||
password_hash=None,
|
password_hash: Optional[str] = None,
|
||||||
was_guest=False,
|
was_guest: bool = False,
|
||||||
make_guest=False,
|
make_guest: bool = False,
|
||||||
appservice_id=None,
|
appservice_id: Optional[str] = None,
|
||||||
create_profile_with_displayname=None,
|
create_profile_with_displayname: Optional[str] = None,
|
||||||
admin=False,
|
admin: bool = False,
|
||||||
user_type=None,
|
user_type: Optional[str] = None,
|
||||||
shadow_banned=False,
|
shadow_banned: bool = False,
|
||||||
):
|
) -> None:
|
||||||
"""Attempts to register an account.
|
"""Attempts to register an account.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The desired user ID to register.
|
user_id: The desired user ID to register.
|
||||||
password_hash (str|None): Optional. The password hash for this user.
|
password_hash: Optional. The password hash for this user.
|
||||||
was_guest (bool): Optional. Whether this is a guest account being
|
was_guest: Whether this is a guest account being upgraded to a
|
||||||
upgraded to a non-guest account.
|
non-guest account.
|
||||||
make_guest (boolean): True if the the new user should be guest,
|
make_guest: True if the the new user should be guest, false to add a
|
||||||
false to add a regular user account.
|
regular user account.
|
||||||
appservice_id (str): The ID of the appservice registering the user.
|
appservice_id: The ID of the appservice registering the user.
|
||||||
create_profile_with_displayname (unicode): Optionally create a profile for
|
create_profile_with_displayname: Optionally create a profile for
|
||||||
the user, setting their displayname to the given value
|
the user, setting their displayname to the given value
|
||||||
admin (boolean): is an admin user?
|
admin: is an admin user?
|
||||||
user_type (str|None): type of user. One of the values from
|
user_type: type of user. One of the values from api.constants.UserTypes,
|
||||||
api.constants.UserTypes, or None for a normal user.
|
or None for a normal user.
|
||||||
shadow_banned (bool): Whether the user is shadow-banned,
|
shadow_banned: Whether the user is shadow-banned, i.e. they may be
|
||||||
i.e. they may be told their requests succeeded but we ignore them.
|
told their requests succeeded but we ignore them.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if the user_id could not be registered.
|
StoreError if the user_id could not be registered.
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred
|
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"register_user",
|
"register_user",
|
||||||
self._register_user,
|
self._register_user,
|
||||||
user_id,
|
user_id,
|
||||||
@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
desc="record_user_external_id",
|
desc="record_user_external_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
def user_set_password_hash(self, user_id, password_hash):
|
async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
|
||||||
"""
|
"""
|
||||||
NB. This does *not* evict any cache because the one use for this
|
NB. This does *not* evict any cache because the one use for this
|
||||||
removes most of the entries subsequently anyway so it would be
|
removes most of the entries subsequently anyway so it would be
|
||||||
@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"user_set_password_hash", user_set_password_hash_txn
|
"user_set_password_hash", user_set_password_hash_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def user_set_consent_version(self, user_id, consent_version):
|
async def user_set_consent_version(
|
||||||
|
self, user_id: str, consent_version: str
|
||||||
|
) -> None:
|
||||||
"""Updates the user table to record privacy policy consent
|
"""Updates the user table to record privacy policy consent
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): full mxid of the user to update
|
user_id: full mxid of the user to update
|
||||||
consent_version (str): version of the policy the user has consented
|
consent_version: version of the policy the user has consented to
|
||||||
to
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StoreError(404) if user not found
|
StoreError(404) if user not found
|
||||||
@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||||
|
|
||||||
return self.db_pool.runInteraction("user_set_consent_version", f)
|
await self.db_pool.runInteraction("user_set_consent_version", f)
|
||||||
|
|
||||||
def user_set_consent_server_notice_sent(self, user_id, consent_version):
|
async def user_set_consent_server_notice_sent(
|
||||||
|
self, user_id: str, consent_version: str
|
||||||
|
) -> None:
|
||||||
"""Updates the user table to record that we have sent the user a server
|
"""Updates the user table to record that we have sent the user a server
|
||||||
notice about privacy policy consent
|
notice about privacy policy consent
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): full mxid of the user to update
|
user_id: full mxid of the user to update
|
||||||
consent_version (str): version of the policy we have notified the
|
consent_version: version of the policy we have notified the user about
|
||||||
user about
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StoreError(404) if user not found
|
StoreError(404) if user not found
|
||||||
@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||||
|
|
||||||
return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
|
await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
|
||||||
|
|
||||||
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
|
async def user_delete_access_tokens(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
except_token_id: Optional[str] = None,
|
||||||
|
device_id: Optional[str] = None,
|
||||||
|
) -> List[Tuple[str, int, Optional[str]]]:
|
||||||
"""
|
"""
|
||||||
Invalidate access tokens belonging to a user
|
Invalidate access tokens belonging to a user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): ID of user the tokens belong to
|
user_id: ID of user the tokens belong to
|
||||||
except_token_id (str): list of access_tokens IDs which should
|
except_token_id: access_tokens ID which should *not* be deleted
|
||||||
*not* be deleted
|
device_id: ID of device the tokens are associated with.
|
||||||
device_id (str|None): ID of device the tokens are associated with.
|
|
||||||
If None, tokens associated with any device (or no device) will
|
If None, tokens associated with any device (or no device) will
|
||||||
be deleted
|
be deleted
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[list[str, int, str|None, int]]: a list of
|
A tuple of (token, token id, device id) for each of the deleted tokens
|
||||||
(token, token id, device id) for each of the deleted tokens
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
|
|
||||||
return tokens_and_devices
|
return tokens_and_devices
|
||||||
|
|
||||||
return self.db_pool.runInteraction("user_delete_access_tokens", f)
|
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||||
|
|
||||||
def delete_access_token(self, access_token):
|
async def delete_access_token(self, access_token: str) -> None:
|
||||||
def f(txn):
|
def f(txn):
|
||||||
self.db_pool.simple_delete_one_txn(
|
self.db_pool.simple_delete_one_txn(
|
||||||
txn, table="access_tokens", keyvalues={"token": access_token}
|
txn, table="access_tokens", keyvalues={"token": access_token}
|
||||||
@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
txn, self.get_user_by_access_token, (access_token,)
|
txn, self.get_user_by_access_token, (access_token,)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction("delete_access_token", f)
|
await self.db_pool.runInteraction("delete_access_token", f)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def is_guest(self, user_id: str) -> bool:
|
async def is_guest(self, user_id: str) -> bool:
|
||||||
@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
desc="get_users_pending_deactivation",
|
desc="get_users_pending_deactivation",
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
|
async def validate_threepid_session(
|
||||||
|
self, session_id: str, client_secret: str, token: str, current_ts: int
|
||||||
|
) -> Optional[str]:
|
||||||
"""Attempt to validate a threepid session using a token
|
"""Attempt to validate a threepid session using a token
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id (str): The id of a validation session
|
session_id: The id of a validation session
|
||||||
client_secret (str): A unique string provided by the client to
|
client_secret: A unique string provided by the client to help identify
|
||||||
help identify this validation attempt
|
this validation attempt
|
||||||
token (str): A validation token
|
token: A validation token
|
||||||
current_ts (int): The current unix time in milliseconds. Used for
|
current_ts: The current unix time in milliseconds. Used for checking
|
||||||
checking token expiry status
|
token expiry status
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ThreepidValidationError: if a matching validation token was not found or has
|
ThreepidValidationError: if a matching validation token was not found or has
|
||||||
expired
|
expired
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
deferred str|None: A str representing a link to redirect the user
|
A str representing a link to redirect the user to if there is one.
|
||||||
to if there is one.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Insert everything into a transaction in order to run atomically
|
# Insert everything into a transaction in order to run atomically
|
||||||
@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
return next_link
|
return next_link
|
||||||
|
|
||||||
# Return next_link if it exists
|
# Return next_link if it exists
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"validate_threepid_session_txn", validate_threepid_session_txn
|
"validate_threepid_session_txn", validate_threepid_session_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def start_or_continue_validation_session(
|
async def start_or_continue_validation_session(
|
||||||
self,
|
self,
|
||||||
medium,
|
medium: str,
|
||||||
address,
|
address: str,
|
||||||
session_id,
|
session_id: str,
|
||||||
client_secret,
|
client_secret: str,
|
||||||
send_attempt,
|
send_attempt: int,
|
||||||
next_link,
|
next_link: Optional[str],
|
||||||
token,
|
token: str,
|
||||||
token_expires,
|
token_expires: int,
|
||||||
):
|
) -> None:
|
||||||
"""Creates a new threepid validation session if it does not already
|
"""Creates a new threepid validation session if it does not already
|
||||||
exist and associates a new validation token with it
|
exist and associates a new validation token with it
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
medium (str): The medium of the 3PID
|
medium: The medium of the 3PID
|
||||||
address (str): The address of the 3PID
|
address: The address of the 3PID
|
||||||
session_id (str): The id of this validation session
|
session_id: The id of this validation session
|
||||||
client_secret (str): A unique string provided by the client to
|
client_secret: A unique string provided by the client to help
|
||||||
help identify this validation attempt
|
identify this validation attempt
|
||||||
send_attempt (int): The latest send_attempt on this session
|
send_attempt: The latest send_attempt on this session
|
||||||
next_link (str|None): The link to redirect the user to upon
|
next_link: The link to redirect the user to upon successful validation
|
||||||
successful validation
|
token: The validation token
|
||||||
token (str): The validation token
|
token_expires: The timestamp for which after the token will no
|
||||||
token_expires (int): The timestamp for which after the token
|
longer be valid
|
||||||
will no longer be valid
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def start_or_continue_validation_session_txn(txn):
|
def start_or_continue_validation_session_txn(txn):
|
||||||
@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"start_or_continue_validation_session",
|
"start_or_continue_validation_session",
|
||||||
start_or_continue_validation_session_txn,
|
start_or_continue_validation_session_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cull_expired_threepid_validation_tokens(self):
|
async def cull_expired_threepid_validation_tokens(self) -> None:
|
||||||
"""Remove threepid validation tokens with expiry dates that have passed"""
|
"""Remove threepid validation tokens with expiry dates that have passed"""
|
||||||
|
|
||||||
def cull_expired_threepid_validation_tokens_txn(txn, ts):
|
def cull_expired_threepid_validation_tokens_txn(txn, ts):
|
||||||
@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
DELETE FROM threepid_validation_token WHERE
|
DELETE FROM threepid_validation_token WHERE
|
||||||
expires < ?
|
expires < ?
|
||||||
"""
|
"""
|
||||||
return txn.execute(sql, (ts,))
|
txn.execute(sql, (ts,))
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"cull_expired_threepid_validation_tokens",
|
"cull_expired_threepid_validation_tokens",
|
||||||
cull_expired_threepid_validation_tokens_txn,
|
cull_expired_threepid_validation_tokens_txn,
|
||||||
self.clock.time_msec(),
|
self.clock.time_msec(),
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
|
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=100000, iterable=True)
|
@cached(max_entries=100000, iterable=True)
|
||||||
def get_users_in_room(self, room_id: str):
|
async def get_users_in_room(self, room_id: str) -> List[str]:
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_users_in_room", self.get_users_in_room_txn, room_id
|
"get_users_in_room", self.get_users_in_room_txn, room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -180,13 +180,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
return [r[0] for r in txn]
|
return [r[0] for r in txn]
|
||||||
|
|
||||||
@cached(max_entries=100000)
|
@cached(max_entries=100000)
|
||||||
def get_room_summary(self, room_id: str):
|
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
|
||||||
""" Get the details of a room roughly suitable for use by the room
|
""" Get the details of a room roughly suitable for use by the room
|
||||||
summary extension to /sync. Useful when lazy loading room members.
|
summary extension to /sync. Useful when lazy loading room members.
|
||||||
Args:
|
Args:
|
||||||
room_id: The room ID to query
|
room_id: The room ID to query
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, MemberSummary]:
|
|
||||||
dict of membership states, pointing to a MemberSummary named tuple.
|
dict of membership states, pointing to a MemberSummary named tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_room_summary", _get_room_summary_txn
|
||||||
|
)
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
|
async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
|
||||||
"""Get all the rooms the *local* user is invited to.
|
"""Get all the rooms the *local* user is invited to.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user ID.
|
user_id: The user ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A awaitable list of RoomsForUser.
|
A list of RoomsForUser.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.get_rooms_for_local_user_where_membership_is(
|
return await self.get_rooms_for_local_user_where_membership_is(
|
||||||
user_id, [Membership.INVITE]
|
user_id, [Membership.INVITE]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
@cached(max_entries=500000, iterable=True)
|
@cached(max_entries=500000, iterable=True)
|
||||||
def get_rooms_for_user_with_stream_ordering(self, user_id: str):
|
async def get_rooms_for_user_with_stream_ordering(
|
||||||
|
self, user_id: str
|
||||||
|
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
|
||||||
"""Returns a set of room_ids the user is currently joined to.
|
"""Returns a set of room_ids the user is currently joined to.
|
||||||
|
|
||||||
If a remote user only returns rooms this server is currently
|
If a remote user only returns rooms this server is currently
|
||||||
@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
user_id
|
user_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
|
Returns the rooms the user is in currently, along with the stream
|
||||||
the rooms the user is in currently, along with the stream ordering
|
ordering of the most recent join for that user and room.
|
||||||
of the most recent join for that user and room.
|
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_rooms_for_user_with_stream_ordering",
|
"get_rooms_for_user_with_stream_ordering",
|
||||||
self._get_rooms_for_user_with_stream_ordering_txn,
|
self._get_rooms_for_user_with_stream_ordering_txn,
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
|
def _get_rooms_for_user_with_stream_ordering_txn(
|
||||||
|
self, txn, user_id: str
|
||||||
|
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
|
||||||
# We use `current_state_events` here and not `local_current_membership`
|
# We use `current_state_events` here and not `local_current_membership`
|
||||||
# as a) this gets called with remote users and b) this only gets called
|
# as a) this gets called with remote users and b) this only gets called
|
||||||
# for rooms the server is participating in.
|
# for rooms the server is participating in.
|
||||||
@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (user_id, Membership.JOIN))
|
txn.execute(sql, (user_id, Membership.JOIN))
|
||||||
results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
|
return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def get_users_server_still_shares_room_with(
|
async def get_users_server_still_shares_room_with(
|
||||||
self, user_ids: Collection[str]
|
self, user_ids: Collection[str]
|
||||||
@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
return count == 0
|
return count == 0
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_forgotten_rooms_for_user(self, user_id: str):
|
async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
|
||||||
"""Gets all rooms the user has forgotten.
|
"""Gets all rooms the user has forgotten.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id
|
user_id: The user ID to query the rooms of.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[set[str]]
|
The forgotten rooms.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_forgotten_rooms_for_user_txn(txn):
|
def _get_forgotten_rooms_for_user_txn(txn):
|
||||||
@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
txn.execute(sql, (user_id,))
|
txn.execute(sql, (user_id,))
|
||||||
return {row[0] for row in txn if row[1] == 0}
|
return {row[0] for row in txn if row[1] == 0}
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
|
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
|||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
super(RoomMemberStore, self).__init__(database, db_conn, hs)
|
super(RoomMemberStore, self).__init__(database, db_conn, hs)
|
||||||
|
|
||||||
def forget(self, user_id: str, room_id: str):
|
async def forget(self, user_id: str, room_id: str) -> None:
|
||||||
"""Indicate that user_id wishes to discard history for room_id."""
|
"""Indicate that user_id wishes to discard history for room_id."""
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
|||||||
txn, self.get_forgotten_rooms_for_user, (user_id,)
|
txn, self.get_forgotten_rooms_for_user, (user_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.db_pool.runInteraction("forget_membership", f)
|
await self.db_pool.runInteraction("forget_membership", f)
|
||||||
|
|
||||||
|
|
||||||
class _JoinedHostsCache(object):
|
class _JoinedHostsCache(object):
|
||||||
|
Loading…
Reference in New Issue
Block a user