Allow use of both @trace and @tag_args stacked on the same function (#13453)

```py
@trace
@tag_args
async def get_oldest_event_ids_with_depth_in_room(...)
  ...
```

Before this PR, you would see a warning in the logs and the span was not exported:
```
2022-08-03 19:11:59,383 - synapse.logging.opentracing - 835 - ERROR - GET-0 - @trace may not have wrapped EventFederationWorkerStore.get_oldest_event_ids_with_depth_in_room correctly! The function is not async but returned a coroutine.
```
This commit is contained in:
Eric Eastwood 2022-08-09 14:32:33 -05:00 committed by GitHub
parent 1595052b26
commit 1b09b0832e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 191 additions and 61 deletions

1
changelog.d/13453.misc Normal file
View File

@ -0,0 +1 @@
Allow use of both `@trace` and `@tag_args` stacked on the same function (tracing).

View File

@ -173,6 +173,7 @@ from typing import (
Any, Any,
Callable, Callable,
Collection, Collection,
ContextManager,
Dict, Dict,
Generator, Generator,
Iterable, Iterable,
@ -823,75 +824,117 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
# Tracing decorators # Tracing decorators
def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]: def _custom_sync_async_decorator(
func: Callable[P, R],
wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
) -> Callable[P, R]:
"""
Decorates a function that is sync or async (coroutines), or that returns a Twisted
`Deferred`. The custom business logic of the decorator goes in `wrapping_logic`.
Example usage:
```py
# Decorator to time the function and log it out
def duration(func: Callable[P, R]) -> Callable[P, R]:
@contextlib.contextmanager
def _wrapping_logic(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Generator[None, None, None]:
start_ts = time.time()
try:
yield
finally:
end_ts = time.time()
duration = end_ts - start_ts
logger.info("%s took %s seconds", func.__name__, duration)
return _custom_sync_async_decorator(func, _wrapping_logic)
```
Args:
func: The function to be decorated
wrapping_logic: The business logic of your custom decorator.
This should be a ContextManager so you are able to run your logic
before/after the function as desired.
"""
if inspect.iscoroutinefunction(func):
@wraps(func)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
with wrapping_logic(func, *args, **kwargs):
return await func(*args, **kwargs) # type: ignore[misc]
else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
@wraps(func)
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
scope = wrapping_logic(func, *args, **kwargs)
scope.__enter__()
try:
result = func(*args, **kwargs)
if isinstance(result, defer.Deferred):
def call_back(result: R) -> R:
scope.__exit__(None, None, None)
return result
def err_back(result: R) -> R:
scope.__exit__(None, None, None)
return result
result.addCallbacks(call_back, err_back)
else:
if inspect.isawaitable(result):
logger.error(
"@trace may not have wrapped %s correctly! "
"The function is not async but returned a %s.",
func.__qualname__,
type(result).__name__,
)
scope.__exit__(None, None, None)
return result
except Exception as e:
scope.__exit__(type(e), None, e.__traceback__)
raise
return _wrapper # type: ignore[return-value]
def trace_with_opname(
opname: str,
*,
tracer: Optional["opentracing.Tracer"] = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
""" """
Decorator to trace a function with a custom opname. Decorator to trace a function with a custom opname.
See the module's doc string for usage examples. See the module's doc string for usage examples.
""" """
def decorator(func: Callable[P, R]) -> Callable[P, R]: # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
if opentracing is None: @contextlib.contextmanager # type: ignore[arg-type]
return func # type: ignore[unreachable] def _wrapping_logic(
func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> Generator[None, None, None]:
with start_active_span(opname, tracer=tracer):
yield
if inspect.iscoroutinefunction(func): def _decorator(func: Callable[P, R]) -> Callable[P, R]:
if not opentracing:
return func
@wraps(func) return _custom_sync_async_decorator(func, _wrapping_logic)
async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
with start_active_span(opname):
return await func(*args, **kwargs) # type: ignore[misc]
else: return _decorator
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
@wraps(func)
def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
scope = start_active_span(opname)
scope.__enter__()
try:
result = func(*args, **kwargs)
if isinstance(result, defer.Deferred):
def call_back(result: R) -> R:
scope.__exit__(None, None, None)
return result
def err_back(result: R) -> R:
scope.__exit__(None, None, None)
return result
result.addCallbacks(call_back, err_back)
else:
if inspect.isawaitable(result):
logger.error(
"@trace may not have wrapped %s correctly! "
"The function is not async but returned a %s.",
func.__qualname__,
type(result).__name__,
)
scope.__exit__(None, None, None)
return result
except Exception as e:
scope.__exit__(type(e), None, e.__traceback__)
raise
return _trace_inner # type: ignore[return-value]
return decorator
def trace(func: Callable[P, R]) -> Callable[P, R]: def trace(func: Callable[P, R]) -> Callable[P, R]:
""" """
Decorator to trace a function. Decorator to trace a function.
Sets the operation name to that of the function's name. Sets the operation name to that of the function's name.
See the module's doc string for usage examples. See the module's doc string for usage examples.
""" """
@ -900,7 +943,7 @@ def trace(func: Callable[P, R]) -> Callable[P, R]:
def tag_args(func: Callable[P, R]) -> Callable[P, R]: def tag_args(func: Callable[P, R]) -> Callable[P, R]:
""" """
Tags all of the args to the active span. Decorator to tag all of the args to the active span.
Args: Args:
func: `func` is assumed to be a method taking a `self` parameter, or a func: `func` is assumed to be a method taking a `self` parameter, or a
@ -911,22 +954,25 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
if not opentracing: if not opentracing:
return func return func
@wraps(func) # type-ignore: mypy bug, see https://github.com/python/mypy/issues/12909
def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R: @contextlib.contextmanager # type: ignore[arg-type]
def _wrapping_logic(
func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> Generator[None, None, None]:
argspec = inspect.getfullargspec(func) argspec = inspect.getfullargspec(func)
# We use `[1:]` to skip the `self` object reference and `start=1` to # We use `[1:]` to skip the `self` object reference and `start=1` to
# make the index line up with `argspec.args`. # make the index line up with `argspec.args`.
# #
# FIXME: We could update this handle any type of function by ignoring the # FIXME: We could update this to handle any type of function by ignoring the
# first argument only if it's named `self` or `cls`. This isn't fool-proof # first argument only if it's named `self` or `cls`. This isn't fool-proof
# but handles the idiomatic cases. # but handles the idiomatic cases.
for i, arg in enumerate(args[1:], start=1): # type: ignore[index] for i, arg in enumerate(args[1:], start=1): # type: ignore[index]
set_tag("ARG_" + argspec.args[i], str(arg)) set_tag("ARG_" + argspec.args[i], str(arg))
set_tag("args", str(args[len(argspec.args) :])) # type: ignore[index] set_tag("args", str(args[len(argspec.args) :])) # type: ignore[index]
set_tag("kwargs", str(kwargs)) set_tag("kwargs", str(kwargs))
return func(*args, **kwargs) yield
return _tag_args_inner return _custom_sync_async_decorator(func, _wrapping_logic)
@contextlib.contextmanager @contextlib.contextmanager

View File

@ -25,6 +25,8 @@ from synapse.logging.context import (
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
start_active_span, start_active_span,
start_active_span_follows_from, start_active_span_follows_from,
tag_args,
trace_with_opname,
) )
from synapse.util import Clock from synapse.util import Clock
@ -38,8 +40,12 @@ try:
except ImportError: except ImportError:
jaeger_client = None # type: ignore jaeger_client = None # type: ignore
import logging
from tests.unittest import TestCase from tests.unittest import TestCase
logger = logging.getLogger(__name__)
class LogContextScopeManagerTestCase(TestCase): class LogContextScopeManagerTestCase(TestCase):
""" """
@ -194,3 +200,80 @@ class LogContextScopeManagerTestCase(TestCase):
self._reporter.get_spans(), self._reporter.get_spans(),
[scopes[1].span, scopes[2].span, scopes[0].span], [scopes[1].span, scopes[2].span, scopes[0].span],
) )
def test_trace_decorator_sync(self) -> None:
"""
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with sync functions
"""
with LoggingContext("root context"):
@trace_with_opname("fixture_sync_func", tracer=self._tracer)
@tag_args
def fixture_sync_func() -> str:
return "foo"
result = fixture_sync_func()
self.assertEqual(result, "foo")
# the span should have been reported
self.assertEqual(
[span.operation_name for span in self._reporter.get_spans()],
["fixture_sync_func"],
)
def test_trace_decorator_deferred(self) -> None:
"""
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with functions that return deferreds
"""
reactor = MemoryReactorClock()
with LoggingContext("root context"):
@trace_with_opname("fixture_deferred_func", tracer=self._tracer)
@tag_args
def fixture_deferred_func() -> "defer.Deferred[str]":
d1: defer.Deferred[str] = defer.Deferred()
d1.callback("foo")
return d1
result_d1 = fixture_deferred_func()
# let the tasks complete
reactor.pump((2,) * 8)
self.assertEqual(self.successResultOf(result_d1), "foo")
# the span should have been reported
self.assertEqual(
[span.operation_name for span in self._reporter.get_spans()],
["fixture_deferred_func"],
)
def test_trace_decorator_async(self) -> None:
"""
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with async functions
"""
reactor = MemoryReactorClock()
with LoggingContext("root context"):
@trace_with_opname("fixture_async_func", tracer=self._tracer)
@tag_args
async def fixture_async_func() -> str:
return "foo"
d1 = defer.ensureDeferred(fixture_async_func())
# let the tasks complete
reactor.pump((2,) * 8)
self.assertEqual(self.successResultOf(d1), "foo")
# the span should have been reported
self.assertEqual(
[span.operation_name for span in self._reporter.get_spans()],
["fixture_async_func"],
)