mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-12 13:14:19 -05:00
Convert directory, e2e_room_keys, end_to_end_keys, monthly_active_users database to async (#8042)
This commit is contained in:
parent
f3fe6961b2
commit
7f837959ea
1
changelog.d/8042.misc
Normal file
1
changelog.d/8042.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert various parts of the codebase to async/await.
|
@ -136,7 +136,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
master_key_by_user = {}
|
master_key_by_user = {}
|
||||||
self_signing_key_by_user = {}
|
self_signing_key_by_user = {}
|
||||||
for user in users:
|
for user in users:
|
||||||
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
|
cross_signing_key = yield defer.ensureDeferred(
|
||||||
|
self.get_e2e_cross_signing_key(user, "master")
|
||||||
|
)
|
||||||
if cross_signing_key:
|
if cross_signing_key:
|
||||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||||
cross_signing_key
|
cross_signing_key
|
||||||
@ -149,8 +151,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
"device_id": verify_key.version,
|
"device_id": verify_key.version,
|
||||||
}
|
}
|
||||||
|
|
||||||
cross_signing_key = yield self.get_e2e_cross_signing_key(
|
cross_signing_key = yield defer.ensureDeferred(
|
||||||
user, "self_signing"
|
self.get_e2e_cross_signing_key(user, "self_signing")
|
||||||
)
|
)
|
||||||
if cross_signing_key:
|
if cross_signing_key:
|
||||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||||
@ -246,7 +248,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
destination (str): The host the device updates are intended for
|
destination (str): The host the device updates are intended for
|
||||||
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||||
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
|
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
|
||||||
user_id/device_id to update stream_id and the relevent json-encoded
|
user_id/device_id to update stream_id and the relevant json-encoded
|
||||||
opentracing context
|
opentracing context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -599,7 +601,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
between the requested tokens due to the limit.
|
between the requested tokens due to the limit.
|
||||||
|
|
||||||
The token returned can be used in a subsequent call to this
|
The token returned can be used in a subsequent call to this
|
||||||
function to get further updatees.
|
function to get further updates.
|
||||||
|
|
||||||
The updates are a list of 2-tuples of stream ID and the row data
|
The updates are a list of 2-tuples of stream ID and the row data
|
||||||
"""
|
"""
|
||||||
|
@ -14,30 +14,29 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Optional
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
from synapse.types import RoomAlias
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
|
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
|
||||||
|
|
||||||
|
|
||||||
class DirectoryWorkerStore(SQLBaseStore):
|
class DirectoryWorkerStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
async def get_association_from_room_alias(
|
||||||
def get_association_from_room_alias(self, room_alias):
|
self, room_alias: RoomAlias
|
||||||
""" Get's the room_id and server list for a given room_alias
|
) -> Optional[RoomAliasMapping]:
|
||||||
|
"""Gets the room_id and server list for a given room_alias
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_alias (RoomAlias)
|
room_alias: The alias to translate to an ID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: results in namedtuple with keys "room_id" and
|
The room alias mapping or None if no association can be found.
|
||||||
"servers" or None if no association can be found
|
|
||||||
"""
|
"""
|
||||||
room_id = yield self.db_pool.simple_select_one_onecol(
|
room_id = await self.db_pool.simple_select_one_onecol(
|
||||||
"room_aliases",
|
"room_aliases",
|
||||||
{"room_alias": room_alias.to_string()},
|
{"room_alias": room_alias.to_string()},
|
||||||
"room_id",
|
"room_id",
|
||||||
@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
|
|||||||
if not room_id:
|
if not room_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
servers = yield self.db_pool.simple_select_onecol(
|
servers = await self.db_pool.simple_select_onecol(
|
||||||
"room_alias_servers",
|
"room_alias_servers",
|
||||||
{"room_alias": room_alias.to_string()},
|
{"room_alias": room_alias.to_string()},
|
||||||
"server",
|
"server",
|
||||||
@ -79,18 +78,20 @@ class DirectoryWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
|
|
||||||
class DirectoryStore(DirectoryWorkerStore):
|
class DirectoryStore(DirectoryWorkerStore):
|
||||||
@defer.inlineCallbacks
|
async def create_room_alias_association(
|
||||||
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
|
self,
|
||||||
|
room_alias: RoomAlias,
|
||||||
|
room_id: str,
|
||||||
|
servers: Iterable[str],
|
||||||
|
creator: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
""" Creates an association between a room alias and room_id/servers
|
""" Creates an association between a room alias and room_id/servers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_alias (RoomAlias)
|
room_alias: The alias to create.
|
||||||
room_id (str)
|
room_id: The target of the alias.
|
||||||
servers (list)
|
servers: A list of servers through which it may be possible to join the room
|
||||||
creator (str): Optional user_id of creator.
|
creator: Optional user_id of creator.
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def alias_txn(txn):
|
def alias_txn(txn):
|
||||||
@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ret = yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"create_room_alias_association", alias_txn
|
"create_room_alias_association", alias_txn
|
||||||
)
|
)
|
||||||
except self.database_engine.module.IntegrityError:
|
except self.database_engine.module.IntegrityError:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
409, "Room alias %s already exists" % room_alias.to_string()
|
409, "Room alias %s already exists" % room_alias.to_string()
|
||||||
)
|
)
|
||||||
return ret
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def delete_room_alias(self, room_alias: RoomAlias) -> str:
|
||||||
def delete_room_alias(self, room_alias):
|
room_id = await self.db_pool.runInteraction(
|
||||||
room_id = yield self.db_pool.runInteraction(
|
|
||||||
"delete_room_alias", self._delete_room_alias_txn, room_alias
|
"delete_room_alias", self._delete_room_alias_txn, room_alias
|
||||||
)
|
)
|
||||||
|
|
||||||
return room_id
|
return room_id
|
||||||
|
|
||||||
def _delete_room_alias_txn(self, txn, room_alias):
|
def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
|
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
|
||||||
(room_alias.to_string(),),
|
(room_alias.to_string(),),
|
||||||
|
@ -14,8 +14,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.logging.opentracing import log_kv, trace
|
from synapse.logging.opentracing import log_kv, trace
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
@ -23,8 +21,9 @@ from synapse.util import json_encoder
|
|||||||
|
|
||||||
|
|
||||||
class EndToEndRoomKeyStore(SQLBaseStore):
|
class EndToEndRoomKeyStore(SQLBaseStore):
|
||||||
@defer.inlineCallbacks
|
async def update_e2e_room_key(
|
||||||
def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
|
self, user_id, version, room_id, session_id, room_key
|
||||||
|
):
|
||||||
"""Replaces the encrypted E2E room key for a given session in a given backup
|
"""Replaces the encrypted E2E room key for a given session in a given backup
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -37,7 +36,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
StoreError
|
StoreError
|
||||||
"""
|
"""
|
||||||
|
|
||||||
yield self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="e2e_room_keys",
|
table="e2e_room_keys",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@ -54,8 +53,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
desc="update_e2e_room_key",
|
desc="update_e2e_room_key",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def add_e2e_room_keys(self, user_id, version, room_keys):
|
||||||
def add_e2e_room_keys(self, user_id, version, room_keys):
|
|
||||||
"""Bulk add room keys to a given backup.
|
"""Bulk add room keys to a given backup.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -88,13 +86,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.db_pool.simple_insert_many(
|
await self.db_pool.simple_insert_many(
|
||||||
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
|
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||||
def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
|
|
||||||
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
||||||
room, or a given session.
|
room, or a given session.
|
||||||
|
|
||||||
@ -109,7 +106,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
the backup (or for the specified room)
|
the backup (or for the specified room)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred list of dicts giving the session_data and message metadata for
|
A list of dicts giving the session_data and message metadata for
|
||||||
these room keys.
|
these room keys.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -124,7 +121,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
if session_id:
|
if session_id:
|
||||||
keyvalues["session_id"] = session_id
|
keyvalues["session_id"] = session_id
|
||||||
|
|
||||||
rows = yield self.db_pool.simple_select_list(
|
rows = await self.db_pool.simple_select_list(
|
||||||
table="e2e_room_keys",
|
table="e2e_room_keys",
|
||||||
keyvalues=keyvalues,
|
keyvalues=keyvalues,
|
||||||
retcols=(
|
retcols=(
|
||||||
@ -242,8 +239,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
async def delete_e2e_room_keys(
|
||||||
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
|
self, user_id, version, room_id=None, session_id=None
|
||||||
|
):
|
||||||
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
|
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
|
||||||
room or a given session.
|
room or a given session.
|
||||||
|
|
||||||
@ -258,7 +256,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
the backup (or for the specified room)
|
the backup (or for the specified room)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred of the deletion transaction
|
The deletion transaction
|
||||||
"""
|
"""
|
||||||
|
|
||||||
keyvalues = {"user_id": user_id, "version": int(version)}
|
keyvalues = {"user_id": user_id, "version": int(version)}
|
||||||
@ -267,7 +265,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
if session_id:
|
if session_id:
|
||||||
keyvalues["session_id"] = session_id
|
keyvalues["session_id"] = session_id
|
||||||
|
|
||||||
yield self.db_pool.simple_delete(
|
await self.db_pool.simple_delete(
|
||||||
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
|
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,12 +14,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from twisted.enterprise.adbapi import Connection
|
from twisted.enterprise.adbapi import Connection
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
@ -31,8 +30,7 @@ from synapse.util.iterutils import batch_iter
|
|||||||
|
|
||||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
async def get_e2e_device_keys(
|
||||||
def get_e2e_device_keys(
|
|
||||||
self, query_list, include_all_devices=False, include_deleted_devices=False
|
self, query_list, include_all_devices=False, include_deleted_devices=False
|
||||||
):
|
):
|
||||||
"""Fetch a list of device keys.
|
"""Fetch a list of device keys.
|
||||||
@ -52,7 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
if not query_list:
|
if not query_list:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
results = yield self.db_pool.runInteraction(
|
results = await self.db_pool.runInteraction(
|
||||||
"get_e2e_device_keys",
|
"get_e2e_device_keys",
|
||||||
self._get_e2e_device_keys_txn,
|
self._get_e2e_device_keys_txn,
|
||||||
query_list,
|
query_list,
|
||||||
@ -175,8 +173,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
log_kv(result)
|
log_kv(result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_e2e_one_time_keys(
|
||||||
def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
|
self, user_id: str, device_id: str, key_ids: List[str]
|
||||||
|
) -> Dict[Tuple[str, str], str]:
|
||||||
"""Retrieve a number of one-time keys for a user
|
"""Retrieve a number of one-time keys for a user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -186,11 +185,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
retrieve
|
retrieve
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
deferred resolving to Dict[(str, str), str]: map from (algorithm,
|
A map from (algorithm, key_id) to json string for key
|
||||||
key_id) to json string for key
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="e2e_one_time_keys_json",
|
table="e2e_one_time_keys_json",
|
||||||
column="key_id",
|
column="key_id",
|
||||||
iterable=key_ids,
|
iterable=key_ids,
|
||||||
@ -202,17 +200,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
|
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def add_e2e_one_time_keys(
|
||||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
|
self,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
time_now: int,
|
||||||
|
new_keys: Iterable[Tuple[str, str, str]],
|
||||||
|
) -> None:
|
||||||
"""Insert some new one time keys for a device. Errors if any of the
|
"""Insert some new one time keys for a device. Errors if any of the
|
||||||
keys already exist.
|
keys already exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): id of user to get keys for
|
user_id: id of user to get keys for
|
||||||
device_id(str): id of device to get keys for
|
device_id: id of device to get keys for
|
||||||
time_now(long): insertion time to record (ms since epoch)
|
time_now: insertion time to record (ms since epoch)
|
||||||
new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
|
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
|
||||||
(algorithm, key_id, key json)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _add_e2e_one_time_keys(txn):
|
def _add_e2e_one_time_keys(txn):
|
||||||
@ -242,7 +244,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -269,22 +271,23 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_e2e_cross_signing_key(
|
||||||
def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
|
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
||||||
|
) -> Optional[dict]:
|
||||||
"""Returns a user's cross-signing key.
|
"""Returns a user's cross-signing key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): the user whose key is being requested
|
user_id: the user whose key is being requested
|
||||||
key_type (str): the type of key that is being requested: either 'master'
|
key_type: the type of key that is being requested: either 'master'
|
||||||
for a master key, 'self_signing' for a self-signing key, or
|
for a master key, 'self_signing' for a self-signing key, or
|
||||||
'user_signing' for a user-signing key
|
'user_signing' for a user-signing key
|
||||||
from_user_id (str): if specified, signatures made by this user on
|
from_user_id: if specified, signatures made by this user on
|
||||||
the self-signing key will be included in the result
|
the self-signing key will be included in the result
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict of the key data or None if not found
|
dict of the key data or None if not found
|
||||||
"""
|
"""
|
||||||
res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
|
res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
|
||||||
user_keys = res.get(user_id)
|
user_keys = res.get(user_id)
|
||||||
if not user_keys:
|
if not user_keys:
|
||||||
return None
|
return None
|
||||||
@ -450,28 +453,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return keys
|
return keys
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_e2e_cross_signing_keys_bulk(
|
||||||
def get_e2e_cross_signing_keys_bulk(
|
self, user_ids: List[str], from_user_id: Optional[str] = None
|
||||||
self, user_ids: List[str], from_user_id: str = None
|
) -> Dict[str, Dict[str, dict]]:
|
||||||
) -> defer.Deferred:
|
|
||||||
"""Returns the cross-signing keys for a set of users.
|
"""Returns the cross-signing keys for a set of users.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_ids (list[str]): the users whose keys are being requested
|
user_ids: the users whose keys are being requested
|
||||||
from_user_id (str): if specified, signatures made by this user on
|
from_user_id: if specified, signatures made by this user on
|
||||||
the self-signing keys will be included in the result
|
the self-signing keys will be included in the result
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
|
A map of user ID to key type to key data. If a user's cross-signing
|
||||||
key data. If a user's cross-signing keys were not found, either
|
keys were not found, either their user ID will not be in the dict,
|
||||||
their user ID will not be in the dict, or their user ID will map
|
or their user ID will map to None.
|
||||||
to None.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
|
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
|
||||||
|
|
||||||
if from_user_id:
|
if from_user_id:
|
||||||
result = yield self.db_pool.runInteraction(
|
result = await self.db_pool.runInteraction(
|
||||||
"get_e2e_cross_signing_signatures",
|
"get_e2e_cross_signing_signatures",
|
||||||
self._get_e2e_cross_signing_signatures_txn,
|
self._get_e2e_cross_signing_signatures_txn,
|
||||||
result,
|
result,
|
||||||
|
@ -15,8 +15,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
|
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
@ -252,16 +250,12 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
|||||||
"reap_monthly_active_users", _reap_users, reserved_users
|
"reap_monthly_active_users", _reap_users, reserved_users
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def upsert_monthly_active_user(self, user_id: str) -> None:
|
||||||
def upsert_monthly_active_user(self, user_id):
|
|
||||||
"""Updates or inserts the user into the monthly active user table, which
|
"""Updates or inserts the user into the monthly active user table, which
|
||||||
is used to track the current MAU usage of the server
|
is used to track the current MAU usage of the server
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): user to add/update
|
user_id: user to add/update
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred
|
|
||||||
"""
|
"""
|
||||||
# Support user never to be included in MAU stats. Note I can't easily call this
|
# Support user never to be included in MAU stats. Note I can't easily call this
|
||||||
# from upsert_monthly_active_user_txn because then I need a _txn form of
|
# from upsert_monthly_active_user_txn because then I need a _txn form of
|
||||||
@ -271,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
|||||||
# _initialise_reserved_users reasoning that it would be very strange to
|
# _initialise_reserved_users reasoning that it would be very strange to
|
||||||
# include a support user in this context.
|
# include a support user in this context.
|
||||||
|
|
||||||
is_support = yield self.is_support_user(user_id)
|
is_support = await self.is_support_user(user_id)
|
||||||
if is_support:
|
if is_support:
|
||||||
return
|
return
|
||||||
|
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
|
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -322,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
|||||||
|
|
||||||
return is_insert
|
return is_insert
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def populate_monthly_active_users(self, user_id):
|
||||||
def populate_monthly_active_users(self, user_id):
|
|
||||||
"""Checks on the state of monthly active user limits and optionally
|
"""Checks on the state of monthly active user limits and optionally
|
||||||
add the user to the monthly active tables
|
add the user to the monthly active tables
|
||||||
|
|
||||||
@ -332,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
|||||||
"""
|
"""
|
||||||
if self._limit_usage_by_mau or self._mau_stats_only:
|
if self._limit_usage_by_mau or self._mau_stats_only:
|
||||||
# Trial users and guests should not be included as part of MAU group
|
# Trial users and guests should not be included as part of MAU group
|
||||||
is_guest = yield self.is_guest(user_id)
|
is_guest = await self.is_guest(user_id)
|
||||||
if is_guest:
|
if is_guest:
|
||||||
return
|
return
|
||||||
is_trial = yield self.is_trial_user(user_id)
|
is_trial = await self.is_trial_user(user_id)
|
||||||
if is_trial:
|
if is_trial:
|
||||||
return
|
return
|
||||||
|
|
||||||
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
|
last_seen_timestamp = await self.user_last_seen_monthly_active(user_id)
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
|
|
||||||
# We want to reduce to the total number of db writes, and are happy
|
# We want to reduce to the total number of db writes, and are happy
|
||||||
@ -352,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
|||||||
# False, there is no point in checking get_monthly_active_count - it
|
# False, there is no point in checking get_monthly_active_count - it
|
||||||
# adds no value and will break the logic if max_mau_value is exceeded.
|
# adds no value and will break the logic if max_mau_value is exceeded.
|
||||||
if not self._limit_usage_by_mau:
|
if not self._limit_usage_by_mau:
|
||||||
yield self.upsert_monthly_active_user(user_id)
|
await self.upsert_monthly_active_user(user_id)
|
||||||
else:
|
else:
|
||||||
count = yield self.get_monthly_active_count()
|
count = await self.get_monthly_active_count()
|
||||||
if count < self._max_mau_value:
|
if count < self._max_mau_value:
|
||||||
yield self.upsert_monthly_active_user(user_id)
|
await self.upsert_monthly_active_user(user_id)
|
||||||
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
|
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
|
||||||
yield self.upsert_monthly_active_user(user_id)
|
await self.upsert_monthly_active_user(user_id)
|
||||||
|
@ -120,7 +120,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.mock_as_api.query_alias.return_value = make_awaitable(True)
|
self.mock_as_api.query_alias.return_value = make_awaitable(True)
|
||||||
self.mock_store.get_app_services.return_value = services
|
self.mock_store.get_app_services.return_value = services
|
||||||
self.mock_store.get_association_from_room_alias.return_value = defer.succeed(
|
self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
|
||||||
Mock(room_id=room_id, servers=servers)
|
Mock(room_id=room_id, servers=servers)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,8 +34,10 @@ class DirectoryStoreTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_room_to_alias(self):
|
def test_room_to_alias(self):
|
||||||
yield self.store.create_room_alias_association(
|
yield defer.ensureDeferred(
|
||||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
self.store.create_room_alias_association(
|
||||||
|
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_alias_to_room(self):
|
def test_alias_to_room(self):
|
||||||
yield self.store.create_room_alias_association(
|
yield defer.ensureDeferred(
|
||||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
self.store.create_room_alias_association(
|
||||||
|
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertObjectHasAttributes(
|
self.assertObjectHasAttributes(
|
||||||
{"room_id": self.room.to_string(), "servers": ["test"]},
|
{"room_id": self.room.to_string(), "servers": ["test"]},
|
||||||
(yield self.store.get_association_from_room_alias(self.alias)),
|
(
|
||||||
|
yield defer.ensureDeferred(
|
||||||
|
self.store.get_association_from_room_alias(self.alias)
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_delete_alias(self):
|
def test_delete_alias(self):
|
||||||
yield self.store.create_room_alias_association(
|
yield defer.ensureDeferred(
|
||||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
self.store.create_room_alias_association(
|
||||||
|
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
room_id = yield self.store.delete_room_alias(self.alias)
|
room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
|
||||||
self.assertEqual(self.room.to_string(), room_id)
|
self.assertEqual(self.room.to_string(), room_id)
|
||||||
|
|
||||||
self.assertIsNone(
|
self.assertIsNone(
|
||||||
(yield self.store.get_association_from_room_alias(self.alias))
|
(
|
||||||
|
yield defer.ensureDeferred(
|
||||||
|
self.store.get_association_from_room_alias(self.alias)
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
@ -34,7 +34,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||||||
|
|
||||||
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
||||||
|
|
||||||
res = yield self.store.get_e2e_device_keys((("user", "device"),))
|
res = yield defer.ensureDeferred(
|
||||||
|
self.store.get_e2e_device_keys((("user", "device"),))
|
||||||
|
)
|
||||||
self.assertIn("user", res)
|
self.assertIn("user", res)
|
||||||
self.assertIn("device", res["user"])
|
self.assertIn("device", res["user"])
|
||||||
dev = res["user"]["device"]
|
dev = res["user"]["device"]
|
||||||
@ -63,7 +65,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||||||
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
||||||
yield self.store.store_device("user", "device", "display_name")
|
yield self.store.store_device("user", "device", "display_name")
|
||||||
|
|
||||||
res = yield self.store.get_e2e_device_keys((("user", "device"),))
|
res = yield defer.ensureDeferred(
|
||||||
|
self.store.get_e2e_device_keys((("user", "device"),))
|
||||||
|
)
|
||||||
self.assertIn("user", res)
|
self.assertIn("user", res)
|
||||||
self.assertIn("device", res["user"])
|
self.assertIn("device", res["user"])
|
||||||
dev = res["user"]["device"]
|
dev = res["user"]["device"]
|
||||||
@ -85,8 +89,8 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||||||
yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
|
yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
|
||||||
yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
|
yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
|
||||||
|
|
||||||
res = yield self.store.get_e2e_device_keys(
|
res = yield defer.ensureDeferred(
|
||||||
(("user1", "device1"), ("user2", "device2"))
|
self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
|
||||||
)
|
)
|
||||||
self.assertIn("user1", res)
|
self.assertIn("user1", res)
|
||||||
self.assertIn("device1", res["user1"])
|
self.assertIn("device1", res["user1"])
|
||||||
|
@ -19,6 +19,7 @@ from twisted.internet import defer
|
|||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
from tests.unittest import default_config, override_config
|
from tests.unittest import default_config, override_config
|
||||||
|
|
||||||
FORTY_DAYS = 40 * 24 * 60 * 60
|
FORTY_DAYS = 40 * 24 * 60 * 60
|
||||||
@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
self.store.upsert_monthly_active_user = Mock()
|
self.store.upsert_monthly_active_user = Mock(
|
||||||
|
side_effect=lambda user_id: make_awaitable(None)
|
||||||
|
)
|
||||||
|
|
||||||
d = self.store.populate_monthly_active_users(user_id)
|
d = self.store.populate_monthly_active_users(user_id)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
|||||||
self.store.upsert_monthly_active_user.assert_not_called()
|
self.store.upsert_monthly_active_user.assert_not_called()
|
||||||
|
|
||||||
def test_populate_monthly_users_should_update(self):
|
def test_populate_monthly_users_should_update(self):
|
||||||
self.store.upsert_monthly_active_user = Mock()
|
self.store.upsert_monthly_active_user = Mock(
|
||||||
|
side_effect=lambda user_id: make_awaitable(None)
|
||||||
|
)
|
||||||
|
|
||||||
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
|
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
|
||||||
|
|
||||||
@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
|||||||
self.store.upsert_monthly_active_user.assert_called_once()
|
self.store.upsert_monthly_active_user.assert_called_once()
|
||||||
|
|
||||||
def test_populate_monthly_users_should_not_update(self):
|
def test_populate_monthly_users_should_not_update(self):
|
||||||
self.store.upsert_monthly_active_user = Mock()
|
self.store.upsert_monthly_active_user = Mock(
|
||||||
|
side_effect=lambda user_id: make_awaitable(None)
|
||||||
|
)
|
||||||
|
|
||||||
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
|
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
|
||||||
self.store.user_last_seen_monthly_active = Mock(
|
self.store.user_last_seen_monthly_active = Mock(
|
||||||
@ -333,7 +340,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
|
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
|
||||||
def test_no_users_when_not_tracking(self):
|
def test_no_users_when_not_tracking(self):
|
||||||
self.store.upsert_monthly_active_user = Mock()
|
self.store.upsert_monthly_active_user = Mock(
|
||||||
|
side_effect=lambda user_id: make_awaitable(None)
|
||||||
|
)
|
||||||
|
|
||||||
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
|
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user