Add missing type hints to synapse.logging.context (#11556)

This commit is contained in:
Sean Quah 2021-12-14 17:35:28 +00:00 committed by GitHub
parent 2519beaad2
commit 0147b3de20
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 215 additions and 122 deletions

View file

@ -22,20 +22,33 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
import inspect
import logging
import threading
import typing
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
import attr
from typing_extensions import Literal
from twisted.internet import defer, threads
from twisted.python.threadpool import ThreadPool
if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
from synapse.types import ISynapseReactor
logger = logging.getLogger(__name__)
@ -66,7 +79,7 @@ except Exception:
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
def logcontext_error(msg: str):
def logcontext_error(msg: str) -> None:
logger.warning(msg)
@ -223,22 +236,19 @@ class _Sentinel:
def __str__(self) -> str:
return "sentinel"
def copy_to(self, record):
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass
def start(self, rusage: "Optional[resource.struct_rusage]"):
def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
pass
def stop(self, rusage: "Optional[resource.struct_rusage]"):
def add_database_transaction(self, duration_sec: float) -> None:
pass
def add_database_transaction(self, duration_sec):
def add_database_scheduled(self, sched_sec: float) -> None:
pass
def add_database_scheduled(self, sched_sec):
pass
def record_event_fetch(self, event_count):
def record_event_fetch(self, event_count: int) -> None:
pass
def __bool__(self) -> Literal[False]:
@ -379,7 +389,12 @@ class LoggingContext:
)
return self
def __exit__(self, type, value, traceback) -> None:
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
@ -399,17 +414,6 @@ class LoggingContext:
# recorded against the correct metrics.
self.finished = True
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
another LoggingContext
"""
# we track the current request
record.request = self.request
# we also track the current scope:
record.scope = self.scope
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
"""
Record that this logcontext is currently running.
@ -626,7 +630,12 @@ class PreserveLoggingContext:
def __enter__(self) -> None:
self._old_context = set_current_context(self._new_context)
def __exit__(self, type, value, traceback) -> None:
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
context = set_current_context(self._old_context)
if context != self._new_context:
@ -711,16 +720,61 @@ def nested_logging_context(suffix: str) -> LoggingContext:
)
def preserve_fn(f):
R = TypeVar("R")
@overload
def preserve_fn( # type: ignore[misc]
f: Callable[..., Awaitable[R]],
) -> Callable[..., "defer.Deferred[R]"]:
# The `type: ignore[misc]` above suppresses
# "Overloaded function signatures 1 and 2 overlap with incompatible return types"
...
@overload
def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]:
...
def preserve_fn(
f: Union[
Callable[..., R],
Callable[..., Awaitable[R]],
]
) -> Callable[..., "defer.Deferred[R]"]:
"""Function decorator which wraps the function with run_in_background"""
def g(*args, **kwargs):
def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]":
return run_in_background(f, *args, **kwargs)
return g
def run_in_background(f, *args, **kwargs) -> defer.Deferred:
@overload
def run_in_background( # type: ignore[misc]
f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
# The `type: ignore[misc]` above suppresses
# "Overloaded function signatures 1 and 2 overlap with incompatible return types"
...
@overload
def run_in_background(
f: Callable[..., R], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
...
def run_in_background(
f: Union[
Callable[..., R],
Callable[..., Awaitable[R]],
],
*args: Any,
**kwargs: Any,
) -> "defer.Deferred[R]":
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@ -751,6 +805,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
# `res` is not a `Deferred` and not a `Coroutine`.
# There are no other types of `Awaitable`s we expect to encounter in Synapse.
assert not isinstance(res, Awaitable)
return defer.succeed(res)
if res.called and not res.paused:
@ -778,13 +836,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
return res
def make_deferred_yieldable(deferred):
"""Given a deferred (or coroutine), make it follow the Synapse logcontext
rules:
T = TypeVar("T")
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
result/failure).
def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
"""Given a deferred, make it follow the Synapse logcontext rules:
If the deferred has completed, essentially does nothing (just returns another
completed deferred with the result/failure).
If the deferred has not yet completed, resets the logcontext before
returning a deferred. Then, when the deferred completes, restores the
@ -792,16 +851,6 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
if inspect.isawaitable(deferred):
# If we're given a coroutine we convert it to a deferred so that we
# run it and find out if it immediately finishes, it it does then we
# don't need to fiddle with log contexts at all and can return
# immediately.
deferred = defer.ensureDeferred(deferred)
if not isinstance(deferred, defer.Deferred):
return deferred
if deferred.called and not deferred.paused:
# it looks like this deferred is ready to run any callbacks we give it
# immediately. We may as well optimise out the logcontext faffery.
@ -823,7 +872,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
return result
def defer_to_thread(reactor, f, *args, **kwargs):
def defer_to_thread(
reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any
) -> "defer.Deferred[R]":
"""
Calls the function `f` using a thread from the reactor's default threadpool and
returns the result as a Deferred.
@ -855,7 +906,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
def defer_to_threadpool(
reactor: "ISynapseReactor",
threadpool: ThreadPool,
f: Callable[..., R],
*args: Any,
**kwargs: Any,
) -> "defer.Deferred[R]":
"""
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
logcontexts correctly.
@ -897,7 +954,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
def g():
def g() -> R:
with LoggingContext(str(curr_context), parent_context=parent_context):
return f(*args, **kwargs)