Add type annotations to trace decorator. (#13328)

Functions that are decorated with `trace` are now properly typed
and the type hints for them are fixed.
This commit is contained in:
Patrick Cloke 2022-07-19 14:14:30 -04:00 committed by GitHub
parent 47822fd2e8
commit a6895dd576
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 102 additions and 55 deletions

1
changelog.d/13328.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `trace` decorator.

View File

@ -217,7 +217,7 @@ class FederationClient(FederationBase):
)
async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int
self, destination: str, content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.

View File

@ -619,7 +619,7 @@ class TransportLayerClient:
)
async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: int
self, destination: str, query_content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.

View File

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@ -92,7 +92,11 @@ class E2eKeysHandler:
@trace
async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
self,
query_body: JsonDict,
timeout: int,
from_user_id: str,
from_device_id: Optional[str],
) -> JsonDict:
"""Handle a device key query from a client
@ -120,9 +124,7 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
)
device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
@ -392,7 +394,7 @@ class E2eKeysHandler:
@trace
async def query_local_devices(
self, query: Dict[str, Optional[List[str]]]
self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
@ -461,7 +463,7 @@ class E2eKeysHandler:
@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}

View File

@ -84,14 +84,13 @@ the function becomes the operation name for the span.
return something_usual_and_useful
Operation names can be explicitly set for a function by passing the
operation name to ``trace``
Operation names can be explicitly set for a function by using ``trace_with_opname``:
.. code-block:: python
from synapse.logging.opentracing import trace
from synapse.logging.opentracing import trace_with_opname
@trace(opname="a_better_operation_name")
@trace_with_opname("a_better_operation_name")
def interesting_badly_named_function(*args, **kwargs):
# Does all kinds of cool and expected things
return something_usual_and_useful
@ -798,33 +797,31 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
# Tracing decorators
def trace(func=None, opname: Optional[str] = None):
def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator to trace a function.
Sets the operation name to that of the function's or that given
as operation_name. See the module's doc string for usage
examples.
Decorator to trace a function with a custom opname.
See the module's doc string for usage examples.
"""
def decorator(func):
def decorator(func: Callable[P, R]) -> Callable[P, R]:
if opentracing is None:
return func # type: ignore[unreachable]
_opname = opname if opname else func.__name__
if inspect.iscoroutinefunction(func):
@wraps(func)
async def _trace_inner(*args, **kwargs):
with start_active_span(_opname):
return await func(*args, **kwargs)
async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
with start_active_span(opname):
return await func(*args, **kwargs) # type: ignore[misc]
else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
@wraps(func)
def _trace_inner(*args, **kwargs):
scope = start_active_span(_opname)
def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
scope = start_active_span(opname)
scope.__enter__()
try:
@ -858,14 +855,23 @@ def trace(func=None, opname: Optional[str] = None):
scope.__exit__(type(e), None, e.__traceback__)
raise
return _trace_inner
return _trace_inner # type: ignore[return-value]
if func:
return decorator(func)
else:
return decorator
def trace(func: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to trace a function.
Sets the operation name to that of the function's name.
See the module's doc string for usage examples.
"""
return trace_with_opname(func.__name__)(func)
def tag_args(func: Callable[P, R]) -> Callable[P, R]:
"""
Tags all of the args to the active span.

View File

@ -29,7 +29,7 @@ from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer, is_method_cancellable
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.opentracing import trace
from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@ -196,7 +196,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
"ascii"
)
@trace(opname="outgoing_replication_request")
@trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:

View File

@ -26,7 +26,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler
@ -71,7 +71,7 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
@trace(opname="upload_keys")
@trace_with_opname("upload_keys")
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:

View File

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple, cast
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
version = parse_string(request, "version")
version = parse_string(request, "version", required=True)
if session_id:
body = {"sessions": {session_id: body}}
@ -196,8 +196,11 @@ class RoomKeysServlet(RestServlet):
user_id = requester.user.to_string()
version = parse_string(request, "version", required=True)
room_keys = await self.e2e_room_keys_handler.get_room_keys(
room_keys = cast(
JsonDict,
await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
),
)
# Convert room_keys to the right format to return.
@ -240,7 +243,7 @@ class RoomKeysServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version")
version = parse_string(request, "version", required=True)
ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id

View File

@ -19,7 +19,7 @@ from synapse.http import servlet
from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag, trace
from synapse.logging.opentracing import set_tag, trace_with_opname
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict
@ -43,7 +43,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
@trace(opname="sendToDevice")
@trace_with_opname("sendToDevice")
def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:

View File

@ -37,7 +37,7 @@ from synapse.handlers.sync import (
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete")
return 200, response_content
@trace(opname="sync.encode_response")
@trace_with_opname("sync.encode_response")
async def encode_response(
self,
time_now: int,
@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet):
]
}
@trace(opname="sync.encode_joined")
@trace_with_opname("sync.encode_joined")
async def encode_joined(
self,
rooms: List[JoinedSyncResult],
@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet):
return joined
@trace(opname="sync.encode_invited")
@trace_with_opname("sync.encode_invited")
async def encode_invited(
self,
rooms: List[InvitedSyncResult],
@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet):
return invited
@trace(opname="sync.encode_knocked")
@trace_with_opname("sync.encode_knocked")
async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet):
return knocked
@trace(opname="sync.encode_archived")
@trace_with_opname("sync.encode_archived")
async def encode_archived(
self,
rooms: List[ArchivedSyncResult],

View File

@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
@trace
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, str]]
self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.

View File

@ -22,11 +22,14 @@ from typing import (
List,
Optional,
Tuple,
Union,
cast,
overload,
)
import attr
from canonicaljson import encode_canonical_json
from typing_extensions import Literal
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result = {"device_id": device_id}
result: JsonDict = {"device_id": device_id}
keys = device.keys
if keys:
@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
rv[user_id] = {}
for device_id, device_info in device_keys.items():
r = device_info.keys
if r is None:
continue
r["unsigned"] = {}
display_name = device_info.display_name
if display_name is not None:
@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return rv
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: Literal[False] = False,
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
...
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: Literal[False] = False,
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
...
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: Literal[True],
include_deleted_devices: Literal[True],
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
...
@trace
async def get_e2e_device_keys_and_signatures(
self,
query_list: List[Tuple[str, Optional[str]]],
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
) -> Union[
Dict[str, Dict[str, DeviceKeyLookupResult]],
Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
]:
"""Fetch a list of device keys
Any cross-signatures made on the keys by the owner of the device are also
@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False
row = await self.db_pool.runInteraction(
claim_row = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
_claim_e2e_one_time_key,
user_id,
@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
algorithm,
db_autocommit=db_autocommit,
)
if row:
if claim_row:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
device_results[row[0]] = row[1]
device_results[claim_row[0]] = claim_row[1]
continue
# No one-time key available, so see if there's a fallback