Track notification counts per thread (implement MSC3773). (#13776)

When retrieving counts of notifications segment the results based on the
thread ID, but choose whether to return them as individual threads or as
a single summed field by letting the client opt-in via a sync flag.

The summarization code is also updated to be per thread, instead of per
room.
This commit is contained in:
Patrick Cloke 2022-10-04 09:47:04 -04:00 committed by GitHub
parent 94017e867d
commit b4ec4f5e71
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 513 additions and 92 deletions

View file

@ -88,7 +88,7 @@ from typing import (
import attr
from synapse.api.constants import ReceiptTypes
from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@ -157,7 +157,7 @@ class UserPushAction(EmailPushAction):
@attr.s(slots=True, auto_attribs=True)
class NotifCounts:
"""
The per-user, per-room count of notifications. Used by sync and push.
The per-user, per-room, per-thread count of notifications. Used by sync and push.
"""
notify_count: int = 0
@ -165,6 +165,21 @@ class NotifCounts:
highlight_count: int = 0
@attr.s(slots=True, auto_attribs=True)
class RoomNotifCounts:
"""
The per-user, per-room count of notifications. Used by sync and push.
"""
main_timeline: NotifCounts
# Map of thread ID to the notification counts.
threads: Dict[str, NotifCounts]
def __len__(self) -> int:
# To properly account for the amount of space in any caches.
return len(self.threads) + 1
def _serialize_action(
actions: Collection[Union[Mapping, str]], is_highlight: bool
) -> str:
@ -338,12 +353,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result
@cached(tree=True, max_entries=5000)
@cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
room_id: str,
user_id: str,
) -> NotifCounts:
) -> RoomNotifCounts:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after their latest read receipt.
@ -356,8 +371,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to retrieve the counts for.
Returns
A NotifCounts object containing the notification count, the highlight count
and the unread message count.
A RoomNotifCounts object containing the notification count, the
highlight count and the unread message count for both the main timeline
and threads.
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
@ -371,7 +387,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: LoggingTransaction,
room_id: str,
user_id: str,
) -> NotifCounts:
) -> RoomNotifCounts:
# Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_unthreaded_receipt_for_user_txn(
txn,
@ -406,7 +422,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
room_id: str,
user_id: str,
receipt_stream_ordering: int,
) -> NotifCounts:
) -> RoomNotifCounts:
"""Get the number of unread messages for a user/room that have happened
since the given stream ordering.
@ -418,12 +434,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
receipt in the room. If there are no receipts, the stream ordering
of the user's join event.
Returns
A NotifCounts object containing the notification count, the highlight count
and the unread message count.
Returns:
A RoomNotifCounts object containing the notification count, the
highlight count and the unread message count for both the main timeline
and threads.
"""
counts = NotifCounts()
main_counts = NotifCounts()
thread_counts: Dict[str, NotifCounts] = {}
def _get_thread(thread_id: str) -> NotifCounts:
if thread_id == MAIN_TIMELINE:
return main_counts
return thread_counts.setdefault(thread_id, NotifCounts())
# First we pull the counts from the summary table.
#
@ -440,52 +463,61 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# receipt).
txn.execute(
"""
SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id
FROM event_push_summary
WHERE room_id = ? AND user_id = ?
AND (
(last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
OR last_receipt_stream_ordering = ?
)
) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0)
""",
(room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
)
row = txn.fetchone()
max_summary_stream_ordering = 0
for summary_stream_ordering, notif_count, unread_count, thread_id in txn:
counts = _get_thread(thread_id)
counts.notify_count += notif_count
counts.unread_count += unread_count
summary_stream_ordering = 0
if row:
summary_stream_ordering = row[0]
counts.notify_count += row[1]
counts.unread_count += row[2]
# Summaries will only be used if they have not been invalidated by
# a recent receipt; track the latest stream ordering or a valid summary.
#
# Note that since there's only one read receipt in the room per user,
# valid summaries are contiguous.
max_summary_stream_ordering = max(
summary_stream_ordering, max_summary_stream_ordering
)
# Next we need to count highlights, which aren't summarised
sql = """
SELECT COUNT(*) FROM event_push_actions
SELECT COUNT(*), thread_id FROM event_push_actions
WHERE user_id = ?
AND room_id = ?
AND stream_ordering > ?
AND highlight = 1
GROUP BY thread_id
"""
txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
row = txn.fetchone()
if row:
counts.highlight_count += row[0]
for highlight_count, thread_id in txn:
_get_thread(thread_id).highlight_count += highlight_count
# Finally we need to count push actions that aren't included in the
# summary returned above. This might be due to recent events that haven't
# been summarised yet or the summary is out of date due to a recent read
# receipt.
start_unread_stream_ordering = max(
receipt_stream_ordering, summary_stream_ordering
receipt_stream_ordering, max_summary_stream_ordering
)
notify_count, unread_count = self._get_notif_unread_count_for_user_room(
unread_counts = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, start_unread_stream_ordering
)
counts.notify_count += notify_count
counts.unread_count += unread_count
for notif_count, unread_count, thread_id in unread_counts:
counts = _get_thread(thread_id)
counts.notify_count += notif_count
counts.unread_count += unread_count
return counts
return RoomNotifCounts(main_counts, thread_counts)
def _get_notif_unread_count_for_user_room(
self,
@ -494,7 +526,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
stream_ordering: int,
max_stream_ordering: Optional[int] = None,
) -> Tuple[int, int]:
) -> List[Tuple[int, int, str]]:
"""Returns the notify and unread counts from `event_push_actions` for
the given user/room in the given range.
@ -510,13 +542,14 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
If this is not given, then no maximum is applied.
Return:
A tuple of the notif count and unread count in the given range.
A tuple of the notif count and unread count in the given range for
each thread.
"""
# If there have been no events in the room since the stream ordering,
# there can't be any push actions either.
if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
return 0, 0
return []
clause = ""
args = [user_id, room_id, stream_ordering]
@ -527,26 +560,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# If the max stream ordering is less than the min stream ordering,
# then obviously there are zero push actions in that range.
if max_stream_ordering <= stream_ordering:
return 0, 0
return []
sql = f"""
SELECT
COUNT(CASE WHEN notif = 1 THEN 1 END),
COUNT(CASE WHEN unread = 1 THEN 1 END)
FROM event_push_actions ea
WHERE user_id = ?
COUNT(CASE WHEN unread = 1 THEN 1 END),
thread_id
FROM event_push_actions ea
WHERE user_id = ?
AND room_id = ?
AND ea.stream_ordering > ?
{clause}
GROUP BY thread_id
"""
txn.execute(sql, args)
row = txn.fetchone()
if row:
return cast(Tuple[int, int], row)
return 0, 0
return cast(List[Tuple[int, int, str]], txn.fetchall())
async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int
@ -1099,26 +1129,34 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised.
notif_count, unread_count = self._get_notif_unread_count_for_user_room(
unread_counts = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
)
# Replace the previous summary with the new counts.
#
# TODO(threads): Upsert per-thread instead of setting them all to main.
self.db_pool.simple_upsert_txn(
# First mark the summary for all threads in the room as cleared.
self.db_pool.simple_update_txn(
txn,
table="event_push_summary",
keyvalues={"room_id": room_id, "user_id": user_id},
values={
"notif_count": notif_count,
"unread_count": unread_count,
keyvalues={"user_id": user_id, "room_id": room_id},
updatevalues={
"notif_count": 0,
"unread_count": 0,
"stream_ordering": old_rotate_stream_ordering,
"last_receipt_stream_ordering": stream_ordering,
"thread_id": "main",
},
)
# Then any updated threads get their notification count and unread
# count updated.
self.db_pool.simple_update_many_txn(
txn,
table="event_push_summary",
key_names=("room_id", "user_id", "thread_id"),
key_values=[(room_id, user_id, row[2]) for row in unread_counts],
value_names=("notif_count", "unread_count"),
value_values=[(row[0], row[1]) for row in unread_counts],
)
# We always update `event_push_summary_last_receipt_stream_id` to
# ensure that we don't rescan the same receipts for remote users.
@ -1204,23 +1242,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
SELECT user_id, room_id, thread_id,
coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering
FROM (
SELECT user_id, room_id, count(*) as cnt,
SELECT user_id, room_id, thread_id, count(*) as cnt,
max(ea.stream_ordering) as stream_ordering
FROM event_push_actions AS ea
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
AND (
old.last_receipt_stream_ordering IS NULL
OR old.last_receipt_stream_ordering < ea.stream_ordering
)
AND %s = 1
GROUP BY user_id, room_id
GROUP BY user_id, room_id, thread_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
"""
# First get the count of unread messages.
@ -1234,11 +1272,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {}
for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2],
stream_ordering=row[3],
summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=row[3],
stream_ordering=row[4],
notif_count=0,
)
@ -1249,34 +1287,35 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
for row in txn:
if (row[0], row[1]) in summaries:
summaries[(row[0], row[1])].notif_count = row[2]
if (row[0], row[1], row[2]) in summaries:
summaries[(row[0], row[1], row[2])].notif_count = row[3]
else:
# Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room)
# tuple to complete.
summaries[(row[0], row[1])] = _EventPushSummary(
summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=0,
stream_ordering=row[3],
notif_count=row[2],
stream_ordering=row[4],
notif_count=row[3],
)
logger.info("Rotating notifications, handling %d rows", len(summaries))
# TODO(threads): Update on a per-thread basis.
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
key_names=("user_id", "room_id"),
key_values=[(user_id, room_id) for user_id, room_id in summaries],
value_names=("notif_count", "unread_count", "stream_ordering", "thread_id"),
key_names=("user_id", "room_id", "thread_id"),
key_values=[
(user_id, room_id, thread_id)
for user_id, room_id, thread_id in summaries
],
value_names=("notif_count", "unread_count", "stream_ordering"),
value_values=[
(
summary.notif_count,
summary.unread_count,
summary.stream_ordering,
"main",
)
for summary in summaries.values()
],
@ -1288,7 +1327,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
async def _remove_old_push_actions_that_have_rotated(self) -> None:
"""Clear out old push actions that have been summarised."""
"""
Clear out old push actions that have been summarised (and are older than
1 day ago).
"""
# We want to clear out anything that is older than a day that *has* already
# been rotated.