Aggregate unread notif count query for badge count calculation (#14255)

Fetch the unread notification counts used by the badge counts
in push notifications for all rooms at once (instead of fetching
them per room).
This commit is contained in:
Nick Mills-Barrett 2022-11-30 13:45:06 +00:00 committed by GitHub
parent 4569eda944
commit e8bce8999f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 198 additions and 27 deletions

1
changelog.d/14255.misc Normal file
View File

@ -0,0 +1 @@
Optimise push badge count calculations. Contributed by Nick @ Beeper (@fizzadar).

View File

@ -17,7 +17,6 @@ from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage.controllers import StorageControllers from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.util.async_helpers import concurrently_execute
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int: async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge = len(invites) badge = len(invites)
room_notifs = [] room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
for room_id, notify_count in room_to_count.items():
async def get_room_unread_count(room_id: str) -> None: # room_to_count may include rooms which the user has left,
room_notifs.append( # ignore those.
await store.get_unread_event_push_actions_by_room_for_user( if room_id not in joins:
room_id, continue
user_id,
)
)
await concurrently_execute(get_room_unread_count, joins, 10)
for notifs in room_notifs:
# Combine the counts from all the threads.
notify_count = notifs.main_timeline.notify_count + sum(
n.notify_count for n in notifs.threads.values()
)
if notify_count == 0: if notify_count == 0:
continue continue
@ -51,8 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
# return one badge count per conversation # return one badge count per conversation
badge += 1 badge += 1
else: else:
# increment the badge count by the number of unread messages in the room # Increase badge by number of notifications in room
# NOTE: this includes threaded and unthreaded notifications.
badge += notify_count badge += notify_count
return badge return badge

View File

@ -74,6 +74,7 @@ receipt.
""" """
import logging import logging
from collections import defaultdict
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Collection, Collection,
@ -95,6 +96,7 @@ from synapse.storage.database import (
DatabasePool, DatabasePool,
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
PostgresEngine,
) )
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore
@ -463,6 +465,153 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result return result
async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
"""Get the notification count by room for a user. Only considers notifications,
not highlight or unread counts, and threads are currently aggregated under their room.
This function is intentionally not cached because it is called to calculate the
unread badge for push notifications and thus the result is expected to change.
Note that this function assumes the user is a member of the room. Because
summary rows are not removed when a user leaves a room, the caller must
filter out those results from the result.
Returns:
A map of room ID to notification counts for the given user.
"""
return await self.db_pool.runInteraction(
"get_unread_counts_by_room_for_user",
self._get_unread_counts_by_room_for_user_txn,
user_id,
)
def _get_unread_counts_by_room_for_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Dict[str, int]:
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
args.extend([user_id, user_id])
receipts_cte = f"""
WITH all_receipts AS (
SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE
{receipt_types_clause}
AND user_id = ?
GROUP BY room_id, thread_id
)
"""
receipts_joins = """
LEFT JOIN (
SELECT room_id, thread_id,
max_receipt_stream_ordering AS threaded_receipt_stream_ordering
FROM all_receipts
WHERE thread_id IS NOT NULL
) AS threaded_receipts USING (room_id, thread_id)
LEFT JOIN (
SELECT room_id, thread_id,
max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
FROM all_receipts
WHERE thread_id IS NULL
) AS unthreaded_receipts USING (room_id)
"""
# First get summary counts by room / thread for the user. We use the max receipt
# stream ordering of both threaded & unthreaded receipts to compare against the
# summary table.
#
# PostgreSQL and SQLite differ in comparing scalar numerics.
if isinstance(self.database_engine, PostgresEngine):
# GREATEST ignores NULLs.
max_clause = """GREATEST(
threaded_receipt_stream_ordering,
unthreaded_receipt_stream_ordering
)"""
else:
# MAX returns NULL if any are NULL, so COALESCE to 0 first.
max_clause = """MAX(
COALESCE(threaded_receipt_stream_ordering, 0),
COALESCE(unthreaded_receipt_stream_ordering, 0)
)"""
sql = f"""
{receipts_cte}
SELECT eps.room_id, eps.thread_id, notif_count
FROM event_push_summary AS eps
{receipts_joins}
WHERE user_id = ?
AND notif_count != 0
AND (
(last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
OR last_receipt_stream_ordering = {max_clause}
)
"""
txn.execute(sql, args)
seen_thread_ids = set()
room_to_count: Dict[str, int] = defaultdict(int)
for room_id, thread_id, notif_count in txn:
room_to_count[room_id] += notif_count
seen_thread_ids.add(thread_id)
# Now get any event push actions that haven't been rotated using the same OR
# join and filter by receipt and event push summary rotated up to stream ordering.
sql = f"""
{receipts_cte}
SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
FROM event_push_actions AS epa
{receipts_joins}
WHERE user_id = ?
AND epa.notif = 1
AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
GROUP BY epa.room_id, epa.thread_id
"""
txn.execute(sql, args)
for room_id, thread_id, notif_count in txn:
# Note: only count push actions we have valid summaries for with up to date receipt.
if thread_id not in seen_thread_ids:
continue
room_to_count[room_id] += notif_count
thread_id_clause, thread_ids_args = make_in_list_sql_clause(
self.database_engine, "epa.thread_id", seen_thread_ids
)
# Finally re-check event_push_actions for any rooms not in the summary, ignoring
# the rotated up-to position. This handles the case where a read receipt has arrived
# but not been rotated meaning the summary table is out of date, so we go back to
# the push actions table.
sql = f"""
{receipts_cte}
SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
FROM event_push_actions AS epa
{receipts_joins}
WHERE user_id = ?
AND NOT {thread_id_clause}
AND epa.notif = 1
AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
GROUP BY epa.room_id
"""
args.extend(thread_ids_args)
txn.execute(sql, args)
for room_id, notif_count in txn:
room_to_count[room_id] += notif_count
return room_to_count
@cached(tree=True, max_entries=5000, iterable=True) @cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user( async def get_unread_event_push_actions_by_room_for_user(
self, self,

View File

@ -156,7 +156,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
last_event_id: str last_event_id: str
def _assert_counts(noitf_count: int, highlight_count: int) -> None: def _assert_counts(notif_count: int, highlight_count: int) -> None:
counts = self.get_success( counts = self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"get-unread-counts", "get-unread-counts",
@ -168,13 +168,22 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual( self.assertEqual(
counts.main_timeline, counts.main_timeline,
NotifCounts( NotifCounts(
notify_count=noitf_count, notify_count=notif_count,
unread_count=0, unread_count=0,
highlight_count=highlight_count, highlight_count=highlight_count,
), ),
) )
self.assertEqual(counts.threads, {}) self.assertEqual(counts.threads, {})
aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(aggregate_counts[room_id], notif_count)
def _create_event(highlight: bool = False) -> str: def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event( result = self.helper.send_event(
room_id, room_id,
@ -283,7 +292,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
last_event_id: str last_event_id: str
def _assert_counts( def _assert_counts(
noitf_count: int, notif_count: int,
highlight_count: int, highlight_count: int,
thread_notif_count: int, thread_notif_count: int,
thread_highlight_count: int, thread_highlight_count: int,
@ -299,7 +308,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual( self.assertEqual(
counts.main_timeline, counts.main_timeline,
NotifCounts( NotifCounts(
notify_count=noitf_count, notify_count=notif_count,
unread_count=0, unread_count=0,
highlight_count=highlight_count, highlight_count=highlight_count,
), ),
@ -318,6 +327,17 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
else: else:
self.assertEqual(counts.threads, {}) self.assertEqual(counts.threads, {})
aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(
aggregate_counts[room_id], notif_count + thread_notif_count
)
def _create_event( def _create_event(
highlight: bool = False, thread_id: Optional[str] = None highlight: bool = False, thread_id: Optional[str] = None
) -> str: ) -> str:
@ -454,7 +474,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
last_event_id: str last_event_id: str
def _assert_counts( def _assert_counts(
noitf_count: int, notif_count: int,
highlight_count: int, highlight_count: int,
thread_notif_count: int, thread_notif_count: int,
thread_highlight_count: int, thread_highlight_count: int,
@ -470,7 +490,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual( self.assertEqual(
counts.main_timeline, counts.main_timeline,
NotifCounts( NotifCounts(
notify_count=noitf_count, notify_count=notif_count,
unread_count=0, unread_count=0,
highlight_count=highlight_count, highlight_count=highlight_count,
), ),
@ -489,6 +509,17 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
else: else:
self.assertEqual(counts.threads, {}) self.assertEqual(counts.threads, {})
aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(
aggregate_counts[room_id], notif_count + thread_notif_count
)
def _create_event( def _create_event(
highlight: bool = False, thread_id: Optional[str] = None highlight: bool = False, thread_id: Optional[str] = None
) -> str: ) -> str:
@ -646,7 +677,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
) )
return result["event_id"] return result["event_id"]
def _assert_counts(noitf_count: int, thread_notif_count: int) -> None: def _assert_counts(notif_count: int, thread_notif_count: int) -> None:
counts = self.get_success( counts = self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"get-unread-counts", "get-unread-counts",
@ -658,7 +689,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual( self.assertEqual(
counts.main_timeline, counts.main_timeline,
NotifCounts( NotifCounts(
notify_count=noitf_count, unread_count=0, highlight_count=0 notify_count=notif_count, unread_count=0, highlight_count=0
), ),
) )
if thread_notif_count: if thread_notif_count: