Convert additional databases to async/await part 3 (#8201)

This commit is contained in:
Patrick Cloke 2020-09-01 11:04:17 -04:00 committed by GitHub
parent 7d103a594e
commit 37db6252b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 121 additions and 87 deletions

View file

@ -34,13 +34,15 @@ if TYPE_CHECKING:
class EndToEndKeyWorkerStore(SQLBaseStore):
def get_e2e_device_keys_for_federation_query(self, user_id: str):
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
"""Get all devices (with any device keys) for a user
Returns:
Deferred which resolves to (stream_id, devices)
(stream_id, devices)
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_e2e_device_keys_for_federation_query",
self._get_e2e_device_keys_for_federation_query_txn,
user_id,
@ -292,10 +294,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
def count_e2e_one_time_keys(self, user_id, device_id):
async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device
Returns:
Dict mapping from algorithm to number of keys for that algorithm.
A mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
@ -310,7 +314,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
@ -348,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
list_name="user_ids",
num_args=1,
)
def _get_bare_e2e_cross_signing_keys_bulk(
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
@ -356,16 +360,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the signatures for the calling user need to be fetched.
Args:
user_ids (list[str]): the users whose keys are being requested
user_ids: the users whose keys are being requested
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, either
their user ID will not be in the dict, or their user ID will map
to None.
A mapping from user ID to key type to key data. If a user's cross-signing
keys were not found, either their user ID will not be in the dict, or
their user ID will map to None.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
@ -588,7 +591,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
@ -624,12 +629,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
"""Take a list of one time keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""
@trace
def _claim_e2e_one_time_keys(txn):
@ -665,11 +679,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
def delete_e2e_keys_by_device(self, user_id, device_id):
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
@ -692,7 +706,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)