Refactor the notifier.wait_for_events code to be clearer. Add _NotifierUserStream.new_listener that accpets a token to avoid races.

This commit is contained in:
Erik Johnston 2015-06-18 15:49:05 +01:00
parent 050ebccf30
commit 22049ea700
3 changed files with 72 additions and 69 deletions

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor, ObservableDeferred
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
@ -45,20 +45,16 @@ class _NotificationListener(object):
The events stream handler will have yielded to the deferred, so to The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred. notify the handler it is sufficient to resolve the deferred.
""" """
__slots__ = ["deferred"]
def __init__(self, deferred): def __init__(self, deferred):
self.deferred = deferred object.__setattr__(self, "deferred", deferred)
def notified(self): def __getattr__(self, name):
return self.deferred.called return getattr(self.deferred, name)
def notify(self, token): def __setattr__(self, name, value):
""" Inform whoever is listening about the new events. setattr(self.deferred, name, value)
"""
try:
self.deferred.callback(token)
except defer.AlreadyCalledError:
pass
class _NotifierUserStream(object): class _NotifierUserStream(object):
@ -75,11 +71,12 @@ class _NotifierUserStream(object):
appservice=None): appservice=None):
self.user = str(user) self.user = str(user)
self.appservice = appservice self.appservice = appservice
self.listeners = set()
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms): def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
event source. event source.
@ -91,12 +88,10 @@ class _NotifierUserStream(object):
self.current_token = self.current_token.copy_and_advance( self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id stream_key, stream_id
) )
if self.listeners:
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
listeners = self.listeners noify_deferred = self.notify_deferred
self.listeners = set() self.notify_deferred = ObservableDeferred(defer.Deferred())
for listener in listeners: noify_deferred.callback(self.current_token)
listener.notify(self.current_token)
def remove(self, notifier): def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier """ Remove this listener from all the indexes in the Notifier
@ -114,6 +109,18 @@ class _NotifierUserStream(object):
self.appservice, set() self.appservice, set()
).discard(self) ).discard(self)
def count_listeners(self):
return len(self.noify_deferred.observers())
def new_listener(self, token):
"""Returns a deferred that is resolved when there is a new token
greater than the given token.
"""
if self.current_token.is_after(token):
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
@ -158,7 +165,7 @@ class Notifier(object):
for x in self.appservice_to_user_streams.values(): for x in self.appservice_to_user_streams.values():
all_user_streams |= x all_user_streams |= x
return sum(len(stream.listeners) for stream in all_user_streams) return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners) metrics.register_callback("listeners", count_listeners)
metrics.register_callback( metrics.register_callback(
@ -286,10 +293,6 @@ class Notifier(object):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
deferred = defer.Deferred()
time_now_ms = self.clock.time_msec()
user = str(user) user = str(user)
user_stream = self.user_to_user_stream.get(user) user_stream = self.user_to_user_stream.get(user)
if user_stream is None: if user_stream is None:
@ -302,54 +305,38 @@ class Notifier(object):
rooms=rooms, rooms=rooms,
appservice=appservice, appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=time_now_ms, time_now_ms=self.clock.time_msec(),
) )
self._register_with_keys(user_stream) self._register_with_keys(user_stream)
else:
current_token = user_stream.current_token
result = None result = None
if current_token.is_after(from_token):
result = yield callback(from_token, current_token)
if result:
defer.returnValue(result)
if timeout: if timeout:
timer = [None] listener = None
listeners = [] timer = self.clock.call_later(
timed_out = [False] timeout/1000., lambda: listener.cancel()
)
def notify_listeners(): prev_token = from_token
user_stream.listeners.difference_update(listeners) while not result:
for listener in listeners:
listener.notify(current_token)
del listeners[:]
def _timeout_listener():
timed_out[0] = True
timer[0] = None
notify_listeners()
# We create multiple notification listeners so we have to manage
# canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]:
deferred = defer.Deferred()
notify_listeners()
listeners.append(_NotificationListener(deferred))
user_stream.listeners.update(listeners)
new_token = yield deferred
result = yield callback(current_token, new_token)
current_token = new_token
if timer[0] is not None:
try: try:
self.clock.cancel_call_later(timer[0]) # We need to start listening to the streams *before* doing
except: # the callback, as otherwise we may miss something.
logger.exception("Failed to cancel notifer timer") current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
if result:
break
prev_token = current_token
listener = user_stream.new_listener(prev_token)
yield listener.deferred
except defer.CancelledError:
break
self.clock.cancel_call_later(timer, ignore_errs=True)
else:
current_token = user_stream.current_token
result = yield callback(from_token, current_token)
defer.returnValue(result) defer.returnValue(result)
@ -367,6 +354,9 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
defer.returnValue(None)
events = [] events = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.items():
@ -401,7 +391,7 @@ class Notifier(object):
expired_streams = [] expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
for stream in self.user_to_user_stream.values(): for stream in self.user_to_user_stream.values():
if stream.listeners: if stream.count_listeners():
continue continue
if stream.last_notified_ms < expire_before_ts: if stream.last_notified_ms < expire_before_ts:
expired_streams.append(stream) expired_streams.append(stream)

View File

@ -91,8 +91,12 @@ class Clock(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback, *args, **kwargs) return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer): def cancel_call_later(self, timer, ignore_errs=False):
try:
timer.cancel() timer.cancel()
except:
if not ignore_errs:
raise
def time_bound_deferred(self, given_deferred, time_out): def time_bound_deferred(self, given_deferred, time_out):
if given_deferred.called: if given_deferred.called:

View File

@ -45,7 +45,7 @@ class ObservableDeferred(object):
def __init__(self, deferred, consumeErrors=False): def __init__(self, deferred, consumeErrors=False):
object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None) object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", []) object.__setattr__(self, "_observers", set())
def callback(r): def callback(r):
self._result = (True, r) self._result = (True, r)
@ -74,12 +74,21 @@ class ObservableDeferred(object):
def observe(self): def observe(self):
if not self._result: if not self._result:
d = defer.Deferred() d = defer.Deferred()
self._observers.append(d)
def remove(r):
self._observers.discard(d)
return r
d.addBoth(remove)
self._observers.add(d)
return d return d
else: else:
success, res = self._result success, res = self._result
return defer.succeed(res) if success else defer.fail(res) return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self._deferred, name) return getattr(self._deferred, name)