Appease mypy

This commit is contained in:
Erik Johnston 2019-10-10 12:15:17 +01:00
parent 791a8c559b
commit 941edad583

View File

@ -18,11 +18,17 @@ from __future__ import print_function
import functools import functools
import sys import sys
from typing import List, Callable, Any
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
# Tracks if we've already patched inlineCallbacks
_already_patched = False
def do_patch(): def do_patch():
""" """
Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
@ -30,16 +36,18 @@ def do_patch():
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
global _already_patched
orig_inline_callbacks = defer.inlineCallbacks orig_inline_callbacks = defer.inlineCallbacks
if hasattr(orig_inline_callbacks, "patched_by_synapse"): if _already_patched:
return return
def new_inline_callbacks(f): def new_inline_callbacks(f):
@functools.wraps(f) @functools.wraps(f)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
start_context = LoggingContext.current_context() start_context = LoggingContext.current_context()
changes = [] changes: List[str] = []
orig = orig_inline_callbacks(_check_yield_points(f, changes, start_context)) orig = orig_inline_callbacks(_check_yield_points(f, changes))
try: try:
res = orig(*args, **kwargs) res = orig(*args, **kwargs)
@ -101,10 +109,10 @@ def do_patch():
return wrapped return wrapped
defer.inlineCallbacks = new_inline_callbacks defer.inlineCallbacks = new_inline_callbacks
new_inline_callbacks.patched_by_synapse = True _already_patched = True
def _check_yield_points(f, changes, start_context): def _check_yield_points(f: Callable, changes: List[str]):
"""Wraps a generator that is about to be passed to defer.inlineCallbacks """Wraps a generator that is about to be passed to defer.inlineCallbacks
checking that after every yield the log contexts are correct. checking that after every yield the log contexts are correct.
@ -114,9 +122,8 @@ def _check_yield_points(f, changes, start_context):
Args: Args:
f: generator function to wrap f: generator function to wrap
changes (list[str]): A list of strings detailing how the contexts changes: A list of strings detailing how the contexts
changed within a function. changed within a function.
start_context (LoggingContext): The initial context we're expecting
Returns: Returns:
function function
@ -126,13 +133,13 @@ def _check_yield_points(f, changes, start_context):
@functools.wraps(f) @functools.wraps(f)
def check_yield_points_inner(*args, **kwargs): def check_yield_points_inner(*args, **kwargs):
expected_context = start_context
gen = f(*args, **kwargs) gen = f(*args, **kwargs)
last_yield_line_no = gen.gi_frame.f_lineno last_yield_line_no = gen.gi_frame.f_lineno
result = None result: Any = None
while True: while True:
expected_context = LoggingContext.current_context()
try: try:
isFailure = isinstance(result, Failure) isFailure = isinstance(result, Failure)
if isFailure: if isFailure:
@ -200,7 +207,7 @@ def _check_yield_points(f, changes, start_context):
"%s changed context from %s to %s, happened between lines %d and %d in %s" "%s changed context from %s to %s, happened between lines %d and %d in %s"
% ( % (
frame.f_code.co_name, frame.f_code.co_name,
start_context, expected_context,
LoggingContext.current_context(), LoggingContext.current_context(),
last_yield_line_no, last_yield_line_no,
frame.f_lineno, frame.f_lineno,
@ -209,8 +216,6 @@ def _check_yield_points(f, changes, start_context):
) )
changes.append(err) changes.append(err)
expected_context = LoggingContext.current_context()
last_yield_line_no = frame.f_lineno last_yield_line_no = frame.f_lineno
return check_yield_points_inner return check_yield_points_inner