Convert directory, e2e_room_keys, end_to_end_keys, monthly_active_users database to async (#8042)

This commit is contained in:
Patrick Cloke 2020-08-07 13:36:29 -04:00 committed by GitHub
parent f3fe6961b2
commit 7f837959ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 141 additions and 120 deletions

1
changelog.d/8042.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -136,7 +136,9 @@ class DeviceWorkerStore(SQLBaseStore):
master_key_by_user = {} master_key_by_user = {}
self_signing_key_by_user = {} self_signing_key_by_user = {}
for user in users: for user in users:
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master") cross_signing_key = yield defer.ensureDeferred(
self.get_e2e_cross_signing_key(user, "master")
)
if cross_signing_key: if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key( key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key cross_signing_key
@ -149,8 +151,8 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version, "device_id": verify_key.version,
} }
cross_signing_key = yield self.get_e2e_cross_signing_key( cross_signing_key = yield defer.ensureDeferred(
user, "self_signing" self.get_e2e_cross_signing_key(user, "self_signing")
) )
if cross_signing_key: if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key( key_id, verify_key = get_verify_key_from_cross_signing_key(
@ -246,7 +248,7 @@ class DeviceWorkerStore(SQLBaseStore):
destination (str): The host the device updates are intended for destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive from_stream_id (int): The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
user_id/device_id to update stream_id and the relevent json-encoded user_id/device_id to update stream_id and the relevant json-encoded
opentracing context opentracing context
Returns: Returns:
@ -599,7 +601,7 @@ class DeviceWorkerStore(SQLBaseStore):
between the requested tokens due to the limit. between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this The token returned can be used in a subsequent call to this
function to get further updatees. function to get further updates.
The updates are a list of 2-tuples of stream ID and the row data The updates are a list of 2-tuples of stream ID and the row data
""" """

View File

@ -14,30 +14,29 @@
# limitations under the License. # limitations under the License.
from collections import namedtuple from collections import namedtuple
from typing import Optional from typing import Iterable, Optional
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers")) RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
class DirectoryWorkerStore(SQLBaseStore): class DirectoryWorkerStore(SQLBaseStore):
@defer.inlineCallbacks async def get_association_from_room_alias(
def get_association_from_room_alias(self, room_alias): self, room_alias: RoomAlias
""" Get's the room_id and server list for a given room_alias ) -> Optional[RoomAliasMapping]:
"""Gets the room_id and server list for a given room_alias
Args: Args:
room_alias (RoomAlias) room_alias: The alias to translate to an ID.
Returns: Returns:
Deferred: results in namedtuple with keys "room_id" and The room alias mapping or None if no association can be found.
"servers" or None if no association can be found
""" """
room_id = yield self.db_pool.simple_select_one_onecol( room_id = await self.db_pool.simple_select_one_onecol(
"room_aliases", "room_aliases",
{"room_alias": room_alias.to_string()}, {"room_alias": room_alias.to_string()},
"room_id", "room_id",
@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
if not room_id: if not room_id:
return None return None
servers = yield self.db_pool.simple_select_onecol( servers = await self.db_pool.simple_select_onecol(
"room_alias_servers", "room_alias_servers",
{"room_alias": room_alias.to_string()}, {"room_alias": room_alias.to_string()},
"server", "server",
@ -79,18 +78,20 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore): class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks async def create_room_alias_association(
def create_room_alias_association(self, room_alias, room_id, servers, creator=None): self,
room_alias: RoomAlias,
room_id: str,
servers: Iterable[str],
creator: Optional[str] = None,
) -> None:
""" Creates an association between a room alias and room_id/servers """ Creates an association between a room alias and room_id/servers
Args: Args:
room_alias (RoomAlias) room_alias: The alias to create.
room_id (str) room_id: The target of the alias.
servers (list) servers: A list of servers through which it may be possible to join the room
creator (str): Optional user_id of creator. creator: Optional user_id of creator.
Returns:
Deferred
""" """
def alias_txn(txn): def alias_txn(txn):
@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore):
) )
try: try:
ret = yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"create_room_alias_association", alias_txn "create_room_alias_association", alias_txn
) )
except self.database_engine.module.IntegrityError: except self.database_engine.module.IntegrityError:
raise SynapseError( raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string() 409, "Room alias %s already exists" % room_alias.to_string()
) )
return ret
@defer.inlineCallbacks async def delete_room_alias(self, room_alias: RoomAlias) -> str:
def delete_room_alias(self, room_alias): room_id = await self.db_pool.runInteraction(
room_id = yield self.db_pool.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias "delete_room_alias", self._delete_room_alias_txn, room_alias
) )
return room_id return room_id
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
txn.execute( txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?", "SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),), (room_alias.to_string(),),

View File

@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
@ -23,8 +21,9 @@ from synapse.util import json_encoder
class EndToEndRoomKeyStore(SQLBaseStore): class EndToEndRoomKeyStore(SQLBaseStore):
@defer.inlineCallbacks async def update_e2e_room_key(
def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key): self, user_id, version, room_id, session_id, room_key
):
"""Replaces the encrypted E2E room key for a given session in a given backup """Replaces the encrypted E2E room key for a given session in a given backup
Args: Args:
@ -37,7 +36,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError StoreError
""" """
yield self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="e2e_room_keys", table="e2e_room_keys",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -54,8 +53,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="update_e2e_room_key", desc="update_e2e_room_key",
) )
@defer.inlineCallbacks async def add_e2e_room_keys(self, user_id, version, room_keys):
def add_e2e_room_keys(self, user_id, version, room_keys):
"""Bulk add room keys to a given backup. """Bulk add room keys to a given backup.
Args: Args:
@ -88,13 +86,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
} }
) )
yield self.db_pool.simple_insert_many( await self.db_pool.simple_insert_many(
table="e2e_room_keys", values=values, desc="add_e2e_room_keys" table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
) )
@trace @trace
@defer.inlineCallbacks async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given """Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session. room, or a given session.
@ -109,7 +106,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
the backup (or for the specified room) the backup (or for the specified room)
Returns: Returns:
A deferred list of dicts giving the session_data and message metadata for A list of dicts giving the session_data and message metadata for
these room keys. these room keys.
""" """
@ -124,7 +121,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id: if session_id:
keyvalues["session_id"] = session_id keyvalues["session_id"] = session_id
rows = yield self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
table="e2e_room_keys", table="e2e_room_keys",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=( retcols=(
@ -242,8 +239,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
) )
@trace @trace
@defer.inlineCallbacks async def delete_e2e_room_keys(
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): self, user_id, version, room_id=None, session_id=None
):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session. room or a given session.
@ -258,7 +256,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
the backup (or for the specified room) the backup (or for the specified room)
Returns: Returns:
A deferred of the deletion transaction The deletion transaction
""" """
keyvalues = {"user_id": user_id, "version": int(version)} keyvalues = {"user_id": user_id, "version": int(version)}
@ -267,7 +265,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id: if session_id:
keyvalues["session_id"] = session_id keyvalues["session_id"] = session_id
yield self.db_pool.simple_delete( await self.db_pool.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys" table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
) )

View File

@ -14,12 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, List, Tuple from typing import Dict, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection from twisted.enterprise.adbapi import Connection
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
@ -31,8 +30,7 @@ from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore): class EndToEndKeyWorkerStore(SQLBaseStore):
@trace @trace
@defer.inlineCallbacks async def get_e2e_device_keys(
def get_e2e_device_keys(
self, query_list, include_all_devices=False, include_deleted_devices=False self, query_list, include_all_devices=False, include_deleted_devices=False
): ):
"""Fetch a list of device keys. """Fetch a list of device keys.
@ -52,7 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list: if not query_list:
return {} return {}
results = yield self.db_pool.runInteraction( results = await self.db_pool.runInteraction(
"get_e2e_device_keys", "get_e2e_device_keys",
self._get_e2e_device_keys_txn, self._get_e2e_device_keys_txn,
query_list, query_list,
@ -175,8 +173,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
log_kv(result) log_kv(result)
return result return result
@defer.inlineCallbacks async def get_e2e_one_time_keys(
def get_e2e_one_time_keys(self, user_id, device_id, key_ids): self, user_id: str, device_id: str, key_ids: List[str]
) -> Dict[Tuple[str, str], str]:
"""Retrieve a number of one-time keys for a user """Retrieve a number of one-time keys for a user
Args: Args:
@ -186,11 +185,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
retrieve retrieve
Returns: Returns:
deferred resolving to Dict[(str, str), str]: map from (algorithm, A map from (algorithm, key_id) to json string for key
key_id) to json string for key
""" """
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json", table="e2e_one_time_keys_json",
column="key_id", column="key_id",
iterable=key_ids, iterable=key_ids,
@ -202,17 +200,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result}) log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result return result
@defer.inlineCallbacks async def add_e2e_one_time_keys(
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): self,
user_id: str,
device_id: str,
time_now: int,
new_keys: Iterable[Tuple[str, str, str]],
) -> None:
"""Insert some new one time keys for a device. Errors if any of the """Insert some new one time keys for a device. Errors if any of the
keys already exist. keys already exist.
Args: Args:
user_id(str): id of user to get keys for user_id: id of user to get keys for
device_id(str): id of device to get keys for device_id: id of device to get keys for
time_now(long): insertion time to record (ms since epoch) time_now: insertion time to record (ms since epoch)
new_keys(iterable[(str, str, str)]: keys to add - each a tuple of new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
(algorithm, key_id, key json)
""" """
def _add_e2e_one_time_keys(txn): def _add_e2e_one_time_keys(txn):
@ -242,7 +244,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
) )
@ -269,22 +271,23 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys "count_e2e_one_time_keys", _count_e2e_one_time_keys
) )
@defer.inlineCallbacks async def get_e2e_cross_signing_key(
def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[dict]:
"""Returns a user's cross-signing key. """Returns a user's cross-signing key.
Args: Args:
user_id (str): the user whose key is being requested user_id: the user whose key is being requested
key_type (str): the type of key that is being requested: either 'master' key_type: the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key 'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on from_user_id: if specified, signatures made by this user on
the self-signing key will be included in the result the self-signing key will be included in the result
Returns: Returns:
dict of the key data or None if not found dict of the key data or None if not found
""" """
res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id) res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
user_keys = res.get(user_id) user_keys = res.get(user_id)
if not user_keys: if not user_keys:
return None return None
@ -450,28 +453,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return keys return keys
@defer.inlineCallbacks async def get_e2e_cross_signing_keys_bulk(
def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None
self, user_ids: List[str], from_user_id: str = None ) -> Dict[str, Dict[str, dict]]:
) -> defer.Deferred:
"""Returns the cross-signing keys for a set of users. """Returns the cross-signing keys for a set of users.
Args: Args:
user_ids (list[str]): the users whose keys are being requested user_ids: the users whose keys are being requested
from_user_id (str): if specified, signatures made by this user on from_user_id: if specified, signatures made by this user on
the self-signing keys will be included in the result the self-signing keys will be included in the result
Returns: Returns:
Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to A map of user ID to key type to key data. If a user's cross-signing
key data. If a user's cross-signing keys were not found, either keys were not found, either their user ID will not be in the dict,
their user ID will not be in the dict, or their user ID will map or their user ID will map to None.
to None.
""" """
result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id: if from_user_id:
result = yield self.db_pool.runInteraction( result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures", "get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn, self._get_e2e_cross_signing_signatures_txn,
result, result,

View File

@ -15,8 +15,6 @@
import logging import logging
from typing import List from typing import List
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -252,16 +250,12 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"reap_monthly_active_users", _reap_users, reserved_users "reap_monthly_active_users", _reap_users, reserved_users
) )
@defer.inlineCallbacks async def upsert_monthly_active_user(self, user_id: str) -> None:
def upsert_monthly_active_user(self, user_id):
"""Updates or inserts the user into the monthly active user table, which """Updates or inserts the user into the monthly active user table, which
is used to track the current MAU usage of the server is used to track the current MAU usage of the server
Args: Args:
user_id (str): user to add/update user_id: user to add/update
Returns:
Deferred
""" """
# Support user never to be included in MAU stats. Note I can't easily call this # Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of # from upsert_monthly_active_user_txn because then I need a _txn form of
@ -271,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# _initialise_reserved_users reasoning that it would be very strange to # _initialise_reserved_users reasoning that it would be very strange to
# include a support user in this context. # include a support user in this context.
is_support = yield self.is_support_user(user_id) is_support = await self.is_support_user(user_id)
if is_support: if is_support:
return return
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id "upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
) )
@ -322,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
return is_insert return is_insert
@defer.inlineCallbacks async def populate_monthly_active_users(self, user_id):
def populate_monthly_active_users(self, user_id):
"""Checks on the state of monthly active user limits and optionally """Checks on the state of monthly active user limits and optionally
add the user to the monthly active tables add the user to the monthly active tables
@ -332,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
""" """
if self._limit_usage_by_mau or self._mau_stats_only: if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group # Trial users and guests should not be included as part of MAU group
is_guest = yield self.is_guest(user_id) is_guest = await self.is_guest(user_id)
if is_guest: if is_guest:
return return
is_trial = yield self.is_trial_user(user_id) is_trial = await self.is_trial_user(user_id)
if is_trial: if is_trial:
return return
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id) last_seen_timestamp = await self.user_last_seen_monthly_active(user_id)
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
# We want to reduce to the total number of db writes, and are happy # We want to reduce to the total number of db writes, and are happy
@ -352,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# False, there is no point in checking get_monthly_active_count - it # False, there is no point in checking get_monthly_active_count - it
# adds no value and will break the logic if max_mau_value is exceeded. # adds no value and will break the logic if max_mau_value is exceeded.
if not self._limit_usage_by_mau: if not self._limit_usage_by_mau:
yield self.upsert_monthly_active_user(user_id) await self.upsert_monthly_active_user(user_id)
else: else:
count = yield self.get_monthly_active_count() count = await self.get_monthly_active_count()
if count < self._max_mau_value: if count < self._max_mau_value:
yield self.upsert_monthly_active_user(user_id) await self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY: elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
yield self.upsert_monthly_active_user(user_id) await self.upsert_monthly_active_user(user_id)

View File

@ -120,7 +120,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.query_alias.return_value = make_awaitable(True) self.mock_as_api.query_alias.return_value = make_awaitable(True)
self.mock_store.get_app_services.return_value = services self.mock_store.get_app_services.return_value = services
self.mock_store.get_association_from_room_alias.return_value = defer.succeed( self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
Mock(room_id=room_id, servers=servers) Mock(room_id=room_id, servers=servers)
) )

View File

@ -34,9 +34,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_room_to_alias(self): def test_room_to_alias(self):
yield self.store.create_room_alias_association( yield defer.ensureDeferred(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
) )
)
self.assertEquals( self.assertEquals(
["#my-room:test"], ["#my-room:test"],
@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_alias_to_room(self): def test_alias_to_room(self):
yield self.store.create_room_alias_association( yield defer.ensureDeferred(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
) )
)
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]}, {"room_id": self.room.to_string(), "servers": ["test"]},
(yield self.store.get_association_from_room_alias(self.alias)), (
yield defer.ensureDeferred(
self.store.get_association_from_room_alias(self.alias)
)
),
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_delete_alias(self): def test_delete_alias(self):
yield self.store.create_room_alias_association( yield defer.ensureDeferred(
self.store.create_room_alias_association(
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"] room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
) )
)
room_id = yield self.store.delete_room_alias(self.alias) room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id) self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone( self.assertIsNone(
(yield self.store.get_association_from_room_alias(self.alias)) (
yield defer.ensureDeferred(
self.store.get_association_from_room_alias(self.alias)
)
)
) )

View File

@ -34,7 +34,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.set_e2e_device_keys("user", "device", now, json) yield self.store.set_e2e_device_keys("user", "device", now, json)
res = yield self.store.get_e2e_device_keys((("user", "device"),)) res = yield defer.ensureDeferred(
self.store.get_e2e_device_keys((("user", "device"),))
)
self.assertIn("user", res) self.assertIn("user", res)
self.assertIn("device", res["user"]) self.assertIn("device", res["user"])
dev = res["user"]["device"] dev = res["user"]["device"]
@ -63,7 +65,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.set_e2e_device_keys("user", "device", now, json) yield self.store.set_e2e_device_keys("user", "device", now, json)
yield self.store.store_device("user", "device", "display_name") yield self.store.store_device("user", "device", "display_name")
res = yield self.store.get_e2e_device_keys((("user", "device"),)) res = yield defer.ensureDeferred(
self.store.get_e2e_device_keys((("user", "device"),))
)
self.assertIn("user", res) self.assertIn("user", res)
self.assertIn("device", res["user"]) self.assertIn("device", res["user"])
dev = res["user"]["device"] dev = res["user"]["device"]
@ -85,8 +89,8 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"}) yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"}) yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
res = yield self.store.get_e2e_device_keys( res = yield defer.ensureDeferred(
(("user1", "device1"), ("user2", "device2")) self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
) )
self.assertIn("user1", res) self.assertIn("user1", res)
self.assertIn("device1", res["user1"]) self.assertIn("device1", res["user1"])

View File

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import default_config, override_config from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60 FORTY_DAYS = 40 * 24 * 60 * 60
@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
) )
self.get_success(d) self.get_success(d)
self.store.upsert_monthly_active_user = Mock() self.store.upsert_monthly_active_user = Mock(
side_effect=lambda user_id: make_awaitable(None)
)
d = self.store.populate_monthly_active_users(user_id) d = self.store.populate_monthly_active_users(user_id)
self.get_success(d) self.get_success(d)
@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called() self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self): def test_populate_monthly_users_should_update(self):
self.store.upsert_monthly_active_user = Mock() self.store.upsert_monthly_active_user = Mock(
side_effect=lambda user_id: make_awaitable(None)
)
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.is_trial_user = Mock(return_value=defer.succeed(False))
@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once() self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self): def test_populate_monthly_users_should_not_update(self):
self.store.upsert_monthly_active_user = Mock() self.store.upsert_monthly_active_user = Mock(
side_effect=lambda user_id: make_awaitable(None)
)
self.store.is_trial_user = Mock(return_value=defer.succeed(False)) self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock( self.store.user_last_seen_monthly_active = Mock(
@ -333,7 +340,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self): def test_no_users_when_not_tracking(self):
self.store.upsert_monthly_active_user = Mock() self.store.upsert_monthly_active_user = Mock(
side_effect=lambda user_id: make_awaitable(None)
)
self.get_success(self.store.populate_monthly_active_users("@user:sever")) self.get_success(self.store.populate_monthly_active_users("@user:sever"))