Convert event_push_actions, registration, and roommember datastores to async (#8197)

This commit is contained in:
Patrick Cloke 2020-08-28 11:34:50 -04:00 committed by GitHub
parent 22b926c284
commit d58fda99ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 169 additions and 160 deletions

1
changelog.d/8197.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List from typing import Dict, List, Union
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@ -383,19 +383,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Now return the first `limit` # Now return the first `limit`
return notifs[:limit] return notifs[:limit]
def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering): async def get_if_maybe_push_in_range_for_user(
self, user_id: str, min_stream_ordering: int
) -> bool:
"""A fast check to see if there might be something to push for the """A fast check to see if there might be something to push for the
user since the given stream ordering. May return false positives. user since the given stream ordering. May return false positives.
Useful to know whether to bother starting a pusher on start up or not. Useful to know whether to bother starting a pusher on start up or not.
Args: Args:
user_id (str) user_id
min_stream_ordering (int) min_stream_ordering
Returns: Returns:
Deferred[bool]: True if there may be push to process, False if True if there may be push to process, False if there definitely isn't.
there definitely isn't.
""" """
def _get_if_maybe_push_in_range_for_user_txn(txn): def _get_if_maybe_push_in_range_for_user_txn(txn):
@ -408,22 +409,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, min_stream_ordering)) txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone()) return bool(txn.fetchone())
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_if_maybe_push_in_range_for_user", "get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn, _get_if_maybe_push_in_range_for_user_txn,
) )
async def add_push_actions_to_staging(self, event_id, user_id_actions): async def add_push_actions_to_staging(
self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]]
) -> None:
"""Add the push actions for the event to the push action staging area. """Add the push actions for the event to the push action staging area.
Args: Args:
event_id (str) event_id
user_id_actions (dict[str, list[dict|str])]): A dictionary mapping user_id_actions: A mapping of user_id to list of push actions, where
user_id to list of push actions, where an action can either be an action can either be a string or dict.
a string or dict.
Returns:
Deferred
""" """
if not user_id_actions: if not user_id_actions:
@ -507,7 +506,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
) )
def find_first_stream_ordering_after_ts(self, ts): async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
"""Gets the stream ordering corresponding to a given timestamp. """Gets the stream ordering corresponding to a given timestamp.
Specifically, finds the stream_ordering of the first event that was Specifically, finds the stream_ordering of the first event that was
@ -516,13 +515,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
relatively slow. relatively slow.
Args: Args:
ts (int): timestamp in millis ts: timestamp in millis
Returns: Returns:
Deferred[int]: stream ordering of the first event received on/after stream ordering of the first event received on/after the timestamp
the timestamp
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"_find_first_stream_ordering_after_ts_txn", "_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn, self._find_first_stream_ordering_after_ts_txn,
ts, ts,

View File

@ -17,7 +17,7 @@
import logging import logging
import re import re
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@ -84,17 +84,17 @@ class RegistrationWorkerStore(SQLBaseStore):
return is_trial return is_trial
@cached() @cached()
def get_user_by_access_token(self, token): async def get_user_by_access_token(self, token: str) -> Optional[dict]:
"""Get a user from the given access token. """Get a user from the given access token.
Args: Args:
token (str): The access token of a user. token: The access token of a user.
Returns: Returns:
defer.Deferred: None, if the token did not match, otherwise dict None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`, including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`. `valid_until_ms`.
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token "get_user_by_access_token", self._query_for_auth, token
) )
@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
return bool(res) if res else False return bool(res) if res else False
def set_server_admin(self, user, admin): async def set_server_admin(self, user: UserID, admin: bool) -> None:
"""Sets whether a user is an admin of this homeserver. """Sets whether a user is an admin of this homeserver.
Args: Args:
user (UserID): user ID of the user to test user: user ID of the user to test
admin (bool): true iff the user is to be a server admin, admin: true iff the user is to be a server admin, false otherwise.
false otherwise.
""" """
def set_server_admin_txn(txn): def set_server_admin_txn(txn):
@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_user_by_id, (user.to_string(),) txn, self.get_user_by_id, (user.to_string(),)
) )
return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
) )
return True if res == UserTypes.SUPPORT else False return True if res == UserTypes.SUPPORT else False
def get_users_by_id_case_insensitive(self, user_id): async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
"""Gets users that match user_id case insensitively. """Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
Returns:
A mapping of user_id -> password_hash.
""" """
def f(txn): def f(txn):
@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return dict(txn) return dict(txn)
return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f) return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id( async def get_user_by_external_id(
self, auth_provider: str, external_id: str self, auth_provider: str, external_id: str
@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("count_users", _count_users) return await self.db_pool.runInteraction("count_users", _count_users)
def count_daily_user_type(self): async def count_daily_user_type(self) -> Dict[str, int]:
""" """
Counts 1) native non guest users Counts 1) native non guest users
2) native guests users 2) native guests users
@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1] results[row[0]] = row[1]
return results return results
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"count_daily_user_type", _count_daily_user_type "count_daily_user_type", _count_daily_user_type
) )
@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
# Convert the integer into a boolean. # Convert the integer into a boolean.
return res == 1 return res == 1
def get_threepid_validation_session( async def get_threepid_validation_session(
self, medium, client_secret, address=None, sid=None, validated=True self,
): medium: Optional[str],
client_secret: str,
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
) -> Optional[Dict[str, Any]]:
"""Gets a session_id and last_send_attempt (if available) for a """Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata combination of validation metadata
Args: Args:
medium (str|None): The medium of the 3PID medium: The medium of the 3PID
address (str|None): The address of the 3PID client_secret: A unique string provided by the client to help identify this
sid (str|None): The ID of the validation session
client_secret (str): A unique string provided by the client to help identify this
validation attempt validation attempt
validated (bool|None): Whether sessions should be filtered by address: The address of the 3PID
sid: The ID of the validation session
validated: Whether sessions should be filtered by
whether they have been validated already or not. None to whether they have been validated already or not. None to
perform no filtering perform no filtering
Returns: Returns:
Deferred[dict|None]: A dict containing the following: A dict containing the following:
* address - address of the 3pid * address - address of the 3pid
* medium - medium of the 3pid * medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session * client_secret - a secret provided by the client for this validation session
@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0] return rows[0]
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn "get_threepid_validation_session", get_threepid_validation_session_txn
) )
def delete_threepid_session(self, session_id): async def delete_threepid_session(self, session_id: str) -> None:
"""Removes a threepid validation session from the database. This can """Removes a threepid validation session from the database. This can
be done after validation has been performed and whatever action was be done after validation has been performed and whatever action was
waiting on it has been carried out waiting on it has been carried out
Args: Args:
session_id (str): The ID of the session to delete session_id: The ID of the session to delete
""" """
def delete_threepid_session_txn(txn): def delete_threepid_session_txn(txn):
@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_threepid_session", delete_threepid_session_txn "delete_threepid_session", delete_threepid_session_txn
) )
@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )
def register_user( async def register_user(
self, self,
user_id, user_id: str,
password_hash=None, password_hash: Optional[str] = None,
was_guest=False, was_guest: bool = False,
make_guest=False, make_guest: bool = False,
appservice_id=None, appservice_id: Optional[str] = None,
create_profile_with_displayname=None, create_profile_with_displayname: Optional[str] = None,
admin=False, admin: bool = False,
user_type=None, user_type: Optional[str] = None,
shadow_banned=False, shadow_banned: bool = False,
): ) -> None:
"""Attempts to register an account. """Attempts to register an account.
Args: Args:
user_id (str): The desired user ID to register. user_id: The desired user ID to register.
password_hash (str|None): Optional. The password hash for this user. password_hash: Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being was_guest: Whether this is a guest account being upgraded to a
upgraded to a non-guest account. non-guest account.
make_guest (boolean): True if the the new user should be guest, make_guest: True if the the new user should be guest, false to add a
false to add a regular user account. regular user account.
appservice_id (str): The ID of the appservice registering the user. appservice_id: The ID of the appservice registering the user.
create_profile_with_displayname (unicode): Optionally create a profile for create_profile_with_displayname: Optionally create a profile for
the user, setting their displayname to the given value the user, setting their displayname to the given value
admin (boolean): is an admin user? admin: is an admin user?
user_type (str|None): type of user. One of the values from user_type: type of user. One of the values from api.constants.UserTypes,
api.constants.UserTypes, or None for a normal user. or None for a normal user.
shadow_banned (bool): Whether the user is shadow-banned, shadow_banned: Whether the user is shadow-banned, i.e. they may be
i.e. they may be told their requests succeeded but we ignore them. told their requests succeeded but we ignore them.
Raises: Raises:
StoreError if the user_id could not be registered. StoreError if the user_id could not be registered.
Returns:
Deferred
""" """
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"register_user", "register_user",
self._register_user, self._register_user,
user_id, user_id,
@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id", desc="record_user_external_id",
) )
def user_set_password_hash(self, user_id, password_hash): async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
""" """
NB. This does *not* evict any cache because the one use for this NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be removes most of the entries subsequently anyway so it would be
@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"user_set_password_hash", user_set_password_hash_txn "user_set_password_hash", user_set_password_hash_txn
) )
def user_set_consent_version(self, user_id, consent_version): async def user_set_consent_version(
self, user_id: str, consent_version: str
) -> None:
"""Updates the user table to record privacy policy consent """Updates the user table to record privacy policy consent
Args: Args:
user_id (str): full mxid of the user to update user_id: full mxid of the user to update
consent_version (str): version of the policy the user has consented consent_version: version of the policy the user has consented to
to
Raises: Raises:
StoreError(404) if user not found StoreError(404) if user not found
@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction("user_set_consent_version", f) await self.db_pool.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version): async def user_set_consent_server_notice_sent(
self, user_id: str, consent_version: str
) -> None:
"""Updates the user table to record that we have sent the user a server """Updates the user table to record that we have sent the user a server
notice about privacy policy consent notice about privacy policy consent
Args: Args:
user_id (str): full mxid of the user to update user_id: full mxid of the user to update
consent_version (str): version of the policy we have notified the consent_version: version of the policy we have notified the user about
user about
Raises: Raises:
StoreError(404) if user not found StoreError(404) if user not found
@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f) await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None): async def user_delete_access_tokens(
self,
user_id: str,
except_token_id: Optional[str] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
""" """
Invalidate access tokens belonging to a user Invalidate access tokens belonging to a user
Args: Args:
user_id (str): ID of user the tokens belong to user_id: ID of user the tokens belong to
except_token_id (str): list of access_tokens IDs which should except_token_id: access_tokens ID which should *not* be deleted
*not* be deleted device_id: ID of device the tokens are associated with.
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
be deleted be deleted
Returns: Returns:
defer.Deferred[list[str, int, str|None, int]]: a list of A tuple of (token, token id, device id) for each of the deleted tokens
(token, token id, device id) for each of the deleted tokens
""" """
def f(txn): def f(txn):
@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices return tokens_and_devices
return self.db_pool.runInteraction("user_delete_access_tokens", f) return await self.db_pool.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token): async def delete_access_token(self, access_token: str) -> None:
def f(txn): def f(txn):
self.db_pool.simple_delete_one_txn( self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token} txn, table="access_tokens", keyvalues={"token": access_token}
@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,) txn, self.get_user_by_access_token, (access_token,)
) )
return self.db_pool.runInteraction("delete_access_token", f) await self.db_pool.runInteraction("delete_access_token", f)
@cached() @cached()
async def is_guest(self, user_id: str) -> bool: async def is_guest(self, user_id: str) -> bool:
@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="get_users_pending_deactivation", desc="get_users_pending_deactivation",
) )
def validate_threepid_session(self, session_id, client_secret, token, current_ts): async def validate_threepid_session(
self, session_id: str, client_secret: str, token: str, current_ts: int
) -> Optional[str]:
"""Attempt to validate a threepid session using a token """Attempt to validate a threepid session using a token
Args: Args:
session_id (str): The id of a validation session session_id: The id of a validation session
client_secret (str): A unique string provided by the client to client_secret: A unique string provided by the client to help identify
help identify this validation attempt this validation attempt
token (str): A validation token token: A validation token
current_ts (int): The current unix time in milliseconds. Used for current_ts: The current unix time in milliseconds. Used for checking
checking token expiry status token expiry status
Raises: Raises:
ThreepidValidationError: if a matching validation token was not found or has ThreepidValidationError: if a matching validation token was not found or has
expired expired
Returns: Returns:
deferred str|None: A str representing a link to redirect the user A str representing a link to redirect the user to if there is one.
to if there is one.
""" """
# Insert everything into a transaction in order to run atomically # Insert everything into a transaction in order to run atomically
@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link return next_link
# Return next_link if it exists # Return next_link if it exists
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn "validate_threepid_session_txn", validate_threepid_session_txn
) )
def start_or_continue_validation_session( async def start_or_continue_validation_session(
self, self,
medium, medium: str,
address, address: str,
session_id, session_id: str,
client_secret, client_secret: str,
send_attempt, send_attempt: int,
next_link, next_link: Optional[str],
token, token: str,
token_expires, token_expires: int,
): ) -> None:
"""Creates a new threepid validation session if it does not already """Creates a new threepid validation session if it does not already
exist and associates a new validation token with it exist and associates a new validation token with it
Args: Args:
medium (str): The medium of the 3PID medium: The medium of the 3PID
address (str): The address of the 3PID address: The address of the 3PID
session_id (str): The id of this validation session session_id: The id of this validation session
client_secret (str): A unique string provided by the client to client_secret: A unique string provided by the client to help
help identify this validation attempt identify this validation attempt
send_attempt (int): The latest send_attempt on this session send_attempt: The latest send_attempt on this session
next_link (str|None): The link to redirect the user to upon next_link: The link to redirect the user to upon successful validation
successful validation token: The validation token
token (str): The validation token token_expires: The timestamp for which after the token will no
token_expires (int): The timestamp for which after the token longer be valid
will no longer be valid
""" """
def start_or_continue_validation_session_txn(txn): def start_or_continue_validation_session_txn(txn):
@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
}, },
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"start_or_continue_validation_session", "start_or_continue_validation_session",
start_or_continue_validation_session_txn, start_or_continue_validation_session_txn,
) )
def cull_expired_threepid_validation_tokens(self): async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed""" """Remove threepid validation tokens with expiry dates that have passed"""
def cull_expired_threepid_validation_tokens_txn(txn, ts): def cull_expired_threepid_validation_tokens_txn(txn, ts):
@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
DELETE FROM threepid_validation_token WHERE DELETE FROM threepid_validation_token WHERE
expires < ? expires < ?
""" """
return txn.execute(sql, (ts,)) txn.execute(sql, (ts,))
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens", "cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn, cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(), self.clock.time_msec(),

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) )
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id: str): async def get_users_in_room(self, room_id: str) -> List[str]:
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id "get_users_in_room", self.get_users_in_room_txn, room_id
) )
@ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn] return [r[0] for r in txn]
@cached(max_entries=100000) @cached(max_entries=100000)
def get_room_summary(self, room_id: str): async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
""" Get the details of a room roughly suitable for use by the room """ Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members. summary extension to /sync. Useful when lazy loading room members.
Args: Args:
room_id: The room ID to query room_id: The room ID to query
Returns: Returns:
Deferred[dict[str, MemberSummary]: dict of membership states, pointing to a MemberSummary named tuple.
dict of membership states, pointing to a MemberSummary named tuple.
""" """
def _get_room_summary_txn(txn): def _get_room_summary_txn(txn):
@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res return res
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn) return await self.db_pool.runInteraction(
"get_room_summary", _get_room_summary_txn
)
@cached() @cached()
def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]: async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
"""Get all the rooms the *local* user is invited to. """Get all the rooms the *local* user is invited to.
Args: Args:
user_id: The user ID. user_id: The user ID.
Returns: Returns:
A awaitable list of RoomsForUser. A list of RoomsForUser.
""" """
return self.get_rooms_for_local_user_where_membership_is( return await self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE] user_id, [Membership.INVITE]
) )
@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results return results
@cached(max_entries=500000, iterable=True) @cached(max_entries=500000, iterable=True)
def get_rooms_for_user_with_stream_ordering(self, user_id: str): async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
"""Returns a set of room_ids the user is currently joined to. """Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently If a remote user only returns rooms this server is currently
@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id user_id
Returns: Returns:
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns Returns the rooms the user is in currently, along with the stream
the rooms the user is in currently, along with the stream ordering ordering of the most recent join for that user and room.
of the most recent join for that user and room.
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering", "get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn, self._get_rooms_for_user_with_stream_ordering_txn,
user_id, user_id,
) )
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str): def _get_rooms_for_user_with_stream_ordering_txn(
self, txn, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership` # We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called # as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in. # for rooms the server is participating in.
@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
""" """
txn.execute(sql, (user_id, Membership.JOIN)) txn.execute(sql, (user_id, Membership.JOIN))
results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn) return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
return results
async def get_users_server_still_shares_room_with( async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str] self, user_ids: Collection[str]
@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return count == 0 return count == 0
@cached() @cached()
def get_forgotten_rooms_for_user(self, user_id: str): async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
"""Gets all rooms the user has forgotten. """Gets all rooms the user has forgotten.
Args: Args:
user_id user_id: The user ID to query the rooms of.
Returns: Returns:
Deferred[set[str]] The forgotten rooms.
""" """
def _get_forgotten_rooms_for_user_txn(txn): def _get_forgotten_rooms_for_user_txn(txn):
@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return {row[0] for row in txn if row[1] == 0} return {row[0] for row in txn if row[1] == 0}
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
) )
@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs) super(RoomMemberStore, self).__init__(database, db_conn, hs)
def forget(self, user_id: str, room_id: str): async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id.""" """Indicate that user_id wishes to discard history for room_id."""
def f(txn): def f(txn):
@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,) txn, self.get_forgotten_rooms_for_user, (user_id,)
) )
return self.db_pool.runInteraction("forget_membership", f) await self.db_pool.runInteraction("forget_membership", f)
class _JoinedHostsCache(object): class _JoinedHostsCache(object):