Converts event_federation and registration databases to async/await (#8061)

This commit is contained in:
Patrick Cloke 2020-08-11 17:21:13 -04:00 committed by GitHub
parent 61d8ff0d44
commit a0acdfa9e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 150 additions and 177 deletions

1
changelog.d/8061.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -15,9 +15,7 @@
import itertools
import logging
from queue import Empty, PriorityQueue
from typing import Dict, List, Optional, Set, Tuple
from twisted.internet import defer
from typing import Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
@ -286,17 +284,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return dict(txn)
@defer.inlineCallbacks
def get_max_depth_of(self, event_ids):
async def get_max_depth_of(self, event_ids: List[str]) -> int:
"""Returns the max depth of a set of event IDs
Args:
event_ids (list[str])
Returns
Deferred[int]
event_ids: The event IDs to calculate the max depth of.
"""
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
@ -550,9 +544,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return event_results
@defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events, limit):
ids = yield self.db_pool.runInteraction(
async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
ids = await self.db_pool.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id,
@ -560,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
events = yield self.get_events_as_list(ids)
events = await self.get_events_as_list(ids)
return events
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@ -595,17 +588,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_results.reverse()
return event_results
@defer.inlineCallbacks
def get_successor_events(self, event_ids):
async def get_successor_events(self, event_ids: Iterable[str]) -> List[str]:
"""Fetch all events that have the given events as a prev event
Args:
event_ids (iterable[str])
Returns:
Deferred[list[str]]
event_ids: The events to use as the previous events.
"""
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
@ -674,8 +663,7 @@ class EventFederationStore(EventFederationWorkerStore):
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
@defer.inlineCallbacks
def _background_delete_non_state_event_auth(self, progress, batch_size):
async def _background_delete_non_state_event_auth(self, progress, batch_size):
def delete_event_auth(txn):
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive")
@ -714,12 +702,12 @@ class EventFederationStore(EventFederationWorkerStore):
return min_stream_id >= target_min_stream_id
result = yield self.db_pool.runInteraction(
result = await self.db_pool.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
)
if not result:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.EVENT_AUTH_STATE_ONLY
)

View File

@ -17,9 +17,8 @@
import logging
import re
from typing import Optional
from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes
@ -30,7 +29,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import cached
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
@ -69,19 +68,15 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_id",
)
@defer.inlineCallbacks
def is_trial_user(self, user_id):
async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config
Args:
user_id (str)
Returns:
Deferred[bool]
user_id: The user to check for trial status.
"""
info = yield self.get_user_by_id(user_id)
info = await self.get_user_by_id(user_id)
if not info:
return False
@ -105,41 +100,42 @@ class RegistrationWorkerStore(SQLBaseStore):
"get_user_by_access_token", self._query_for_auth, token
)
@cachedInlineCallbacks()
def get_expiration_ts_for_user(self, user_id):
@cached()
async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
user_id (str): The ID of the user.
user_id: The ID of the user.
Returns:
defer.Deferred: None, if the account has no expiration timestamp,
otherwise int representation of the timestamp (as a number of
milliseconds since epoch).
None, if the account has no expiration timestamp, otherwise int
representation of the timestamp (as a number of milliseconds since epoch).
"""
res = yield self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
desc="get_expiration_ts_for_user",
)
return res
@defer.inlineCallbacks
def set_account_validity_for_user(
self, user_id, expiration_ts, email_sent, renewal_token=None
):
async def set_account_validity_for_user(
self,
user_id: str,
expiration_ts: int,
email_sent: bool,
renewal_token: Optional[str] = None,
) -> None:
"""Updates the account validity properties of the given account, with the
given values.
Args:
user_id (str): ID of the account to update properties for.
expiration_ts (int): New expiration date, as a timestamp in milliseconds
user_id: ID of the account to update properties for.
expiration_ts: New expiration date, as a timestamp in milliseconds
since epoch.
email_sent (bool): True means a renewal email has been sent for this
account and there's no need to send another one for the current validity
email_sent: True means a renewal email has been sent for this account
and there's no need to send another one for the current validity
period.
renewal_token (str): Renewal token the user can use to extend the validity
renewal_token: Renewal token the user can use to extend the validity
of their account. Defaults to no token.
"""
@ -158,75 +154,69 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_expiration_ts_for_user, (user_id,)
)
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"set_account_validity_for_user", set_account_validity_for_user_txn
)
@defer.inlineCallbacks
def set_renewal_token_for_user(self, user_id, renewal_token):
async def set_renewal_token_for_user(
self, user_id: str, renewal_token: str
) -> None:
"""Defines a renewal token for a given user.
Args:
user_id (str): ID of the user to set the renewal token for.
renewal_token (str): Random unique string that will be used to renew the
user_id: ID of the user to set the renewal token for.
renewal_token: Random unique string that will be used to renew the
user's account.
Raises:
StoreError: The provided token is already set for another user.
"""
yield self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
desc="set_renewal_token_for_user",
)
@defer.inlineCallbacks
def get_user_from_renewal_token(self, renewal_token):
async def get_user_from_renewal_token(self, renewal_token: str) -> str:
"""Get a user ID from a renewal token.
Args:
renewal_token (str): The renewal token to perform the lookup with.
renewal_token: The renewal token to perform the lookup with.
Returns:
defer.Deferred[str]: The ID of the user to which the token belongs.
The ID of the user to which the token belongs.
"""
res = yield self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
desc="get_user_from_renewal_token",
)
return res
@defer.inlineCallbacks
def get_renewal_token_for_user(self, user_id):
async def get_renewal_token_for_user(self, user_id: str) -> str:
"""Get the renewal token associated with a given user ID.
Args:
user_id (str): The user ID to lookup a token for.
user_id: The user ID to lookup a token for.
Returns:
defer.Deferred[str]: The renewal token associated with this user ID.
The renewal token associated with this user ID.
"""
res = yield self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
desc="get_renewal_token_for_user",
)
return res
@defer.inlineCallbacks
def get_users_expiring_soon(self):
async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
"""Selects users whose account will expire in the [now, now + renew_at] time
window (see configuration for account_validity for information on what renew_at
refers to).
Returns:
Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
A list of dictionaries mapping user ID to expiration time (in milliseconds).
"""
def select_users_txn(txn, now_ms, renew_at):
@ -238,53 +228,49 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, values)
return self.db_pool.cursor_to_dict(txn)
res = yield self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
self.config.account_validity.renew_at,
)
return res
@defer.inlineCallbacks
def set_renewal_mail_status(self, user_id, email_sent):
async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None:
"""Sets or unsets the flag that indicates whether a renewal email has been sent
to the user (and the user hasn't renewed their account yet).
Args:
user_id (str): ID of the user to set/unset the flag for.
email_sent (bool): Flag which indicates whether a renewal email has been sent
user_id: ID of the user to set/unset the flag for.
email_sent: Flag which indicates whether a renewal email has been sent
to this user.
"""
yield self.db_pool.simple_update_one(
await self.db_pool.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
desc="set_renewal_mail_status",
)
@defer.inlineCallbacks
def delete_account_validity_for_user(self, user_id):
async def delete_account_validity_for_user(self, user_id: str) -> None:
"""Deletes the entry for the given user in the account validity table, removing
their expiration date and renewal token.
Args:
user_id (str): ID of the user to remove from the account validity table.
user_id: ID of the user to remove from the account validity table.
"""
yield self.db_pool.simple_delete_one(
await self.db_pool.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
)
async def is_server_admin(self, user):
async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
Args:
user (UserID): user ID of the user to test
user: user ID of the user to test
Returns (bool):
Returns:
true iff the user is a server admin, false otherwise.
"""
res = await self.db_pool.simple_select_one_onecol(
@ -332,32 +318,31 @@ class RegistrationWorkerStore(SQLBaseStore):
return None
@cachedInlineCallbacks()
def is_real_user(self, user_id):
@cached()
async def is_real_user(self, user_id: str) -> bool:
"""Determines if the user is a real user, ie does not have a 'user_type'.
Args:
user_id (str): user id to test
user_id: user id to test
Returns:
Deferred[bool]: True if user 'user_type' is null or empty string
True if user 'user_type' is null or empty string
"""
res = yield self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"is_real_user", self.is_real_user_txn, user_id
)
return res
@cached()
def is_support_user(self, user_id):
async def is_support_user(self, user_id: str) -> bool:
"""Determines if the user is of type UserTypes.SUPPORT
Args:
user_id (str): user id to test
user_id: user id to test
Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT
True if user is of type UserTypes.SUPPORT
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
@ -413,8 +398,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_by_external_id",
)
@defer.inlineCallbacks
def count_all_users(self):
async def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
@ -424,8 +408,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]["users"]
return 0
ret = yield self.db_pool.runInteraction("count_users", _count_users)
return ret
return await self.db_pool.runInteraction("count_users", _count_users)
def count_daily_user_type(self):
"""
@ -460,8 +443,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"count_daily_user_type", _count_daily_user_type
)
@defer.inlineCallbacks
def count_nonbridged_users(self):
async def count_nonbridged_users(self):
def _count_users(txn):
txn.execute(
"""
@ -472,11 +454,9 @@ class RegistrationWorkerStore(SQLBaseStore):
(count,) = txn.fetchone()
return count
ret = yield self.db_pool.runInteraction("count_users", _count_users)
return ret
return await self.db_pool.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
def count_real_users(self):
async def count_real_users(self):
"""Counts all users without a special user_type registered on the homeserver."""
def _count_users(txn):
@ -486,8 +466,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]["users"]
return 0
ret = yield self.db_pool.runInteraction("count_real_users", _count_users)
return ret
return await self.db_pool.runInteraction("count_real_users", _count_users)
async def generate_user_id(self) -> str:
"""Generate a suitable localpart for a guest user
@ -537,23 +516,20 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret["user_id"]
return None
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self.db_pool.simple_upsert(
async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
await self.db_pool.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self.db_pool.simple_select_list(
async def user_get_threepids(self, user_id):
return await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
"user_get_threepids",
)
return ret
def user_delete_threepid(self, user_id, medium, address):
return self.db_pool.simple_delete(
@ -668,18 +644,18 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_id_servers_user_bound",
)
@cachedInlineCallbacks()
def get_user_deactivated_status(self, user_id):
@cached()
async def get_user_deactivated_status(self, user_id: str) -> bool:
"""Retrieve the value for the `deactivated` property for the provided user.
Args:
user_id (str): The ID of the user to retrieve the status for.
user_id: The ID of the user to retrieve the status for.
Returns:
defer.Deferred(bool): The requested value.
True if the user was deactivated, false if the user is still active.
"""
res = yield self.db_pool.simple_select_one_onecol(
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
@ -818,8 +794,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
@defer.inlineCallbacks
def _background_update_set_deactivated_flag(self, progress, batch_size):
async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
@ -870,19 +845,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
else:
return False, len(rows)
end, nb_processed = yield self.db_pool.runInteraction(
end, nb_processed = await self.db_pool.runInteraction(
"users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
"users_set_deactivated_flag"
)
return nb_processed
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
async def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
@ -903,11 +877,11 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self.db_pool.updates._end_background_update("user_threepids_grandfather")
await self.db_pool.updates._end_background_update("user_threepids_grandfather")
return 1
@ -937,23 +911,26 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
async def add_access_token_to_user(
self,
user_id: str,
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
) -> None:
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
device_id (str): ID of the device to associate with the access
token
valid_until_ms (int|None): when the token is valid until. None for
no expiry.
user_id: The user ID.
token: The new access token to add.
device_id: ID of the device to associate with the access token
valid_until_ms: when the token is valid until. None for no expiry.
Raises:
StoreError if there was a problem adding this.
"""
next_id = self._access_tokens_id_gen.get_next()
yield self.db_pool.simple_insert(
await self.db_pool.simple_insert(
"access_tokens",
{
"id": next_id,
@ -1097,7 +1074,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str
@ -1241,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return self.db_pool.runInteraction("delete_access_token", f)
@cachedInlineCallbacks()
def is_guest(self, user_id):
res = yield self.db_pool.simple_select_one_onecol(
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@ -1481,16 +1457,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self.clock.time_msec(),
)
@defer.inlineCallbacks
def set_user_deactivated_status(self, user_id, deactivated):
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
Args:
user_id (str): The ID of the user to set the status for.
deactivated (bool): The value to set for `deactivated`.
user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
@ -1507,9 +1484,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
@defer.inlineCallbacks
def _set_expiration_date_when_missing(self):
async def _set_expiration_date_when_missing(self):
"""
Retrieves the list of registered users that don't have an expiration date, and
adds an expiration date for each of them.
@ -1533,7 +1510,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, user["name"], use_delta=True
)
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn,
)

View File

@ -15,8 +15,6 @@
import logging
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
@ -198,8 +196,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"],
)
@defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size):
async def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
@ -212,7 +209,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
if max_group is None:
rows = yield self.db_pool.execute(
rows = await self.db_pool.execute(
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
@ -330,19 +327,18 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return False, batch_size
finished, result = yield self.db_pool.runInteraction(
finished, result = await self.db_pool.runInteraction(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
)
if finished:
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
return result * BATCH_SIZE_SCALE_FACTOR
@defer.inlineCallbacks
def _background_index_state(self, progress, batch_size):
async def _background_index_state(self, progress, batch_size):
def reindex_txn(conn):
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
@ -365,9 +361,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
yield self.db_pool.runWithConnection(reindex_txn)
await self.db_pool.runWithConnection(reindex_txn)
yield self.db_pool.updates._end_background_update(
await self.db_pool.updates._end_background_update(
self.STATE_GROUP_INDEX_UPDATE_NAME
)

View File

@ -22,6 +22,7 @@ from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler
from synapse.types import RoomAlias, UserID, create_requester
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from .. import unittest
@ -187,7 +188,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.is_real_user = Mock(return_value=defer.succeed(False))
self.store.is_real_user = Mock(return_value=make_awaitable(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@ -199,8 +200,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self):
room_alias_str = "#room:test"
self.store.count_real_users = Mock(return_value=defer.succeed(1))
self.store.is_real_user = Mock(return_value=defer.succeed(True))
self.store.count_real_users = Mock(return_value=make_awaitable(1))
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@ -214,8 +215,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=defer.succeed(2))
self.store.is_real_user = Mock(return_value=defer.succeed(True))
self.store.count_real_users = Mock(return_value=make_awaitable(2))
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)

View File

@ -300,8 +300,12 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.get_success(self.store.register_user(user_id=user2, password_hash=None))
now = int(self.hs.get_clock().time_msec())
self.get_success(
self.store.user_add_threepid(user1, "email", user1_email, now, now)
)
self.get_success(
self.store.user_add_threepid(user2, "email", user2_email, now, now)
)
users = self.get_success(self.store.get_registered_reserved_users())
self.assertEqual(len(users), len(threepids))

View File

@ -58,9 +58,11 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_add_tokens(self):
yield self.store.register_user(self.user_id, self.pwhash)
yield self.store.add_access_token_to_user(
yield defer.ensureDeferred(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
result = yield self.store.get_user_by_access_token(self.tokens[1])
@ -74,12 +76,16 @@ class RegistrationStoreTestCase(unittest.TestCase):
def test_user_delete_access_tokens(self):
# add some tokens
yield self.store.register_user(self.user_id, self.pwhash)
yield self.store.add_access_token_to_user(
yield defer.ensureDeferred(
self.store.add_access_token_to_user(
self.user_id, self.tokens[0], device_id=None, valid_until_ms=None
)
yield self.store.add_access_token_to_user(
)
yield defer.ensureDeferred(
self.store.add_access_token_to_user(
self.user_id, self.tokens[1], self.device_id, valid_until_ms=None
)
)
# now delete some
yield self.store.user_delete_access_tokens(