mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-28 15:09:27 -05:00
Add type annotations to trace
decorator. (#13328)
Functions that are decorated with `trace` are now properly typed and the type hints for them are fixed.
This commit is contained in:
parent
47822fd2e8
commit
a6895dd576
1
changelog.d/13328.misc
Normal file
1
changelog.d/13328.misc
Normal file
@ -0,0 +1 @@
|
||||
Add type hints to `trace` decorator.
|
@ -217,7 +217,7 @@ class FederationClient(FederationBase):
|
||||
)
|
||||
|
||||
async def claim_client_keys(
|
||||
self, destination: str, content: JsonDict, timeout: int
|
||||
self, destination: str, content: JsonDict, timeout: Optional[int]
|
||||
) -> JsonDict:
|
||||
"""Claims one-time keys for a device hosted on a remote server.
|
||||
|
||||
|
@ -619,7 +619,7 @@ class TransportLayerClient:
|
||||
)
|
||||
|
||||
async def claim_client_keys(
|
||||
self, destination: str, query_content: JsonDict, timeout: int
|
||||
self, destination: str, query_content: JsonDict, timeout: Optional[int]
|
||||
) -> JsonDict:
|
||||
"""Claim one-time keys for a list of devices hosted on a remote server.
|
||||
|
||||
|
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
@ -92,7 +92,11 @@ class E2eKeysHandler:
|
||||
|
||||
@trace
|
||||
async def query_devices(
|
||||
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
|
||||
self,
|
||||
query_body: JsonDict,
|
||||
timeout: int,
|
||||
from_user_id: str,
|
||||
from_device_id: Optional[str],
|
||||
) -> JsonDict:
|
||||
"""Handle a device key query from a client
|
||||
|
||||
@ -120,9 +124,7 @@ class E2eKeysHandler:
|
||||
the number of in-flight queries at a time.
|
||||
"""
|
||||
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
||||
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
|
||||
"device_keys", {}
|
||||
)
|
||||
device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
|
||||
|
||||
# separate users by domain.
|
||||
# make a map from domain to user_id to device_ids
|
||||
@ -392,7 +394,7 @@ class E2eKeysHandler:
|
||||
|
||||
@trace
|
||||
async def query_local_devices(
|
||||
self, query: Dict[str, Optional[List[str]]]
|
||||
self, query: Mapping[str, Optional[List[str]]]
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Get E2E device keys for local users
|
||||
|
||||
@ -461,7 +463,7 @@ class E2eKeysHandler:
|
||||
|
||||
@trace
|
||||
async def claim_one_time_keys(
|
||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
|
||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
|
||||
) -> JsonDict:
|
||||
local_query: List[Tuple[str, str, str]] = []
|
||||
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||
|
@ -84,14 +84,13 @@ the function becomes the operation name for the span.
|
||||
return something_usual_and_useful
|
||||
|
||||
|
||||
Operation names can be explicitly set for a function by passing the
|
||||
operation name to ``trace``
|
||||
Operation names can be explicitly set for a function by using ``trace_with_opname``:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.logging.opentracing import trace_with_opname
|
||||
|
||||
@trace(opname="a_better_operation_name")
|
||||
@trace_with_opname("a_better_operation_name")
|
||||
def interesting_badly_named_function(*args, **kwargs):
|
||||
# Does all kinds of cool and expected things
|
||||
return something_usual_and_useful
|
||||
@ -798,33 +797,31 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
|
||||
# Tracing decorators
|
||||
|
||||
|
||||
def trace(func=None, opname: Optional[str] = None):
|
||||
def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""
|
||||
Decorator to trace a function.
|
||||
Sets the operation name to that of the function's or that given
|
||||
as operation_name. See the module's doc string for usage
|
||||
examples.
|
||||
Decorator to trace a function with a custom opname.
|
||||
|
||||
See the module's doc string for usage examples.
|
||||
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
if opentracing is None:
|
||||
return func # type: ignore[unreachable]
|
||||
|
||||
_opname = opname if opname else func.__name__
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def _trace_inner(*args, **kwargs):
|
||||
with start_active_span(_opname):
|
||||
return await func(*args, **kwargs)
|
||||
async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
with start_active_span(opname):
|
||||
return await func(*args, **kwargs) # type: ignore[misc]
|
||||
|
||||
else:
|
||||
# The other case here handles both sync functions and those
|
||||
# decorated with inlineDeferred.
|
||||
@wraps(func)
|
||||
def _trace_inner(*args, **kwargs):
|
||||
scope = start_active_span(_opname)
|
||||
def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
scope = start_active_span(opname)
|
||||
scope.__enter__()
|
||||
|
||||
try:
|
||||
@ -858,12 +855,21 @@ def trace(func=None, opname: Optional[str] = None):
|
||||
scope.__exit__(type(e), None, e.__traceback__)
|
||||
raise
|
||||
|
||||
return _trace_inner
|
||||
return _trace_inner # type: ignore[return-value]
|
||||
|
||||
if func:
|
||||
return decorator(func)
|
||||
else:
|
||||
return decorator
|
||||
return decorator
|
||||
|
||||
|
||||
def trace(func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Decorator to trace a function.
|
||||
|
||||
Sets the operation name to that of the function's name.
|
||||
|
||||
See the module's doc string for usage examples.
|
||||
"""
|
||||
|
||||
return trace_with_opname(func.__name__)(func)
|
||||
|
||||
|
||||
def tag_args(func: Callable[P, R]) -> Callable[P, R]:
|
||||
|
@ -29,7 +29,7 @@ from synapse.http import RequestTimedOutError
|
||||
from synapse.http.server import HttpServer, is_method_cancellable
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging import opentracing
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.logging.opentracing import trace_with_opname
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.stringutils import random_string
|
||||
@ -196,7 +196,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
"ascii"
|
||||
)
|
||||
|
||||
@trace(opname="outgoing_replication_request")
|
||||
@trace_with_opname("outgoing_replication_request")
|
||||
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
|
||||
with outgoing_gauge.track_inprogress():
|
||||
if instance_name == local_instance_name:
|
||||
|
@ -26,7 +26,7 @@ from synapse.http.servlet import (
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
|
||||
from synapse.types import JsonDict, StreamToken
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
@ -71,7 +71,7 @@ class KeyUploadServlet(RestServlet):
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@trace(opname="upload_keys")
|
||||
@trace_with_opname("upload_keys")
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, device_id: Optional[str]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, cast
|
||||
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
version = parse_string(request, "version")
|
||||
version = parse_string(request, "version", required=True)
|
||||
|
||||
if session_id:
|
||||
body = {"sessions": {session_id: body}}
|
||||
@ -196,8 +196,11 @@ class RoomKeysServlet(RestServlet):
|
||||
user_id = requester.user.to_string()
|
||||
version = parse_string(request, "version", required=True)
|
||||
|
||||
room_keys = await self.e2e_room_keys_handler.get_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
room_keys = cast(
|
||||
JsonDict,
|
||||
await self.e2e_room_keys_handler.get_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
),
|
||||
)
|
||||
|
||||
# Convert room_keys to the right format to return.
|
||||
@ -240,7 +243,7 @@ class RoomKeysServlet(RestServlet):
|
||||
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
version = parse_string(request, "version")
|
||||
version = parse_string(request, "version", required=True)
|
||||
|
||||
ret = await self.e2e_room_keys_handler.delete_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
|
@ -19,7 +19,7 @@ from synapse.http import servlet
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import set_tag, trace
|
||||
from synapse.logging.opentracing import set_tag, trace_with_opname
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.types import JsonDict
|
||||
|
||||
@ -43,7 +43,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
||||
self.txns = HttpTransactionCache(hs)
|
||||
self.device_message_handler = hs.get_device_message_handler()
|
||||
|
||||
@trace(opname="sendToDevice")
|
||||
@trace_with_opname("sendToDevice")
|
||||
def on_PUT(
|
||||
self, request: SynapseRequest, message_type: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
|
@ -37,7 +37,7 @@ from synapse.handlers.sync import (
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.logging.opentracing import trace_with_opname
|
||||
from synapse.types import JsonDict, StreamToken
|
||||
from synapse.util import json_decoder
|
||||
|
||||
@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet):
|
||||
logger.debug("Event formatting complete")
|
||||
return 200, response_content
|
||||
|
||||
@trace(opname="sync.encode_response")
|
||||
@trace_with_opname("sync.encode_response")
|
||||
async def encode_response(
|
||||
self,
|
||||
time_now: int,
|
||||
@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet):
|
||||
]
|
||||
}
|
||||
|
||||
@trace(opname="sync.encode_joined")
|
||||
@trace_with_opname("sync.encode_joined")
|
||||
async def encode_joined(
|
||||
self,
|
||||
rooms: List[JoinedSyncResult],
|
||||
@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet):
|
||||
|
||||
return joined
|
||||
|
||||
@trace(opname="sync.encode_invited")
|
||||
@trace_with_opname("sync.encode_invited")
|
||||
async def encode_invited(
|
||||
self,
|
||||
rooms: List[InvitedSyncResult],
|
||||
@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet):
|
||||
|
||||
return invited
|
||||
|
||||
@trace(opname="sync.encode_knocked")
|
||||
@trace_with_opname("sync.encode_knocked")
|
||||
async def encode_knocked(
|
||||
self,
|
||||
rooms: List[KnockedSyncResult],
|
||||
@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet):
|
||||
|
||||
return knocked
|
||||
|
||||
@trace(opname="sync.encode_archived")
|
||||
@trace_with_opname("sync.encode_archived")
|
||||
async def encode_archived(
|
||||
self,
|
||||
rooms: List[ArchivedSyncResult],
|
||||
|
@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
|
||||
|
||||
@trace
|
||||
async def get_user_devices_from_cache(
|
||||
self, query_list: List[Tuple[str, str]]
|
||||
self, query_list: List[Tuple[str, Optional[str]]]
|
||||
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Get the devices (and keys if any) for remote users from the cache.
|
||||
|
||||
|
@ -22,11 +22,14 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.api.constants import DeviceKeyAlgorithms
|
||||
from synapse.appservice import (
|
||||
@ -113,7 +116,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
user_devices = devices[user_id]
|
||||
results = []
|
||||
for device_id, device in user_devices.items():
|
||||
result = {"device_id": device_id}
|
||||
result: JsonDict = {"device_id": device_id}
|
||||
|
||||
keys = device.keys
|
||||
if keys:
|
||||
@ -156,6 +159,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
rv[user_id] = {}
|
||||
for device_id, device_info in device_keys.items():
|
||||
r = device_info.keys
|
||||
if r is None:
|
||||
continue
|
||||
|
||||
r["unsigned"] = {}
|
||||
display_name = device_info.display_name
|
||||
if display_name is not None:
|
||||
@ -164,13 +170,42 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
|
||||
return rv
|
||||
|
||||
@overload
|
||||
async def get_e2e_device_keys_and_signatures(
|
||||
self,
|
||||
query_list: Collection[Tuple[str, Optional[str]]],
|
||||
include_all_devices: Literal[False] = False,
|
||||
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_e2e_device_keys_and_signatures(
|
||||
self,
|
||||
query_list: Collection[Tuple[str, Optional[str]]],
|
||||
include_all_devices: bool = False,
|
||||
include_deleted_devices: Literal[False] = False,
|
||||
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_e2e_device_keys_and_signatures(
|
||||
self,
|
||||
query_list: Collection[Tuple[str, Optional[str]]],
|
||||
include_all_devices: Literal[True],
|
||||
include_deleted_devices: Literal[True],
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
...
|
||||
|
||||
@trace
|
||||
async def get_e2e_device_keys_and_signatures(
|
||||
self,
|
||||
query_list: List[Tuple[str, Optional[str]]],
|
||||
query_list: Collection[Tuple[str, Optional[str]]],
|
||||
include_all_devices: bool = False,
|
||||
include_deleted_devices: bool = False,
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
) -> Union[
|
||||
Dict[str, Dict[str, DeviceKeyLookupResult]],
|
||||
Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
|
||||
]:
|
||||
"""Fetch a list of device keys
|
||||
|
||||
Any cross-signatures made on the keys by the owner of the device are also
|
||||
@ -1044,7 +1079,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
|
||||
db_autocommit = False
|
||||
|
||||
row = await self.db_pool.runInteraction(
|
||||
claim_row = await self.db_pool.runInteraction(
|
||||
"claim_e2e_one_time_keys",
|
||||
_claim_e2e_one_time_key,
|
||||
user_id,
|
||||
@ -1052,11 +1087,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
algorithm,
|
||||
db_autocommit=db_autocommit,
|
||||
)
|
||||
if row:
|
||||
if claim_row:
|
||||
device_results = results.setdefault(user_id, {}).setdefault(
|
||||
device_id, {}
|
||||
)
|
||||
device_results[row[0]] = row[1]
|
||||
device_results[claim_row[0]] = claim_row[1]
|
||||
continue
|
||||
|
||||
# No one-time key available, so see if there's a fallback
|
||||
|
Loading…
Reference in New Issue
Block a user