Converts event_federation and registration databases to async/await (#8061)

This commit is contained in:
Patrick Cloke 2020-08-11 17:21:13 -04:00 committed by GitHub
parent 61d8ff0d44
commit a0acdfa9e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 150 additions and 177 deletions

1
changelog.d/8061.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -15,9 +15,7 @@
import itertools import itertools
import logging import logging
from queue import Empty, PriorityQueue from queue import Empty, PriorityQueue
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, Iterable, List, Optional, Set, Tuple
from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -286,17 +284,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return dict(txn) return dict(txn)
@defer.inlineCallbacks async def get_max_depth_of(self, event_ids: List[str]) -> int:
def get_max_depth_of(self, event_ids):
"""Returns the max depth of a set of event IDs """Returns the max depth of a set of event IDs
Args: Args:
event_ids (list[str]) event_ids: The event IDs to calculate the max depth of.
Returns
Deferred[int]
""" """
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="events", table="events",
column="event_id", column="event_id",
iterable=event_ids, iterable=event_ids,
@ -550,9 +544,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return event_results return event_results
@defer.inlineCallbacks async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
def get_missing_events(self, room_id, earliest_events, latest_events, limit): ids = await self.db_pool.runInteraction(
ids = yield self.db_pool.runInteraction(
"get_missing_events", "get_missing_events",
self._get_missing_events, self._get_missing_events,
room_id, room_id,
@ -560,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events, latest_events,
limit, limit,
) )
events = yield self.get_events_as_list(ids) events = await self.get_events_as_list(ids)
return events return events
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@ -595,17 +588,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_results.reverse() event_results.reverse()
return event_results return event_results
@defer.inlineCallbacks async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]:
def get_successor_events(self, event_ids):
"""Fetch all events that have the given events as a prev event """Fetch all events that have the given events as a prev event
Args: Args:
event_ids (iterable[str]) event_ids: The events to use as the previous events.
Returns:
Deferred[list[str]]
""" """
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="event_edges", table="event_edges",
column="prev_event_id", column="prev_event_id",
iterable=event_ids, iterable=event_ids,
@ -674,8 +663,7 @@ class EventFederationStore(EventFederationWorkerStore):
txn.execute(query, (room_id,)) txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
@defer.inlineCallbacks async def _background_delete_non_state_event_auth(self, progress, batch_size):
def _background_delete_non_state_event_auth(self, progress, batch_size):
def delete_event_auth(txn): def delete_event_auth(txn):
target_min_stream_id = progress.get("target_min_stream_id_inclusive") target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive") max_stream_id = progress.get("max_stream_id_exclusive")
@ -714,12 +702,12 @@ class EventFederationStore(EventFederationWorkerStore):
return min_stream_id >= target_min_stream_id return min_stream_id >= target_min_stream_id
result = yield self.db_pool.runInteraction( result = await self.db_pool.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth self.EVENT_AUTH_STATE_ONLY, delete_event_auth
) )
if not result: if not result:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.EVENT_AUTH_STATE_ONLY self.EVENT_AUTH_STATE_ONLY
) )

View File

@ -17,9 +17,8 @@
import logging import logging
import re import re
from typing import Optional from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
@ -30,7 +29,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_id", desc="get_user_by_id",
) )
@defer.inlineCallbacks async def is_trial_user(self, user_id: str) -> bool:
def is_trial_user(self, user_id):
"""Checks if user is in the "trial" period, i.e. within the first """Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config N days of registration defined by `mau_trial_days` config
Args: Args:
user_id (str) user_id: The user to check for trial status.
Returns:
Deferred[bool]
""" """
info = yield self.get_user_by_id(user_id) info = await self.get_user_by_id(user_id)
if not info: if not info:
return False return False
@ -105,41 +100,42 @@ class RegistrationWorkerStore(SQLBaseStore):
"get_user_by_access_token", self._query_for_auth, token "get_user_by_access_token", self._query_for_auth, token
) )
@cachedInlineCallbacks() @cached()
def get_expiration_ts_for_user(self, user_id): async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
"""Get the expiration timestamp for the account bearing a given user ID. """Get the expiration timestamp for the account bearing a given user ID.
Args: Args:
user_id (str): The ID of the user. user_id: The ID of the user.
Returns: Returns:
defer.Deferred: None, if the account has no expiration timestamp, None, if the account has no expiration timestamp, otherwise int
otherwise int representation of the timestamp (as a number of representation of the timestamp (as a number of milliseconds since epoch).
milliseconds since epoch).
""" """
res = yield self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="expiration_ts_ms", retcol="expiration_ts_ms",
allow_none=True, allow_none=True,
desc="get_expiration_ts_for_user", desc="get_expiration_ts_for_user",
) )
return res
@defer.inlineCallbacks async def set_account_validity_for_user(
def set_account_validity_for_user( self,
self, user_id, expiration_ts, email_sent, renewal_token=None user_id: str,
): expiration_ts: int,
email_sent: bool,
renewal_token: Optional[str] = None,
) -> None:
"""Updates the account validity properties of the given account, with the """Updates the account validity properties of the given account, with the
given values. given values.
Args: Args:
user_id (str): ID of the account to update properties for. user_id: ID of the account to update properties for.
expiration_ts (int): New expiration date, as a timestamp in milliseconds expiration_ts: New expiration date, as a timestamp in milliseconds
since epoch. since epoch.
email_sent (bool): True means a renewal email has been sent for this email_sent: True means a renewal email has been sent for this account
account and there's no need to send another one for the current validity and there's no need to send another one for the current validity
period. period.
renewal_token (str): Renewal token the user can use to extend the validity renewal_token: Renewal token the user can use to extend the validity
of their account. Defaults to no token. of their account. Defaults to no token.
""" """
@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_expiration_ts_for_user, (user_id,) txn, self.get_expiration_ts_for_user, (user_id,)
) )
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_account_validity_for_user", set_account_validity_for_user_txn "set_account_validity_for_user", set_account_validity_for_user_txn
) )
@defer.inlineCallbacks async def set_renewal_token_for_user(
def set_renewal_token_for_user(self, user_id, renewal_token): self, user_id: str, renewal_token: str
) -> None:
"""Defines a renewal token for a given user. """Defines a renewal token for a given user.
Args: Args:
user_id (str): ID of the user to set the renewal token for. user_id: ID of the user to set the renewal token for.
renewal_token (str): Random unique string that will be used to renew the renewal_token: Random unique string that will be used to renew the
user's account. user's account.
Raises: Raises:
StoreError: The provided token is already set for another user. StoreError: The provided token is already set for another user.
""" """
yield self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token}, updatevalues={"renewal_token": renewal_token},
desc="set_renewal_token_for_user", desc="set_renewal_token_for_user",
) )
@defer.inlineCallbacks async def get_user_from_renewal_token(self, renewal_token: str) -> str:
def get_user_from_renewal_token(self, renewal_token):
"""Get a user ID from a renewal token. """Get a user ID from a renewal token.
Args: Args:
renewal_token (str): The renewal token to perform the lookup with. renewal_token: The renewal token to perform the lookup with.
Returns: Returns:
defer.Deferred[str]: The ID of the user to which the token belongs. The ID of the user to which the token belongs.
""" """
res = yield self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="account_validity", table="account_validity",
keyvalues={"renewal_token": renewal_token}, keyvalues={"renewal_token": renewal_token},
retcol="user_id", retcol="user_id",
desc="get_user_from_renewal_token", desc="get_user_from_renewal_token",
) )
return res async def get_renewal_token_for_user(self, user_id: str) -> str:
@defer.inlineCallbacks
def get_renewal_token_for_user(self, user_id):
"""Get the renewal token associated with a given user ID. """Get the renewal token associated with a given user ID.
Args: Args:
user_id (str): The user ID to lookup a token for. user_id: The user ID to lookup a token for.
Returns: Returns:
defer.Deferred[str]: The renewal token associated with this user ID. The renewal token associated with this user ID.
""" """
res = yield self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="renewal_token", retcol="renewal_token",
desc="get_renewal_token_for_user", desc="get_renewal_token_for_user",
) )
return res async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
@defer.inlineCallbacks
def get_users_expiring_soon(self):
"""Selects users whose account will expire in the [now, now + renew_at] time """Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at window (see configuration for account_validity for information on what renew_at
refers to). refers to).
Returns: Returns:
Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]] A list of dictionaries mapping user ID to expiration time (in milliseconds).
""" """
def select_users_txn(txn, now_ms, renew_at): def select_users_txn(txn, now_ms, renew_at):
@ -238,53 +228,49 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, values) txn.execute(sql, values)
return self.db_pool.cursor_to_dict(txn) return self.db_pool.cursor_to_dict(txn)
res = yield self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_expiring_soon", "get_users_expiring_soon",
select_users_txn, select_users_txn,
self.clock.time_msec(), self.clock.time_msec(),
self.config.account_validity.renew_at, self.config.account_validity.renew_at,
) )
return res async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
@defer.inlineCallbacks
def set_renewal_mail_status(self, user_id, email_sent):
"""Sets or unsets the flag that indicates whether a renewal email has been sent """Sets or unsets the flag that indicates whether a renewal email has been sent
to the user (and the user hasn't renewed their account yet). to the user (and the user hasn't renewed their account yet).
Args: Args:
user_id (str): ID of the user to set/unset the flag for. user_id: ID of the user to set/unset the flag for.
email_sent (bool): Flag which indicates whether a renewal email has been sent email_sent: Flag which indicates whether a renewal email has been sent
to this user. to this user.
""" """
yield self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent}, updatevalues={"email_sent": email_sent},
desc="set_renewal_mail_status", desc="set_renewal_mail_status",
) )
@defer.inlineCallbacks async def delete_account_validity_for_user(self, user_id: str) -> None:
def delete_account_validity_for_user(self, user_id):
"""Deletes the entry for the given user in the account validity table, removing """Deletes the entry for the given user in the account validity table, removing
their expiration date and renewal token. their expiration date and renewal token.
Args: Args:
user_id (str): ID of the user to remove from the account validity table. user_id: ID of the user to remove from the account validity table.
""" """
yield self.db_pool.simple_delete_one( await self.db_pool.simple_delete_one(
table="account_validity", table="account_validity",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user", desc="delete_account_validity_for_user",
) )
async def is_server_admin(self, user): async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver. """Determines if a user is an admin of this homeserver.
Args: Args:
user (UserID): user ID of the user to test user: user ID of the user to test
Returns (bool): Returns:
true iff the user is a server admin, false otherwise. true iff the user is a server admin, false otherwise.
""" """
res = await self.db_pool.simple_select_one_onecol( res = await self.db_pool.simple_select_one_onecol(
@ -332,32 +318,31 @@ class RegistrationWorkerStore(SQLBaseStore):
return None return None
@cachedInlineCallbacks() @cached()
def is_real_user(self, user_id): async def is_real_user(self, user_id: str) -> bool:
"""Determines if the user is a real user, ie does not have a 'user_type'. """Determines if the user is a real user, ie does not have a 'user_type'.
Args: Args:
user_id (str): user id to test user_id: user id to test
Returns: Returns:
Deferred[bool]: True if user 'user_type' is null or empty string True if user 'user_type' is null or empty string
""" """
res = yield self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"is_real_user", self.is_real_user_txn, user_id "is_real_user", self.is_real_user_txn, user_id
) )
return res
@cached() @cached()
def is_support_user(self, user_id): async def is_support_user(self, user_id: str) -> bool:
"""Determines if the user is of type UserTypes.SUPPORT """Determines if the user is of type UserTypes.SUPPORT
Args: Args:
user_id (str): user id to test user_id: user id to test
Returns: Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT True if user is of type UserTypes.SUPPORT
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"is_support_user", self.is_support_user_txn, user_id "is_support_user", self.is_support_user_txn, user_id
) )
@ -413,8 +398,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_external_id", desc="get_user_by_external_id",
) )
@defer.inlineCallbacks async def count_all_users(self):
def count_all_users(self):
"""Counts all users registered on the homeserver.""" """Counts all users registered on the homeserver."""
def _count_users(txn): def _count_users(txn):
@ -424,8 +408,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]["users"] return rows[0]["users"]
return 0 return 0
ret = yield self.db_pool.runInteraction("count_users", _count_users) return await self.db_pool.runInteraction("count_users", _count_users)
return ret
def count_daily_user_type(self): def count_daily_user_type(self):
""" """
@ -460,8 +443,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"count_daily_user_type", _count_daily_user_type "count_daily_user_type", _count_daily_user_type
) )
@defer.inlineCallbacks async def count_nonbridged_users(self):
def count_nonbridged_users(self):
def _count_users(txn): def _count_users(txn):
txn.execute( txn.execute(
""" """
@ -472,11 +454,9 @@ class RegistrationWorkerStore(SQLBaseStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
ret = yield self.db_pool.runInteraction("count_users", _count_users) return await self.db_pool.runInteraction("count_users", _count_users)
return ret
@defer.inlineCallbacks async def count_real_users(self):
def count_real_users(self):
"""Counts all users without a special user_type registered on the homeserver.""" """Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn): def _count_users(txn):
@ -486,8 +466,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]["users"] return rows[0]["users"]
return 0 return 0
ret = yield self.db_pool.runInteraction("count_real_users", _count_users) return await self.db_pool.runInteraction("count_real_users", _count_users)
return ret
async def generate_user_id(self) -> str: async def generate_user_id(self) -> str:
"""Generate a suitable localpart for a guest user """Generate a suitable localpart for a guest user
@ -537,23 +516,20 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret["user_id"] return ret["user_id"]
return None return None
@defer.inlineCallbacks async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
def user_add_threepid(self, user_id, medium, address, validated_at, added_at): await self.db_pool.simple_upsert(
yield self.db_pool.simple_upsert(
"user_threepids", "user_threepids",
{"medium": medium, "address": address}, {"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at}, {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
) )
@defer.inlineCallbacks async def user_get_threepids(self, user_id):
def user_get_threepids(self, user_id): return await self.db_pool.simple_select_list(
ret = yield self.db_pool.simple_select_list(
"user_threepids", "user_threepids",
{"user_id": user_id}, {"user_id": user_id},
["medium", "address", "validated_at", "added_at"], ["medium", "address", "validated_at", "added_at"],
"user_get_threepids", "user_get_threepids",
) )
return ret
def user_delete_threepid(self, user_id, medium, address): def user_delete_threepid(self, user_id, medium, address):
return self.db_pool.simple_delete( return self.db_pool.simple_delete(
@ -668,18 +644,18 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_id_servers_user_bound", desc="get_id_servers_user_bound",
) )
@cachedInlineCallbacks() @cached()
def get_user_deactivated_status(self, user_id): async def get_user_deactivated_status(self, user_id: str) -> bool:
"""Retrieve the value for the `deactivated` property for the provided user. """Retrieve the value for the `deactivated` property for the provided user.
Args: Args:
user_id (str): The ID of the user to retrieve the status for. user_id: The ID of the user to retrieve the status for.
Returns: Returns:
defer.Deferred(bool): The requested value. True if the user was deactivated, false if the user is still active.
""" """
res = yield self.db_pool.simple_select_one_onecol( res = await self.db_pool.simple_select_one_onecol(
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
retcol="deactivated", retcol="deactivated",
@ -818,8 +794,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag "users_set_deactivated_flag", self._background_update_set_deactivated_flag
) )
@defer.inlineCallbacks async def _background_update_set_deactivated_flag(self, progress, batch_size):
def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them. for each of them.
""" """
@ -870,19 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
else: else:
return False, len(rows) return False, len(rows)
end, nb_processed = yield self.db_pool.runInteraction( end, nb_processed = await self.db_pool.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
) )
if end: if end:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
"users_set_deactivated_flag" "users_set_deactivated_flag"
) )
return nb_processed return nb_processed
@defer.inlineCallbacks async def _bg_user_threepids_grandfather(self, progress, batch_size):
def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so """We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track we need to handle the case of existing bindings where we didn't track
this. this.
@ -903,11 +877,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
txn.executemany(sql, [(id_server,) for id_server in id_servers]) txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers: if id_servers:
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
) )
yield self.db_pool.updates._end_background_update("user_threepids_grandfather") await self.db_pool.updates._end_background_update("user_threepids_grandfather")
return 1 return 1
@ -937,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS) hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks async def add_access_token_to_user(
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms): self,
user_id: str,
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
) -> None:
"""Adds an access token for the given user. """Adds an access token for the given user.
Args: Args:
user_id (str): The user ID. user_id: The user ID.
token (str): The new access token to add. token: The new access token to add.
device_id (str): ID of the device to associate with the access device_id: ID of the device to associate with the access token
token valid_until_ms: when the token is valid until. None for no expiry.
valid_until_ms (int|None): when the token is valid until. None for
no expiry.
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
""" """
next_id = self._access_tokens_id_gen.get_next() next_id = self._access_tokens_id_gen.get_next()
yield self.db_pool.simple_insert( await self.db_pool.simple_insert(
"access_tokens", "access_tokens",
{ {
"id": next_id, "id": next_id,
@ -1097,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
) )
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
def record_user_external_id( def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str self, auth_provider: str, external_id: str, user_id: str
@ -1241,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return self.db_pool.runInteraction("delete_access_token", f) return self.db_pool.runInteraction("delete_access_token", f)
@cachedInlineCallbacks() @cached()
def is_guest(self, user_id): async def is_guest(self, user_id: str) -> bool:
res = yield self.db_pool.simple_select_one_onecol( res = await self.db_pool.simple_select_one_onecol(
table="users", table="users",
keyvalues={"name": user_id}, keyvalues={"name": user_id},
retcol="is_guest", retcol="is_guest",
@ -1481,16 +1457,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self.clock.time_msec(), self.clock.time_msec(),
) )
@defer.inlineCallbacks async def set_user_deactivated_status(
def set_user_deactivated_status(self, user_id, deactivated): self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value. """Set the `deactivated` property for the provided user to the provided value.
Args: Args:
user_id (str): The ID of the user to set the status for. user_id: The ID of the user to set the status for.
deactivated (bool): The value to set for `deactivated`. deactivated: The value to set for `deactivated`.
""" """
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"set_user_deactivated_status", "set_user_deactivated_status",
self.set_user_deactivated_status_txn, self.set_user_deactivated_status_txn,
user_id, user_id,
@ -1507,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,) txn, self.get_user_deactivated_status, (user_id,)
) )
txn.call_after(self.is_guest.invalidate, (user_id,))
@defer.inlineCallbacks async def _set_expiration_date_when_missing(self):
def _set_expiration_date_when_missing(self):
""" """
Retrieves the list of registered users that don't have an expiration date, and Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them. adds an expiration date for each of them.
@ -1533,7 +1510,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, user["name"], use_delta=True txn, user["name"], use_delta=True
) )
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"get_users_with_no_expiration_date", "get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn, select_users_with_no_expiration_date_txn,
) )

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
@ -198,8 +196,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"], columns=["room_id"],
) )
@defer.inlineCallbacks async def _background_deduplicate_state(self, progress, batch_size):
def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding """This background update will slowly deduplicate state by reencoding
them as deltas. them as deltas.
""" """
@ -212,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
if max_group is None: if max_group is None:
rows = yield self.db_pool.execute( rows = await self.db_pool.execute(
"_background_deduplicate_state", "_background_deduplicate_state",
None, None,
"SELECT coalesce(max(id), 0) FROM state_groups", "SELECT coalesce(max(id), 0) FROM state_groups",
@ -330,19 +327,18 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return False, batch_size return False, batch_size
finished, result = yield self.db_pool.runInteraction( finished, result = await self.db_pool.runInteraction(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
) )
if finished: if finished:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
) )
return result * BATCH_SIZE_SCALE_FACTOR return result * BATCH_SIZE_SCALE_FACTOR
@defer.inlineCallbacks async def _background_index_state(self, progress, batch_size):
def _background_index_state(self, progress, batch_size):
def reindex_txn(conn): def reindex_txn(conn):
conn.rollback() conn.rollback()
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
@ -365,9 +361,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
) )
txn.execute("DROP INDEX IF EXISTS state_groups_state_id") txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
yield self.db_pool.runWithConnection(reindex_txn) await self.db_pool.runWithConnection(reindex_txn)
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.STATE_GROUP_INDEX_UPDATE_NAME self.STATE_GROUP_INDEX_UPDATE_NAME
) )

View File

@ -22,6 +22,7 @@ from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import RoomAlias, UserID, create_requester from synapse.types import RoomAlias, UserID, create_requester
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
from .. import unittest from .. import unittest
@ -187,7 +188,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
self.store.is_real_user = Mock(return_value=defer.succeed(False)) self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support")) user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)
@ -199,8 +200,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self): def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.store.count_real_users = Mock(return_value=defer.succeed(1)) self.store.count_real_users = Mock(return_value=make_awaitable(1))
self.store.is_real_user = Mock(return_value=defer.succeed(True)) self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
@ -214,8 +215,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test" room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str] self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=defer.succeed(2)) self.store.count_real_users = Mock(return_value=make_awaitable(2))
self.store.is_real_user = Mock(return_value=defer.succeed(True)) self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real")) user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id)) rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0) self.assertEqual(len(rooms), 0)

View File

@ -300,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.register_user(user_id=user2, password_hash=None)) self.get_success(self.store.register_user(user_id=user2, password_hash=None))
now = int(self.hs.get_clock().time_msec()) now = int(self.hs.get_clock().time_msec())
self.get_success(
self.store.user_add_threepid(user1, "email", user1_email, now, now) self.store.user_add_threepid(user1, "email", user1_email, now, now)
)
self.get_success(
self.store.user_add_threepid(user2, "email", user2_email, now, now) self.store.user_add_threepid(user2, "email", user2_email, now, now)
)
users = self.get_success(self.store.get_registered_reserved_users()) users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), len(threepids)) self.assertEqual(len(users), len(threepids))

View File

@ -58,9 +58,11 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_tokens(self): def test_add_tokens(self):
yield self.store.register_user(self.user_id, self.pwhash) yield self.store.register_user(self.user_id, self.pwhash)
yield self.store.add_access_token_to_user( yield defer.ensureDeferred(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
) )
)
result = yield self.store.get_user_by_access_token(self.tokens[1]) result = yield self.store.get_user_by_access_token(self.tokens[1])
@ -74,12 +76,16 @@ class RegistrationStoreTestCase(unittest.TestCase):
def test_user_delete_access_tokens(self): def test_user_delete_access_tokens(self):
# add some tokens # add some tokens
yield self.store.register_user(self.user_id, self.pwhash) yield self.store.register_user(self.user_id, self.pwhash)
yield self.store.add_access_token_to_user( yield defer.ensureDeferred(
self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
) )
yield self.store.add_access_token_to_user( )
yield defer.ensureDeferred(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
) )
)
# now delete some # now delete some
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(