mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-02-28 03:11:12 -05:00
ObservableDeferred: run observers in order (#11229)
This commit is contained in:
parent
93aa670642
commit
46d0937447
1
changelog.d/11229.misc
Normal file
1
changelog.d/11229.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
`ObservableDeferred`: run registered observers in order.
|
@ -22,11 +22,11 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
Hashable,
|
Hashable,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@ -76,12 +76,17 @@ class ObservableDeferred(Generic[_T]):
|
|||||||
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
|
def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
|
||||||
object.__setattr__(self, "_deferred", deferred)
|
object.__setattr__(self, "_deferred", deferred)
|
||||||
object.__setattr__(self, "_result", None)
|
object.__setattr__(self, "_result", None)
|
||||||
object.__setattr__(self, "_observers", set())
|
object.__setattr__(self, "_observers", [])
|
||||||
|
|
||||||
def callback(r):
|
def callback(r):
|
||||||
object.__setattr__(self, "_result", (True, r))
|
object.__setattr__(self, "_result", (True, r))
|
||||||
while self._observers:
|
|
||||||
observer = self._observers.pop()
|
# once we have set _result, no more entries will be added to _observers,
|
||||||
|
# so it's safe to replace it with the empty tuple.
|
||||||
|
observers = self._observers
|
||||||
|
object.__setattr__(self, "_observers", ())
|
||||||
|
|
||||||
|
for observer in observers:
|
||||||
try:
|
try:
|
||||||
observer.callback(r)
|
observer.callback(r)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -95,12 +100,16 @@ class ObservableDeferred(Generic[_T]):
|
|||||||
|
|
||||||
def errback(f):
|
def errback(f):
|
||||||
object.__setattr__(self, "_result", (False, f))
|
object.__setattr__(self, "_result", (False, f))
|
||||||
while self._observers:
|
|
||||||
|
# once we have set _result, no more entries will be added to _observers,
|
||||||
|
# so it's safe to replace it with the empty tuple.
|
||||||
|
observers = self._observers
|
||||||
|
object.__setattr__(self, "_observers", ())
|
||||||
|
|
||||||
|
for observer in observers:
|
||||||
# This is a little bit of magic to correctly propagate stack
|
# This is a little bit of magic to correctly propagate stack
|
||||||
# traces when we `await` on one of the observer deferreds.
|
# traces when we `await` on one of the observer deferreds.
|
||||||
f.value.__failure__ = f
|
f.value.__failure__ = f
|
||||||
|
|
||||||
observer = self._observers.pop()
|
|
||||||
try:
|
try:
|
||||||
observer.errback(f)
|
observer.errback(f)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -127,20 +136,13 @@ class ObservableDeferred(Generic[_T]):
|
|||||||
"""
|
"""
|
||||||
if not self._result:
|
if not self._result:
|
||||||
d: "defer.Deferred[_T]" = defer.Deferred()
|
d: "defer.Deferred[_T]" = defer.Deferred()
|
||||||
|
self._observers.append(d)
|
||||||
def remove(r):
|
|
||||||
self._observers.discard(d)
|
|
||||||
return r
|
|
||||||
|
|
||||||
d.addBoth(remove)
|
|
||||||
|
|
||||||
self._observers.add(d)
|
|
||||||
return d
|
return d
|
||||||
else:
|
else:
|
||||||
success, res = self._result
|
success, res = self._result
|
||||||
return defer.succeed(res) if success else defer.fail(res)
|
return defer.succeed(res) if success else defer.fail(res)
|
||||||
|
|
||||||
def observers(self) -> "List[defer.Deferred[_T]]":
|
def observers(self) -> "Collection[defer.Deferred[_T]]":
|
||||||
return self._observers
|
return self._observers
|
||||||
|
|
||||||
def has_called(self) -> bool:
|
def has_called(self) -> bool:
|
||||||
|
@ -47,9 +47,7 @@ class DeferredCacheTestCase(TestCase):
|
|||||||
self.assertTrue(set_d.called)
|
self.assertTrue(set_d.called)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
# TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
|
get_d.addCallback(check1)
|
||||||
# maybe we should fix that?
|
|
||||||
# get_d.addCallback(check1)
|
|
||||||
|
|
||||||
# now fire off all the deferreds
|
# now fire off all the deferreds
|
||||||
origin_d.callback(99)
|
origin_d.callback(99)
|
||||||
|
@ -21,11 +21,78 @@ from synapse.logging.context import (
|
|||||||
PreserveLoggingContext,
|
PreserveLoggingContext,
|
||||||
current_context,
|
current_context,
|
||||||
)
|
)
|
||||||
from synapse.util.async_helpers import timeout_deferred
|
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
|
||||||
|
|
||||||
from tests.unittest import TestCase
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class ObservableDeferredTest(TestCase):
|
||||||
|
def test_succeed(self):
|
||||||
|
origin_d = Deferred()
|
||||||
|
observable = ObservableDeferred(origin_d)
|
||||||
|
|
||||||
|
observer1 = observable.observe()
|
||||||
|
observer2 = observable.observe()
|
||||||
|
|
||||||
|
self.assertFalse(observer1.called)
|
||||||
|
self.assertFalse(observer2.called)
|
||||||
|
|
||||||
|
# check the first observer is called first
|
||||||
|
def check_called_first(res):
|
||||||
|
self.assertFalse(observer2.called)
|
||||||
|
return res
|
||||||
|
|
||||||
|
observer1.addBoth(check_called_first)
|
||||||
|
|
||||||
|
# store the results
|
||||||
|
results = [None, None]
|
||||||
|
|
||||||
|
def check_val(res, idx):
|
||||||
|
results[idx] = res
|
||||||
|
return res
|
||||||
|
|
||||||
|
observer1.addCallback(check_val, 0)
|
||||||
|
observer2.addCallback(check_val, 1)
|
||||||
|
|
||||||
|
origin_d.callback(123)
|
||||||
|
self.assertEqual(results[0], 123, "observer 1 callback result")
|
||||||
|
self.assertEqual(results[1], 123, "observer 2 callback result")
|
||||||
|
|
||||||
|
def test_failure(self):
|
||||||
|
origin_d = Deferred()
|
||||||
|
observable = ObservableDeferred(origin_d, consumeErrors=True)
|
||||||
|
|
||||||
|
observer1 = observable.observe()
|
||||||
|
observer2 = observable.observe()
|
||||||
|
|
||||||
|
self.assertFalse(observer1.called)
|
||||||
|
self.assertFalse(observer2.called)
|
||||||
|
|
||||||
|
# check the first observer is called first
|
||||||
|
def check_called_first(res):
|
||||||
|
self.assertFalse(observer2.called)
|
||||||
|
return res
|
||||||
|
|
||||||
|
observer1.addBoth(check_called_first)
|
||||||
|
|
||||||
|
# store the results
|
||||||
|
results = [None, None]
|
||||||
|
|
||||||
|
def check_val(res, idx):
|
||||||
|
results[idx] = res
|
||||||
|
return None
|
||||||
|
|
||||||
|
observer1.addErrback(check_val, 0)
|
||||||
|
observer2.addErrback(check_val, 1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
raise Exception("gah!")
|
||||||
|
except Exception as e:
|
||||||
|
origin_d.errback(e)
|
||||||
|
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
|
||||||
|
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
|
||||||
|
|
||||||
|
|
||||||
class TimeoutDeferredTest(TestCase):
|
class TimeoutDeferredTest(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.clock = Clock()
|
self.clock = Clock()
|
Loading…
x
Reference in New Issue
Block a user