Merge pull request #507 from matrix-org/push_badge_counts

Push badge counts
This commit is contained in:
David Baker 2016-01-21 10:09:11 +00:00
commit c1a3021771
5 changed files with 109 additions and 67 deletions

View File

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.api.constants import Membership
import synapse.util.async import synapse.util.async
import push_rule_evaluator as push_rule_evaluator import push_rule_evaluator as push_rule_evaluator
@ -55,6 +56,7 @@ class Pusher(object):
self.backoff_delay = Pusher.INITIAL_BACKOFF self.backoff_delay = Pusher.INITIAL_BACKOFF
self.failing_since = failing_since self.failing_since = failing_since
self.alive = True self.alive = True
self.badge = None
# The last value of last_active_time that we saw # The last value of last_active_time that we saw
self.last_last_active_time = 0 self.last_last_active_time = 0
@ -92,8 +94,7 @@ class Pusher(object):
# we fail to dispatch the push) # we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1') config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream( chunk = yield self.evStreamHandler.get_stream(
self.user_id, config, timeout=0, affect_presence=False, self.user_id, config, timeout=0, affect_presence=False
only_room_events=True
) )
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token( self.store.update_pusher_last_token(
@ -124,9 +125,11 @@ class Pusher(object):
from_tok = StreamToken.from_string(self.last_token) from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1') config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000 timeout = (300 + random.randint(-60, 60)) * 1000
# note that we need to get read receipts down the stream as we need to
# wake up when one arrives. we don't need to explicitly look for
# them though.
chunk = yield self.evStreamHandler.get_stream( chunk = yield self.evStreamHandler.get_stream(
self.user_id, config, timeout=timeout, affect_presence=False, self.user_id, config, timeout=timeout, affect_presence=False
only_room_events=True
) )
# limiting to 1 may get 1 event plus 1 presence event, so # limiting to 1 may get 1 event plus 1 presence event, so
@ -135,10 +138,10 @@ class Pusher(object):
for c in chunk['chunk']: for c in chunk['chunk']:
if 'event_id' in c: # Hmmm... if 'event_id' in c: # Hmmm...
single_event = c single_event = c
break
if not single_event: if not single_event:
yield self.update_badge()
self.last_token = chunk['end'] self.last_token = chunk['end']
logger.debug("Event stream timeout for pushkey %s", self.pushkey)
yield self.store.update_pusher_last_token( yield self.store.update_pusher_last_token(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
@ -161,7 +164,8 @@ class Pusher(object):
tweaks = rule_evaluator.tweaks_for_actions(actions) tweaks = rule_evaluator.tweaks_for_actions(actions)
if 'notify' in actions: if 'notify' in actions:
rejected = yield self.dispatch_push(single_event, tweaks) self.badge = yield self._get_badge_count()
rejected = yield self.dispatch_push(single_event, tweaks, self.badge)
self.has_unread = True self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple): if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True processed = True
@ -181,7 +185,6 @@ class Pusher(object):
yield self.hs.get_pusherpool().remove_pusher( yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk, self.user_id self.app_id, pk, self.user_id
) )
else:
processed = True processed = True
if not self.alive: if not self.alive:
@ -254,7 +257,7 @@ class Pusher(object):
def stop(self): def stop(self):
self.alive = False self.alive = False
def dispatch_push(self, p, tweaks): def dispatch_push(self, p, tweaks, badge):
""" """
Overridden by implementing classes to actually deliver the notification Overridden by implementing classes to actually deliver the notification
Args: Args:
@ -266,23 +269,47 @@ class Pusher(object):
""" """
pass pass
def reset_badge_count(self): @defer.inlineCallbacks
def update_badge(self):
new_badge = yield self._get_badge_count()
if self.badge != new_badge:
self.badge = new_badge
yield self.send_badge(self.badge)
def send_badge(self, badge):
"""
Overridden by implementing classes to send an updated badge count
"""
pass pass
def presence_changed(self, state): @defer.inlineCallbacks
""" def _get_badge_count(self):
We clear badge counts whenever a user's last_active time is bumped room_list = yield self.store.get_rooms_for_user_where_membership_is(
This is by no means perfect but I think it's the best we can do user_id=self.user_id,
without read receipts. membership_list=(Membership.INVITE, Membership.JOIN)
""" )
if 'last_active' in state.state:
last_active = state.state['last_active'] my_receipts_by_room = yield self.store.get_receipts_for_user(
if last_active > self.last_last_active_time: self.user_id,
self.last_last_active_time = last_active "m.read",
if self.has_unread: )
logger.info("Resetting badge count for %s", self.user_id)
self.reset_badge_count() badge = 0
self.has_unread = False
for r in room_list:
if r.membership == Membership.INVITE:
badge += 1
else:
if r.room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[r.room_id]
notifs = yield (
self.store.get_unread_event_push_actions_by_room_for_user(
r.room_id, self.user_id, last_unread_event_id
)
)
badge += len(notifs)
defer.returnValue(badge)
class PusherConfigException(Exception): class PusherConfigException(Exception):

View File

@ -51,7 +51,7 @@ class HttpPusher(Pusher):
del self.data_minus_url['url'] del self.data_minus_url['url']
@defer.inlineCallbacks @defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks): def _build_notification_dict(self, event, tweaks, badge):
# we probably do not want to push for every presence update # we probably do not want to push for every presence update
# (we may want to be able to set up notifications when specific # (we may want to be able to set up notifications when specific
# people sign in, but we'd want to only deliver the pertinent ones) # people sign in, but we'd want to only deliver the pertinent ones)
@ -71,7 +71,7 @@ class HttpPusher(Pusher):
'counts': { # -- we don't mark messages as read yet so 'counts': { # -- we don't mark messages as read yet so
# we have no way of knowing # we have no way of knowing
# Just set the badge to 1 until we have read receipts # Just set the badge to 1 until we have read receipts
'unread': 1, 'unread': badge,
# 'missed_calls': 2 # 'missed_calls': 2
}, },
'devices': [ 'devices': [
@ -101,8 +101,8 @@ class HttpPusher(Pusher):
defer.returnValue(d) defer.returnValue(d)
@defer.inlineCallbacks @defer.inlineCallbacks
def dispatch_push(self, event, tweaks): def dispatch_push(self, event, tweaks, badge):
notification_dict = yield self._build_notification_dict(event, tweaks) notification_dict = yield self._build_notification_dict(event, tweaks, badge)
if not notification_dict: if not notification_dict:
defer.returnValue([]) defer.returnValue([])
try: try:
@ -116,15 +116,15 @@ class HttpPusher(Pusher):
defer.returnValue(rejected) defer.returnValue(rejected)
@defer.inlineCallbacks @defer.inlineCallbacks
def reset_badge_count(self): def send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_id)
d = { d = {
'notification': { 'notification': {
'id': '', 'id': '',
'type': None, 'type': None,
'sender': '', 'sender': '',
'counts': { 'counts': {
'unread': 0, 'unread': badge
'missed_calls': 0
}, },
'devices': [ 'devices': [
{ {

View File

@ -31,21 +31,6 @@ class PusherPool:
self.pushers = {} self.pushers = {}
self.last_pusher_started = -1 self.last_pusher_started = -1
distributor = self.hs.get_distributor()
distributor.observe(
"user_presence_changed", self.user_presence_changed
)
@defer.inlineCallbacks
def user_presence_changed(self, user, state):
user_id = user.to_string()
# until we have read receipts, pushers use this to reset a user's
# badge counters to zero
for p in self.pushers.values():
if p.user_id == user_id:
yield p.presence_changed(state)
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
pushers = yield self.store.get_all_pushers() pushers = yield self.store.get_all_pushers()

View File

@ -45,6 +45,21 @@ class ReceiptsStore(SQLBaseStore):
desc="get_receipts_for_room", desc="get_receipts_for_room",
) )
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
def f(txn):
sql = (
"SELECT room_id,event_id "
"FROM receipts_linearized "
"WHERE user_id = ? AND receipt_type = ? "
)
txn.execute(sql, (user_id, receipt_type))
return txn.fetchall()
defer.returnValue(dict(
(yield self.runInteraction("get_receipts_for_user", f))
))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients. """Get receipts for multiple rooms for sending to clients.
@ -194,29 +209,16 @@ class ReceiptsStore(SQLBaseStore):
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self) return self._receipts_id_gen.get_max_token(self)
@cachedInlineCallbacks()
def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers.
"""
rows = yield self._simple_select_list(
table="receipts_graph",
keyvalues={"room_id": room_id},
retcols=["receipt_type", "user_id", "event_id"],
desc="get_linearized_receipts_for_room",
)
result = {}
for row in rows:
result.setdefault(
row["user_id"], {}
).setdefault(
row["receipt_type"], []
).append(row["event_id"])
defer.returnValue(result)
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id): user_id, event_id, data, stream_id):
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
# We don't want to clobber receipts for more recent events, so we # We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts # have to compare orderings of existing receipts
@ -324,6 +326,7 @@ class ReceiptsStore(SQLBaseStore):
) )
max_persisted_id = yield self._stream_id_gen.get_max_token(self) max_persisted_id = yield self._stream_id_gen.get_max_token(self)
defer.returnValue((stream_id, max_persisted_id)) defer.returnValue((stream_id, max_persisted_id))
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
@ -336,6 +339,15 @@ class ReceiptsStore(SQLBaseStore):
def insert_graph_receipt_txn(self, txn, room_id, receipt_type, def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_ids, data): user_id, event_ids, data):
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
)
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
self._simple_delete_txn( self._simple_delete_txn(
txn, txn,
table="receipts_graph", table="receipts_graph",

View File

@ -0,0 +1,18 @@
/* Copyright 2015, 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX receipts_linearized_user ON receipts_linearized(
user_id
);