Add type annotations to trace decorator. (#13328)

Functions that are decorated with `trace` are now properly typed
and the type hints for them are fixed.
This commit is contained in:
Patrick Cloke 2022-07-19 14:14:30 -04:00 committed by GitHub
parent 47822fd2e8
commit a6895dd576
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 102 additions and 55 deletions

View file

@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
@trace
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, str]]
self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.

View file

@ -22,11 +22,14 @@ from typing import (
List,
Optional,
Tuple,
Union,
cast,
overload,
)
import attr
from canonicaljson import encode_canonical_json
from typing_extensions import Literal
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result = {"device_id": device_id}
result: JsonDict = {"device_id": device_id}
keys = device.keys
if keys:
@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
rv[user_id] = {}
for device_id, device_info in device_keys.items():
r = device_info.keys
if r is None:
continue
r["unsigned"] = {}
display_name = device_info.display_name
if display_name is not None:
@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return rv
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: Literal[False] = False,
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
...
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: Literal[False] = False,
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
...
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: Literal[True],
include_deleted_devices: Literal[True],
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
...
@trace
async def get_e2e_device_keys_and_signatures(
self,
query_list: List[Tuple[str, Optional[str]]],
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
) -> Union[
Dict[str, Dict[str, DeviceKeyLookupResult]],
Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
]:
"""Fetch a list of device keys
Any cross-signatures made on the keys by the owner of the device are also
@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False
row = await self.db_pool.runInteraction(
claim_row = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
_claim_e2e_one_time_key,
user_id,
@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
algorithm,
db_autocommit=db_autocommit,
)
if row:
if claim_row:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
device_results[row[0]] = row[1]
device_results[claim_row[0]] = claim_row[1]
continue
# No one-time key available, so see if there's a fallback