Re-implement unread counts (again) (#8059)

This commit is contained in:
Brendan Abolivier 2020-09-02 17:19:37 +01:00 committed by GitHub
parent 0d4f614fda
commit 5a1dd297c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 457 additions and 122 deletions

1
changelog.d/8059.feature Normal file
View File

@ -0,0 +1 @@
Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).

View File

@ -95,7 +95,12 @@ class TimelineBatch:
__bool__ = __nonzero__ # python3 __bool__ = __nonzero__ # python3
@attr.s(slots=True, frozen=True) # We can't freeze this class, because we need to update it after it's instantiated to
# update its unread count. This is because we calculate the unread count for a room only
# if there are updates for it, which we check after the instance has been created.
# This should not be a big deal because we update the notification counts afterwards as
# well anyway.
@attr.s(slots=True)
class JoinedSyncResult: class JoinedSyncResult:
room_id = attr.ib(type=str) room_id = attr.ib(type=str)
timeline = attr.ib(type=TimelineBatch) timeline = attr.ib(type=TimelineBatch)
@ -104,6 +109,7 @@ class JoinedSyncResult:
account_data = attr.ib(type=List[JsonDict]) account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict) unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict]) summary = attr.ib(type=Optional[JsonDict])
unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool: def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used """Make the result appear empty if there are no updates. This is used
@ -931,7 +937,7 @@ class SyncHandler(object):
async def unread_notifs_for_room_id( async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig self, room_id: str, sync_config: SyncConfig
) -> Optional[Dict[str, str]]: ) -> Dict[str, int]:
with Measure(self.clock, "unread_notifs_for_room_id"): with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
@ -939,15 +945,10 @@ class SyncHandler(object):
receipt_type="m.read", receipt_type="m.read",
) )
if last_unread_event_id: notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
notifs = await self.store.get_unread_event_push_actions_by_room_for_user( room_id, sync_config.user.to_string(), last_unread_event_id
room_id, sync_config.user.to_string(), last_unread_event_id )
) return notifs
return notifs
# There is no new information in this period, so your notification
# count is whatever it was last time.
return None
async def generate_sync_result( async def generate_sync_result(
self, self,
@ -1886,7 +1887,7 @@ class SyncHandler(object):
) )
if room_builder.rtype == "joined": if room_builder.rtype == "joined":
unread_notifications = {} # type: Dict[str, str] unread_notifications = {} # type: Dict[str, int]
room_sync = JoinedSyncResult( room_sync = JoinedSyncResult(
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
@ -1895,14 +1896,16 @@ class SyncHandler(object):
account_data=account_data_events, account_data=account_data_events,
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
summary=summary, summary=summary,
unread_count=0,
) )
if room_sync or always_include: if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config) notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
if notifs is not None: unread_notifications["notification_count"] = notifs["notify_count"]
unread_notifications["notification_count"] = notifs["notify_count"] unread_notifications["highlight_count"] = notifs["highlight_count"]
unread_notifications["highlight_count"] = notifs["highlight_count"]
room_sync.unread_count = notifs["unread_count"]
sync_result_builder.joined.append(room_sync) sync_result_builder.joined.append(room_sync)

View File

@ -19,8 +19,10 @@ from collections import namedtuple
from prometheus_client import Counter from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.event_auth import get_user_power_level from synapse.event_auth import get_user_power_level
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache(
) )
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
EventTypes.Topic,
EventTypes.Name,
EventTypes.RoomAvatar,
EventTypes.Tombstone,
}
def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
# Exclude rejected and soft-failed events.
if context.rejected or event.internal_metadata.is_soft_failed():
return False
# Exclude notices.
if (
not event.is_state()
and event.type == EventTypes.Message
and event.content.get("msgtype") == "m.notice"
):
return False
# Exclude edits.
relates_to = event.content.get("m.relates_to", {})
if relates_to.get("rel_type") == RelationTypes.REPLACE:
return False
# Mark events that have a non-empty string body as unread.
body = event.content.get("body")
if isinstance(body, str) and body:
return True
# Mark some state events as unread.
if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
return True
# Mark encrypted events as unread.
if not event.is_state() and event.type == EventTypes.Encrypted:
return True
return False
class BulkPushRuleEvaluator(object): class BulkPushRuleEvaluator(object):
"""Calculates the outcome of push rules for an event for all users in the """Calculates the outcome of push rules for an event for all users in the
room at once. room at once.
@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object):
return pl_event.content if pl_event else {}, sender_level return pl_event.content if pl_event else {}, sender_level
async def action_for_event_by_user(self, event, context) -> None: async def action_for_event_by_user(self, event, context) -> None:
"""Given an event and context, evaluate the push rules and insert the """Given an event and context, evaluate the push rules, check if the message
results into the event_push_actions_staging table. should increment the unread count, and insert the results into the
event_push_actions_staging table.
""" """
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context) rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {} actions_by_user = {}
@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object):
if event.type == EventTypes.Member and event.state_key == uid: if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None) display_name = event.content.get("displayname", None)
actions_by_user[uid] = []
for rule in rules: for rule in rules:
if "enabled" in rule and not rule["enabled"]: if "enabled" in rule and not rule["enabled"]:
continue continue
@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be # Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist # notified for this event. (This will then get handled when we persist
# the event) # the event)
await self.store.add_push_actions_to_staging(event.event_id, actions_by_user) await self.store.add_push_actions_to_staging(
event.event_id, actions_by_user, count_as_unread,
)
def _condition_checker(evaluator, conditions, uid, display_name, cache): def _condition_checker(evaluator, conditions, uid, display_name, cache):
@ -369,8 +420,8 @@ class RulesForRoom(object):
Args: Args:
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
updated with any new rules. updated with any new rules.
member_event_ids (list): List of event ids for membership events that member_event_ids (dict): Dict of user id to event id for membership events
have happened since the last time we filled rules_by_user that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules state_group: The state group we are currently computing push rules
for. Used when updating the cache. for. Used when updating the cache.
""" """
@ -390,34 +441,19 @@ class RulesForRoom(object):
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values()) logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = { user_ids = {
user_id user_id
for user_id, membership in members.values() for user_id, membership in members.values()
if membership == Membership.JOIN if membership == Membership.JOIN
} }
logger.debug("Joined: %r", interested_in_user_ids) logger.debug("Joined: %r", user_ids)
if_users_with_pushers = await self.store.get_if_users_have_pushers( # Previously we only considered users with pushers or read receipts in that
interested_in_user_ids, on_invalidate=self.invalidate_all_cb # room. We can't do this anymore because we use push actions to calculate unread
) # counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
user_ids = { user_ids = list(filter(self.is_mine_id, user_ids))
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
}
logger.debug("With pushers: %r", user_ids)
users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb
)
logger.debug("With receipts: %r", users_with_receipts)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in interested_in_user_ids:
user_ids.add(uid)
rules_by_user = await self.store.bulk_get_push_rules( rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb user_ids, on_invalidate=self.invalidate_all_cb

View File

@ -36,7 +36,7 @@ async def get_badge_count(store, user_id):
) )
# return one badge count per conversation, as count per # return one badge count per conversation, as count per
# message is so noisy as to be almost useless # message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0 badge += 1 if notifs["unread_count"] else 0
return badge return badge

View File

@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events} result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary result["summary"] = room.summary
result["org.matrix.msc2654.unread_count"] = room.unread_count
return result return result

View File

@ -15,7 +15,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, List, Union from typing import Dict, List, Optional, Tuple, Union
import attr
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True, max_entries=5000) @cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user( async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id self, room_id: str, user_id: str, last_read_event_id: Optional[str],
): ) -> Dict[str, int]:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
Note that this function assumes the user to be a current member of the room,
since it's either called by the sync handler to handle joined room entries, or by
the HTTP pusher to calculate the badge of unread joined rooms.
Args:
room_id: The room to retrieve the counts in.
user_id: The user to retrieve the counts for.
last_read_event_id: The event associated with the latest read receipt for
this user in this room. None if no receipt for this user in this room.
Returns
A dict containing the counts mentioned earlier in this docstring,
respectively under the keys "notify_count", "highlight_count" and
"unread_count".
"""
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room", "get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn, self._get_unread_counts_by_receipt_txn,
@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
) )
def _get_unread_counts_by_receipt_txn( def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id self, txn, room_id, user_id, last_read_event_id,
): ):
sql = ( stream_ordering = None
"SELECT stream_ordering"
" FROM events"
" WHERE room_id = ? AND event_id = ?"
)
txn.execute(sql, (room_id, last_read_event_id))
results = txn.fetchall()
if len(results) == 0:
return {"notify_count": 0, "highlight_count": 0}
stream_ordering = results[0][0] if last_read_event_id is not None:
stream_ordering = self.get_stream_id_for_event_txn(
txn, last_read_event_id, allow_none=True,
)
if stream_ordering is None:
# Either last_read_event_id is None, or it's an event we don't have (e.g.
# because it's been purged), in which case retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
retcol="event_id",
)
stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn( return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering txn, room_id, user_id, stream_ordering
) )
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
# First get number of notifications.
# We don't need to put a notif=1 clause as all rows always have
# notif=1
sql = ( sql = (
"SELECT count(*)" "SELECT"
" COUNT(CASE WHEN notif = 1 THEN 1 END),"
" COUNT(CASE WHEN highlight = 1 THEN 1 END),"
" COUNT(CASE WHEN unread = 1 THEN 1 END)"
" FROM event_push_actions ea" " FROM event_push_actions ea"
" WHERE" " WHERE user_id = ?"
" user_id = ?" " AND room_id = ?"
" AND room_id = ?" " AND stream_ordering > ?"
" AND stream_ordering > ?"
) )
txn.execute(sql, (user_id, room_id, stream_ordering)) txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone() row = txn.fetchone()
notify_count = row[0] if row else 0
(notif_count, highlight_count, unread_count) = (0, 0, 0)
if row:
(notif_count, highlight_count, unread_count) = row
txn.execute( txn.execute(
""" """
SELECT notif_count FROM event_push_summary SELECT notif_count, unread_count FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ? WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""", """,
(room_id, user_id, stream_ordering), (room_id, user_id, stream_ordering),
) )
rows = txn.fetchall()
if rows:
notify_count += rows[0][0]
# Now get the number of highlights
sql = (
"SELECT count(*)"
" FROM event_push_actions ea"
" WHERE"
" highlight = 1"
" AND user_id = ?"
" AND room_id = ?"
" AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone() row = txn.fetchone()
highlight_count = row[0] if row else 0
return {"notify_count": notify_count, "highlight_count": highlight_count} if row:
notif_count += row[0]
unread_count += row[1]
return {
"notify_count": notif_count,
"unread_count": unread_count,
"highlight_count": highlight_count,
}
async def get_push_action_users_in_range( async def get_push_action_users_in_range(
self, min_stream_ordering, max_stream_ordering self, min_stream_ordering, max_stream_ordering
@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?" " AND ep.user_id = ?"
" AND ep.stream_ordering > ?" " AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?" " AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?" " ORDER BY ep.stream_ordering ASC LIMIT ?"
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?" " AND ep.user_id = ?"
" AND ep.stream_ordering > ?" " AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?" " AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?" " ORDER BY ep.stream_ordering ASC LIMIT ?"
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?" " AND ep.user_id = ?"
" AND ep.stream_ordering > ?" " AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?" " AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?" " ORDER BY ep.stream_ordering DESC LIMIT ?"
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?" " AND ep.user_id = ?"
" AND ep.stream_ordering > ?" " AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?" " AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?" " ORDER BY ep.stream_ordering DESC LIMIT ?"
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@ -402,7 +428,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_if_maybe_push_in_range_for_user_txn(txn): def _get_if_maybe_push_in_range_for_user_txn(txn):
sql = """ sql = """
SELECT 1 FROM event_push_actions SELECT 1 FROM event_push_actions
WHERE user_id = ? AND stream_ordering > ? WHERE user_id = ? AND stream_ordering > ? AND notif = 1
LIMIT 1 LIMIT 1
""" """
@ -415,7 +441,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
) )
async def add_push_actions_to_staging( async def add_push_actions_to_staging(
self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]] self,
event_id: str,
user_id_actions: Dict[str, List[Union[dict, str]]],
count_as_unread: bool,
) -> None: ) -> None:
"""Add the push actions for the event to the push action staging area. """Add the push actions for the event to the push action staging area.
@ -423,21 +452,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
event_id event_id
user_id_actions: A mapping of user_id to list of push actions, where user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict. an action can either be a string or dict.
count_as_unread: Whether this event should increment unread counts.
""" """
if not user_id_actions: if not user_id_actions:
return return
# This is a helper function for generating the necessary tuple that # This is a helper function for generating the necessary tuple that
# can be used to inert into the `event_push_actions_staging` table. # can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(user_id, actions): def _gen_entry(user_id, actions):
is_highlight = 1 if _action_has_highlight(actions) else 0 is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return ( return (
event_id, # event_id column event_id, # event_id column
user_id, # user_id column user_id, # user_id column
_serialize_action(actions, is_highlight), # actions column _serialize_action(actions, is_highlight), # actions column
1, # notif column notif, # notif column
is_highlight, # highlight column is_highlight, # highlight column
int(count_as_unread), # unread column
) )
def _add_push_actions_to_staging_txn(txn): def _add_push_actions_to_staging_txn(txn):
@ -446,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql = """ sql = """
INSERT INTO event_push_actions_staging INSERT INTO event_push_actions_staging
(event_id, user_id, actions, notif, highlight) (event_id, user_id, actions, notif, highlight, unread)
VALUES (?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
""" """
txn.executemany( txn.executemany(
@ -811,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# Calculate the new counts that should be upserted into event_push_summary # Calculate the new counts that should be upserted into event_push_summary
sql = """ sql = """
SELECT user_id, room_id, SELECT user_id, room_id,
coalesce(old.notif_count, 0) + upd.notif_count, coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering, upd.stream_ordering,
old.user_id old.user_id
FROM ( FROM (
SELECT user_id, room_id, count(*) as notif_count, SELECT user_id, room_id, count(*) as cnt,
max(stream_ordering) as stream_ordering max(stream_ordering) as stream_ordering
FROM event_push_actions FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ? WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0 AND highlight = 0
AND %s = 1
GROUP BY user_id, room_id GROUP BY user_id, room_id
) AS upd ) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id) LEFT JOIN event_push_summary AS old USING (user_id, room_id)
""" """
txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering)) # First get the count of unread messages.
rows = txn.fetchall() txn.execute(
sql % ("unread_count", "unread"),
(old_rotate_stream_ordering, rotate_to_stream_ordering),
)
logger.info("Rotating notifications, handling %d rows", len(rows)) # We need to merge results from the two requests (the one that retrieves the
# unread count and the one that retrieves the notifications count) into a single
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2],
stream_ordering=row[3],
old_user_id=row[4],
notif_count=0,
)
# Then get the count of notifications.
txn.execute(
sql % ("notif_count", "notif"),
(old_rotate_stream_ordering, rotate_to_stream_ordering),
)
for row in txn:
if (row[0], row[1]) in summaries:
summaries[(row[0], row[1])].notif_count = row[2]
else:
# Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room)
# tuple to complete.
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=0,
stream_ordering=row[3],
old_user_id=row[4],
notif_count=row[2],
)
logger.info("Rotating notifications, handling %d rows", len(summaries))
# If the `old.user_id` above is NULL then we know there isn't already an # If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the # entry in the table, so we simply insert it. Otherwise we update the
@ -838,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
table="event_push_summary", table="event_push_summary",
values=[ values=[
{ {
"user_id": row[0], "user_id": user_id,
"room_id": row[1], "room_id": room_id,
"notif_count": row[2], "notif_count": summary.notif_count,
"stream_ordering": row[3], "unread_count": summary.unread_count,
"stream_ordering": summary.stream_ordering,
} }
for row in rows for ((user_id, room_id), summary) in summaries.items()
if row[4] is None if summary.old_user_id is None
], ],
) )
txn.executemany( txn.executemany(
""" """
UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ? WHERE user_id = ? AND room_id = ?
""", """,
((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None), (
(
summary.notif_count,
summary.unread_count,
summary.stream_ordering,
user_id,
room_id,
)
for ((user_id, room_id), summary) in summaries.items()
if summary.old_user_id is not None
),
) )
txn.execute( txn.execute(
@ -879,3 +961,15 @@ def _action_has_highlight(actions):
pass pass
return False return False
@attr.s
class _EventPushSummary:
"""Summary of pending event push actions for a given user in a given room.
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
"""
unread_count = attr.ib(type=int)
stream_ordering = attr.ib(type=int)
old_user_id = attr.ib(type=str)
notif_count = attr.ib(type=int)

View File

@ -1298,9 +1298,9 @@ class PersistEventsStore:
sql = """ sql = """
INSERT INTO event_push_actions ( INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering, room_id, event_id, user_id, actions, stream_ordering,
topological_ordering, notif, highlight topological_ordering, notif, highlight, unread
) )
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
FROM event_push_actions_staging FROM event_push_actions_staging
WHERE event_id = ? WHERE event_id = ?
""" """

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- We're hijacking the push actions to store unread messages and unread counts (specified
-- in MSC2654) because doing otherwise would result in either performance issues or
-- reimplementing a consequent bit of the push actions.
-- Add columns to event_push_actions and event_push_actions_staging to track unread
-- messages and calculate unread counts.
ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
-- Add column to event_push_summary
ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;

View File

@ -47,7 +47,11 @@ from synapse.api.filtering import Filter
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
@ -593,8 +597,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns: Returns:
A stream ID. A stream ID.
""" """
return await self.db_pool.simple_select_one_onecol( return await self.db_pool.runInteraction(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" "get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
)
def get_stream_id_for_event_txn(
self, txn: LoggingTransaction, event_id: str, allow_none=False,
) -> int:
return self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="events",
keyvalues={"event_id": event_id},
retcol="stream_ordering",
allow_none=allow_none,
) )
async def get_stream_token_for_event(self, event_id: str) -> str: async def get_stream_token_for_event(self, event_id: str) -> str:

View File

@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 0}, {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
) )
self.persist( self.persist(
@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 1}, {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
) )
self.persist( self.persist(
@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 1, "notify_count": 2}, {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
) )
def test_get_rooms_for_user_with_stream_ordering(self): def test_get_rooms_for_user_with_stream_ordering(self):
@ -368,7 +368,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.get_success( self.get_success(
self.master_store.add_push_actions_to_staging( self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions} event.event_id,
{user_id: actions for user_id, actions in push_actions},
False,
) )
) )
return event, context return event, context

View File

@ -16,9 +16,9 @@
import json import json
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import read_marker, sync
from tests import unittest from tests import unittest
from tests.server import TimedOutException from tests.server import TimedOutException
@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch) "GET", sync_url % (access_token, next_batch)
) )
self.assertRaises(TimedOutException, self.render, request) self.assertRaises(TimedOutException, self.render, request)
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
read_marker.register_servlets,
room.register_servlets,
sync.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.url = "/sync?since=%s"
self.next_batch = "s0"
# Register the first user (used to check the unread counts).
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey")
# Create the room we'll check unread counts for.
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
# Register the second user (used to send events to the room).
self.user2 = self.register_user("kermit2", "monkey")
self.tok2 = self.login("kermit2", "monkey")
# Change the power levels of the room so that the second user can send state
# events.
self.helper.send_state(
self.room_id,
EventTypes.PowerLevels,
{
"users": {self.user_id: 100, self.user2: 100},
"users_default": 0,
"events": {
"m.room.name": 50,
"m.room.power_levels": 100,
"m.room.history_visibility": 100,
"m.room.canonical_alias": 50,
"m.room.avatar": 50,
"m.room.tombstone": 100,
"m.room.server_acl": 100,
"m.room.encryption": 100,
},
"events_default": 0,
"state_default": 50,
"ban": 50,
"kick": 50,
"redact": 50,
"invite": 0,
},
tok=self.tok,
)
def test_unread_counts(self):
"""Tests that /sync returns the right value for the unread count (MSC2654)."""
# Check that our own messages don't increase the unread count.
self.helper.send(self.room_id, "hello", tok=self.tok)
self._check_unread_count(0)
# Join the new user and check that this doesn't increase the unread count.
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
self._check_unread_count(0)
# Check that the new user sending a message increases our unread count.
res = self.helper.send(self.room_id, "hello", tok=self.tok2)
self._check_unread_count(1)
# Send a read receipt to tell the server we've read the latest event.
body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
request, channel = self.make_request(
"POST",
"/rooms/%s/read_markers" % self.room_id,
body,
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that the unread counter is back to 0.
self._check_unread_count(0)
# Check that room name changes increase the unread counter.
self.helper.send_state(
self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
)
self._check_unread_count(1)
# Check that room topic changes increase the unread counter.
self.helper.send_state(
self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
)
self._check_unread_count(2)
# Check that encrypted messages increase the unread counter.
self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
self._check_unread_count(3)
# Check that custom events with a body increase the unread counter.
self.helper.send_event(
self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
)
self._check_unread_count(4)
# Check that edits don't increase the unread counter.
self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"body": "hello",
"msgtype": "m.text",
"m.relates_to": {"rel_type": RelationTypes.REPLACE},
},
tok=self.tok2,
)
self._check_unread_count(4)
# Check that notices don't increase the unread counter.
self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"body": "hello", "msgtype": "m.notice"},
tok=self.tok2,
)
self._check_unread_count(4)
# Check that tombstone events changes increase the unread counter.
self.helper.send_state(
self.room_id,
EventTypes.Tombstone,
{"replacement_room": "!someroom:test"},
tok=self.tok2,
)
self._check_unread_count(5)
def _check_unread_count(self, expected_count: True):
"""Syncs and compares the unread count with the expected value."""
request, channel = self.make_request(
"GET", self.url % self.next_batch, access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
room_entry = channel.json_body["rooms"]["join"][self.room_id]
self.assertEqual(
room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
)
# Store the next batch for the next request.
self.next_batch = channel.json_body["next_batch"]

View File

@ -67,7 +67,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
) )
self.assertEquals( self.assertEquals(
counts, counts,
{"notify_count": noitf_count, "highlight_count": highlight_count}, {
"notify_count": noitf_count,
"unread_count": 0, # Unread counts are tested in the sync tests.
"highlight_count": highlight_count,
},
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -80,7 +84,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.store.add_push_actions_to_staging( self.store.add_push_actions_to_staging(
event.event_id, {user_id: action} event.event_id, {user_id: action}, False,
) )
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(