diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0cd5501b0..615720492 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -933,8 +933,9 @@ class FederationHandler(BaseHandler): # lots of requests for missing prev_events which we do actually # have. Hence we fire off the deferred, but don't wait for it. - synapse.util.logcontext.reset_context_after_deferred( - self._handle_queued_pdus(room_queue)) + synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)( + room_queue + ) defer.returnValue(True) diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index d73670f9f..7cbe390b1 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -308,47 +308,44 @@ def preserve_context_over_deferred(deferred, context=None): return d -def reset_context_after_deferred(deferred): - """If the deferred is incomplete, add a callback which will reset the - context. +def preserve_fn(f): + """Wraps a function, to ensure that the current context is restored after + return from the function, and that the sentinel context is set once the + deferred returned by the funtion completes. - This is useful when you want to fire off a deferred, but don't want to - wait for it to complete. (The deferred will restore the current log context - when it completes, so if you don't do anything, it will leak log context.) - - (If this feels asymmetric, consider it this way: we are effectively forking - a new thread of execution. We are probably currently within a - ``with LoggingContext()`` block, which is supposed to have a single entry - and exit point. But by spawning off another deferred, we are effectively - adding a new exit point.) - - Args: - deferred (defer.Deferred): deferred + Useful for wrapping functions that return a deferred which you don't yield + on. """ def reset_context(result): LoggingContext.set_current_context(LoggingContext.sentinel) return result - if not deferred.called: - deferred.addBoth(reset_context) - - -def preserve_fn(f): - """Ensures that function is called with correct context and that context is - restored after return. Useful for wrapping functions that return a deferred - which you don't yield on. - """ + # XXX: why is this here rather than inside g? surely we want to preserve + # the context from the time the function was called, not when it was + # wrapped? current = LoggingContext.current_context() def g(*args, **kwargs): - with PreserveLoggingContext(current): - res = f(*args, **kwargs) - if isinstance(res, defer.Deferred): - return preserve_context_over_deferred( - res, context=LoggingContext.sentinel - ) - else: - return res + res = f(*args, **kwargs) + if isinstance(res, defer.Deferred) and not res.called: + # The function will have reset the context before returning, so + # we need to restore it now. + LoggingContext.set_current_context(current) + + # The original context will be restored when the deferred + # completes, but there is nothing waiting for it, so it will + # get leaked into the reactor or some other function which + # wasn't expecting it. We therefore need to reset the context + # here. + # + # (If this feels asymmetric, consider it this way: we are + # effectively forking a new thread of execution. We are + # probably currently within a ``with LoggingContext()`` block, + # which is supposed to have a single entry and exit point. But + # by spawning off another deferred, we are effectively + # adding a new exit point.) + res.addBoth(reset_context) + return res return g diff --git a/tests/util/test_log_context.py b/tests/util/test_log_context.py index 65a330a0e..9ffe209c4 100644 --- a/tests/util/test_log_context.py +++ b/tests/util/test_log_context.py @@ -1,8 +1,10 @@ +import twisted.python.failure from twisted.internet import defer from twisted.internet import reactor from .. import unittest from synapse.util.async import sleep +from synapse.util import logcontext from synapse.util.logcontext import LoggingContext @@ -33,3 +35,62 @@ class LoggingContextTestCase(unittest.TestCase): context_one.test_key = "one" yield sleep(0) self._check_test_key("one") + + def _test_preserve_fn(self, function): + sentinel_context = LoggingContext.current_context() + + callback_completed = [False] + + @defer.inlineCallbacks + def cb(): + context_one.test_key = "one" + yield function() + self._check_test_key("one") + + callback_completed[0] = True + + with LoggingContext() as context_one: + context_one.test_key = "one" + + # fire off function, but don't wait on it. + logcontext.preserve_fn(cb)() + + self._check_test_key("one") + + # now wait for the function under test to have run, and check that + # the logcontext is left in a sane state. + d2 = defer.Deferred() + + def check_logcontext(): + if not callback_completed[0]: + reactor.callLater(0.01, check_logcontext) + return + + # make sure that the context was reset before it got thrown back + # into the reactor + try: + self.assertIs(LoggingContext.current_context(), + sentinel_context) + d2.callback(None) + except BaseException: + d2.errback(twisted.python.failure.Failure()) + + reactor.callLater(0.01, check_logcontext) + + # test is done once d2 finishes + return d2 + + def test_preserve_fn_with_blocking_fn(self): + @defer.inlineCallbacks + def blocking_function(): + yield sleep(0) + + return self._test_preserve_fn(blocking_function) + + def test_preserve_fn_with_non_blocking_fn(self): + @defer.inlineCallbacks + def nonblocking_function(): + with logcontext.PreserveLoggingContext(): + yield defer.succeed(None) + + return self._test_preserve_fn(nonblocking_function)