diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 9605d7d1b..9cffaec8f 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -32,7 +32,9 @@ class Distributor(object): model will do for today. """ - def __init__(self): + def __init__(self, suppress_failures=True): + self.suppress_failures = suppress_failures + self.signals = {} self.pre_registration = {} @@ -40,7 +42,9 @@ class Distributor(object): if name in self.signals: raise KeyError("%r already has a signal named %s" % (self, name)) - self.signals[name] = Signal(name) + self.signals[name] = Signal(name, + suppress_failures=self.suppress_failures, + ) if name in self.pre_registration: signal = self.signals[name] @@ -74,8 +78,9 @@ class Signal(object): method into all of the observers. """ - def __init__(self, name): + def __init__(self, name, suppress_failures): self.name = name + self.suppress_failures = suppress_failures self.observers = [] def observe(self, observer): @@ -104,6 +109,10 @@ class Signal(object): failure.type, failure.value, failure.getTracebackObject())) + if not self.suppress_failures: + raise failure deferreds.append(d.addErrback(eb)) - return defer.DeferredList(deferreds) + return defer.DeferredList( + deferreds, fireOnOneErrback=not self.suppress_failures + ) diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 21c91f335..2869fdfd7 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -13,9 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - from twisted.internet import defer +from twisted.trial import unittest from mock import Mock, patch @@ -75,6 +74,24 @@ class DistributorTestCase(unittest.TestCase): self.assertIsInstance(mock_logger.warning.call_args[0][0], str) + @defer.inlineCallbacks + def test_signal_catch_no_suppress(self): + # Gut-wrenching + self.dist.suppress_failures = False + + self.dist.declare("whail") + + observer = Mock() + observer.return_value = defer.fail( + Exception("Oopsie") + ) + + self.dist.observe("whail", observer) + + d = self.dist.fire("whail") + + yield self.assertFailure(d, Exception) + def test_signal_prereg(self): observer = Mock() self.dist.observe("flare", observer) @@ -85,5 +102,6 @@ class DistributorTestCase(unittest.TestCase): observer.assert_called_with(4, 5) def test_signal_undeclared(self): - with self.assertRaises(KeyError): + def code(): self.dist.fire("notification") + self.assertRaises(KeyError, code)