Trace functions which return Awaitable (#15650)

This commit is contained in:
Eric Eastwood 2023-06-06 17:39:22 -05:00 committed by GitHub
parent 4e6390cb10
commit 8bfded81f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 22 deletions

1
changelog.d/15650.misc Normal file
View File

@ -0,0 +1 @@
Add support for tracing functions which return `Awaitable`s.

View File

@ -171,6 +171,7 @@ from functools import wraps
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable,
Callable, Callable,
Collection, Collection,
ContextManager, ContextManager,
@ -903,6 +904,7 @@ def _custom_sync_async_decorator(
""" """
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
# For this branch, we handle async functions like `async def func() -> RInner`.
# In this branch, R = Awaitable[RInner], for some other type RInner # In this branch, R = Awaitable[RInner], for some other type RInner
@wraps(func) @wraps(func)
async def _wrapper( async def _wrapper(
@ -914,15 +916,16 @@ def _custom_sync_async_decorator(
return await func(*args, **kwargs) # type: ignore[misc] return await func(*args, **kwargs) # type: ignore[misc]
else: else:
# The other case here handles both sync functions and those # The other case here handles sync functions including those decorated with
# decorated with inlineDeferred. # `@defer.inlineCallbacks` or that return a `Deferred` or other `Awaitable`.
@wraps(func) @wraps(func)
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R: def _wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
scope = wrapping_logic(func, *args, **kwargs) scope = wrapping_logic(func, *args, **kwargs)
scope.__enter__() scope.__enter__()
try: try:
result = func(*args, **kwargs) result = func(*args, **kwargs)
if isinstance(result, defer.Deferred): if isinstance(result, defer.Deferred):
def call_back(result: R) -> R: def call_back(result: R) -> R:
@ -930,20 +933,32 @@ def _custom_sync_async_decorator(
return result return result
def err_back(result: R) -> R: def err_back(result: R) -> R:
# TODO: Pass the error details into `scope.__exit__(...)` for
# consistency with the other paths.
scope.__exit__(None, None, None) scope.__exit__(None, None, None)
return result return result
result.addCallbacks(call_back, err_back) result.addCallbacks(call_back, err_back)
else: elif inspect.isawaitable(result):
if inspect.isawaitable(result):
logger.error(
"@trace may not have wrapped %s correctly! "
"The function is not async but returned a %s.",
func.__qualname__,
type(result).__name__,
)
async def wrap_awaitable() -> Any:
try:
assert isinstance(result, Awaitable)
awaited_result = await result
scope.__exit__(None, None, None)
return awaited_result
except Exception as e:
scope.__exit__(type(e), None, e.__traceback__)
raise
# The original method returned an awaitable, eg. a coroutine, so we
# create another awaitable wrapping it that calls
# `scope.__exit__(...)`.
return wrap_awaitable()
else:
# Just a simple sync function so we can just exit the scope and
# return the result without any fuss.
scope.__exit__(None, None, None) scope.__exit__(None, None, None)
return result return result

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import cast from typing import Awaitable, cast
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactorClock from twisted.test.proto_helpers import MemoryReactorClock
@ -227,8 +227,6 @@ class LogContextScopeManagerTestCase(TestCase):
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args` Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with functions that return deferreds with functions that return deferreds
""" """
reactor = MemoryReactorClock()
with LoggingContext("root context"): with LoggingContext("root context"):
@trace_with_opname("fixture_deferred_func", tracer=self._tracer) @trace_with_opname("fixture_deferred_func", tracer=self._tracer)
@ -240,9 +238,6 @@ class LogContextScopeManagerTestCase(TestCase):
result_d1 = fixture_deferred_func() result_d1 = fixture_deferred_func()
# let the tasks complete
reactor.pump((2,) * 8)
self.assertEqual(self.successResultOf(result_d1), "foo") self.assertEqual(self.successResultOf(result_d1), "foo")
# the span should have been reported # the span should have been reported
@ -256,8 +251,6 @@ class LogContextScopeManagerTestCase(TestCase):
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args` Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with async functions with async functions
""" """
reactor = MemoryReactorClock()
with LoggingContext("root context"): with LoggingContext("root context"):
@trace_with_opname("fixture_async_func", tracer=self._tracer) @trace_with_opname("fixture_async_func", tracer=self._tracer)
@ -267,9 +260,6 @@ class LogContextScopeManagerTestCase(TestCase):
d1 = defer.ensureDeferred(fixture_async_func()) d1 = defer.ensureDeferred(fixture_async_func())
# let the tasks complete
reactor.pump((2,) * 8)
self.assertEqual(self.successResultOf(d1), "foo") self.assertEqual(self.successResultOf(d1), "foo")
# the span should have been reported # the span should have been reported
@ -277,3 +267,34 @@ class LogContextScopeManagerTestCase(TestCase):
[span.operation_name for span in self._reporter.get_spans()], [span.operation_name for span in self._reporter.get_spans()],
["fixture_async_func"], ["fixture_async_func"],
) )
def test_trace_decorator_awaitable_return(self) -> None:
"""
Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args`
with functions that return an awaitable (e.g. a coroutine)
"""
with LoggingContext("root context"):
# Something we can return without `await` to get a coroutine
async def fixture_async_func() -> str:
return "foo"
# The actual kind of function we want to test that returns an awaitable
@trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer)
@tag_args
def fixture_awaitable_return_func() -> Awaitable[str]:
return fixture_async_func()
# Something we can run with `defer.ensureDeferred(runner())` and pump the
# whole async tasks through to completion.
async def runner() -> str:
return await fixture_awaitable_return_func()
d1 = defer.ensureDeferred(runner())
self.assertEqual(self.successResultOf(d1), "foo")
# the span should have been reported
self.assertEqual(
[span.operation_name for span in self._reporter.get_spans()],
["fixture_awaitable_return_func"],
)