Speed up get_unread_event_push_actions_by_room (#13005)

Fixes #11887 hopefully.

The core change here is that `event_push_summary` now holds a summary of counts up until a much more recent point, meaning that the range of rows we need to count in `event_push_actions` is much smaller.

This needs two major changes:
1. When we get a receipt we need to recalculate `event_push_summary` rather than just delete it
2. The logic for deleting `event_push_actions` is now divorced from calculating `event_push_summary`.

In future it would be good to calculate `event_push_summary` while we persist a new event (it should just be a case of adding one to the relevant rows in `event_push_summary`), as that will further simplify the get counts logic and remove the need for us to periodically update `event_push_summary` in a background job.
This commit is contained in:
Erik Johnston 2022-06-15 16:17:14 +01:00 committed by GitHub
parent 9ad2197fa7
commit 0d1d3e0708
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 322 additions and 151 deletions

1
changelog.d/13005.misc Normal file
View File

@ -0,0 +1 @@
Reduce DB usage of `/sync` when a large number of unread messages have recently been sent in a room.

View File

@ -58,6 +58,9 @@ from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateSt
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.databases.main.events_bg_updates import ( from synapse.storage.databases.main.events_bg_updates import (
EventsBackgroundUpdatesStore, EventsBackgroundUpdatesStore,
) )
@ -199,6 +202,7 @@ R = TypeVar("R")
class Store( class Store(
EventPushActionsWorkerStore,
ClientIpBackgroundUpdateStore, ClientIpBackgroundUpdateStore,
DeviceInboxBackgroundUpdateStore, DeviceInboxBackgroundUpdateStore,
DeviceBackgroundUpdateStore, DeviceBackgroundUpdateStore,

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@ -1054,14 +1054,10 @@ class SyncHandler:
self, room_id: str, sync_config: SyncConfig self, room_id: str, sync_config: SyncConfig
) -> NotifCounts: ) -> NotifCounts:
with Measure(self.clock, "unread_notifs_for_room_id"): with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
return await self.store.get_unread_event_push_actions_by_room_for_user( return await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id room_id,
sync_config.user.to_string(),
) )
async def generate_sync_result( async def generate_sync_result(

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from typing import Dict from typing import Dict
from synapse.api.constants import ReceiptTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage.controllers import StorageControllers from synapse.storage.controllers import StorageControllers
@ -24,19 +23,13 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
invites = await store.get_invited_rooms_for_local_user(user_id) invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id) joins = await store.get_rooms_for_user(user_id)
my_receipts_by_room = await store.get_receipts_for_user(
user_id, (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
)
badge = len(invites) badge = len(invites)
for room_id in joins: for room_id in joins:
if room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[room_id]
notifs = await ( notifs = await (
store.get_unread_event_push_actions_by_room_for_user( store.get_unread_event_push_actions_by_room_for_user(
room_id, user_id, last_unread_event_id room_id,
user_id,
) )
) )
if notifs.notify_count == 0: if notifs.notify_count == 0:

View File

@ -92,6 +92,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"event_search": "event_search_event_id_idx", "event_search": "event_search_event_id_idx",
"local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx", "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx",
"remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx", "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx",
"event_push_summary": "event_push_summary_unique_index",
} }

View File

@ -104,13 +104,14 @@ class DataStore(
PusherStore, PusherStore,
PushRuleStore, PushRuleStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
EventPushActionsStore,
ServerMetricsStore,
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
EndToEndRoomKeyStore, EndToEndRoomKeyStore,
SearchStore, SearchStore,
TagsStore, TagsStore,
AccountDataStore, AccountDataStore,
EventPushActionsStore,
OpenIdStore, OpenIdStore,
ClientIpWorkerStore, ClientIpWorkerStore,
DeviceStore, DeviceStore,
@ -124,7 +125,6 @@ class DataStore(
UIAuthStore, UIAuthStore,
EventForwardExtremitiesStore, EventForwardExtremitiesStore,
CacheInvalidationWorkerStore, CacheInvalidationWorkerStore,
ServerMetricsStore,
LockStore, LockStore,
SessionStore, SessionStore,
): ):

View File

@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
import attr import attr
from synapse.api.constants import ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import ( from synapse.storage.database import (
@ -24,6 +25,8 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -79,15 +82,15 @@ class UserPushAction(EmailPushAction):
profile_tag: str profile_tag: str
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)
class NotifCounts: class NotifCounts:
""" """
The per-user, per-room count of notifications. Used by sync and push. The per-user, per-room count of notifications. Used by sync and push.
""" """
notify_count: int notify_count: int = 0
unread_count: int unread_count: int = 0
highlight_count: int highlight_count: int = 0
def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str: def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
@ -119,7 +122,7 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st
return DEFAULT_NOTIF_ACTION return DEFAULT_NOTIF_ACTION
class EventPushActionsWorkerStore(SQLBaseStore): class EventPushActionsWorkerStore(ReceiptsWorkerStore, EventsWorkerStore, SQLBaseStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -148,12 +151,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_notifs, 30 * 60 * 1000 self._rotate_notifs, 30 * 60 * 1000
) )
@cached(num_args=3, tree=True, max_entries=5000) self.db_pool.updates.register_background_index_update(
"event_push_summary_unique_index",
index_name="event_push_summary_unique_index",
table="event_push_summary",
columns=["user_id", "room_id"],
unique=True,
replaces_index="event_push_summary_user_rm",
)
@cached(tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user( async def get_unread_event_push_actions_by_room_for_user(
self, self,
room_id: str, room_id: str,
user_id: str, user_id: str,
last_read_event_id: Optional[str],
) -> NotifCounts: ) -> NotifCounts:
"""Get the notification count, the highlight count and the unread message count """Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt. for a given user in a given room after the given read receipt.
@ -165,8 +176,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
Args: Args:
room_id: The room to retrieve the counts in. room_id: The room to retrieve the counts in.
user_id: The user to retrieve the counts for. user_id: The user to retrieve the counts for.
last_read_event_id: The event associated with the latest read receipt for
this user in this room. None if no receipt for this user in this room.
Returns Returns
A dict containing the counts mentioned earlier in this docstring, A dict containing the counts mentioned earlier in this docstring,
@ -178,7 +187,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._get_unread_counts_by_receipt_txn, self._get_unread_counts_by_receipt_txn,
room_id, room_id,
user_id, user_id,
last_read_event_id,
) )
def _get_unread_counts_by_receipt_txn( def _get_unread_counts_by_receipt_txn(
@ -186,17 +194,18 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn: LoggingTransaction, txn: LoggingTransaction,
room_id: str, room_id: str,
user_id: str, user_id: str,
last_read_event_id: Optional[str],
) -> NotifCounts: ) -> NotifCounts:
stream_ordering = None result = self.get_last_receipt_for_user_txn(
if last_read_event_id is not None:
stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined]
txn, txn,
last_read_event_id, user_id,
allow_none=True, room_id,
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
) )
stream_ordering = None
if result:
_, stream_ordering = result
if stream_ordering is None: if stream_ordering is None:
# Either last_read_event_id is None, or it's an event we don't have (e.g. # Either last_read_event_id is None, or it's an event we don't have (e.g.
# because it's been purged), in which case retrieve the stream ordering for # because it's been purged), in which case retrieve the stream ordering for
@ -218,49 +227,95 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_unread_counts_by_pos_txn( def _get_unread_counts_by_pos_txn(
self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
) -> NotifCounts: ) -> NotifCounts:
sql = ( """Get the number of unread messages for a user/room that have happened
"SELECT" since the given stream ordering.
" COUNT(CASE WHEN notif = 1 THEN 1 END)," """
" COUNT(CASE WHEN highlight = 1 THEN 1 END),"
" COUNT(CASE WHEN unread = 1 THEN 1 END)"
" FROM event_push_actions ea"
" WHERE user_id = ?"
" AND room_id = ?"
" AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering)) counts = NotifCounts()
row = txn.fetchone()
(notif_count, highlight_count, unread_count) = (0, 0, 0)
if row:
(notif_count, highlight_count, unread_count) = row
# First we pull the counts from the summary table
txn.execute( txn.execute(
""" """
SELECT notif_count, unread_count FROM event_push_summary SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ? WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""", """,
(room_id, user_id, stream_ordering), (room_id, user_id, stream_ordering),
) )
row = txn.fetchone() row = txn.fetchone()
summary_stream_ordering = 0
if row: if row:
notif_count += row[0] summary_stream_ordering = row[0]
counts.notify_count += row[1]
counts.unread_count += row[2]
if row[1] is not None: # Next we need to count highlights, which aren't summarized
# The unread_count column of event_push_summary is NULLable, so we need sql = """
# to make sure we don't try increasing the unread counts if it's NULL SELECT COUNT(*) FROM event_push_actions
# for this row. WHERE user_id = ?
unread_count += row[1] AND room_id = ?
AND stream_ordering > ?
AND highlight = 1
"""
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
if row:
counts.highlight_count += row[0]
return NotifCounts( # Finally we need to count push actions that haven't been summarized
notify_count=notif_count, # yet.
unread_count=unread_count, # We only want to pull out push actions that we haven't summarized yet.
highlight_count=highlight_count, stream_ordering = max(stream_ordering, summary_stream_ordering)
notify_count, unread_count = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering
) )
counts.notify_count += notify_count
counts.unread_count += unread_count
return counts
def _get_notif_unread_count_for_user_room(
self,
txn: LoggingTransaction,
room_id: str,
user_id: str,
stream_ordering: int,
max_stream_ordering: Optional[int] = None,
) -> Tuple[int, int]:
"""Returns the notify and unread counts from `event_push_actions` for
the given user/room in the given range.
Does not consult `event_push_summary` table, which may include push
actions that have been deleted from `event_push_actions` table.
"""
clause = ""
args = [user_id, room_id, stream_ordering]
if max_stream_ordering is not None:
clause = "AND ea.stream_ordering <= ?"
args.append(max_stream_ordering)
sql = f"""
SELECT
COUNT(CASE WHEN notif = 1 THEN 1 END),
COUNT(CASE WHEN unread = 1 THEN 1 END)
FROM event_push_actions ea
WHERE user_id = ?
AND room_id = ?
AND ea.stream_ordering > ?
{clause}
"""
txn.execute(sql, args)
row = txn.fetchone()
if row:
return cast(Tuple[int, int], row)
return 0, 0
async def get_push_action_users_in_range( async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int self, min_stream_ordering: int, max_stream_ordering: int
) -> List[str]: ) -> List[str]:
@ -754,6 +809,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
if caught_up: if caught_up:
break break
await self.hs.get_clock().sleep(self._rotate_delay) await self.hs.get_clock().sleep(self._rotate_delay)
await self._remove_old_push_actions_that_have_rotated()
finally: finally:
self._doing_notif_rotation = False self._doing_notif_rotation = False
@ -782,20 +839,16 @@ class EventPushActionsWorkerStore(SQLBaseStore):
stream_row = txn.fetchone() stream_row = txn.fetchone()
if stream_row: if stream_row:
(offset_stream_ordering,) = stream_row (offset_stream_ordering,) = stream_row
assert self.stream_ordering_day_ago is not None rotate_to_stream_ordering = offset_stream_ordering
rotate_to_stream_ordering = min( caught_up = False
self.stream_ordering_day_ago, offset_stream_ordering
)
caught_up = offset_stream_ordering >= self.stream_ordering_day_ago
else: else:
rotate_to_stream_ordering = self.stream_ordering_day_ago rotate_to_stream_ordering = self._stream_id_gen.get_current_token()
caught_up = True caught_up = True
logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
# We have caught up iff we were limited by `stream_ordering_day_ago`
return caught_up return caught_up
def _rotate_notifs_before_txn( def _rotate_notifs_before_txn(
@ -819,7 +872,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
max(stream_ordering) as stream_ordering max(stream_ordering) as stream_ordering
FROM event_push_actions FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ? WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0
AND %s = 1 AND %s = 1
GROUP BY user_id, room_id GROUP BY user_id, room_id
) AS upd ) AS upd
@ -915,17 +967,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
) )
txn.execute( txn.execute(
"DELETE FROM event_push_actions" "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
" WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", (rotate_to_stream_ordering,),
(old_rotate_stream_ordering, rotate_to_stream_ordering), )
async def _remove_old_push_actions_that_have_rotated(
self,
) -> None:
"""Clear out old push actions that have been summarized."""
# We want to clear out anything that older than a day that *has* already
# been rotated.
rotated_upto_stream_ordering = await self.db_pool.simple_select_one_onecol(
table="event_push_summary_stream_ordering",
keyvalues={},
retcol="stream_ordering",
)
max_stream_ordering_to_delete = min(
rotated_upto_stream_ordering, self.stream_ordering_day_ago
)
def remove_old_push_actions_that_have_rotated_txn(
txn: LoggingTransaction,
) -> bool:
# We don't want to clear out too much at a time, so we bound our
# deletes.
batch_size = 10000
txn.execute(
"""
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering < ? AND highlight = 0
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
""",
(
max_stream_ordering_to_delete,
batch_size,
),
)
stream_row = txn.fetchone()
if stream_row:
(stream_ordering,) = stream_row
else:
stream_ordering = max_stream_ordering_to_delete
txn.execute(
"""
DELETE FROM event_push_actions
WHERE stream_ordering < ? AND highlight = 0
""",
(stream_ordering,),
) )
logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
txn.execute( return txn.rowcount < batch_size
"UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
(rotate_to_stream_ordering,), while True:
done = await self.db_pool.runInteraction(
"_remove_old_push_actions_that_have_rotated",
remove_old_push_actions_that_have_rotated_txn,
) )
if done:
break
def _remove_old_push_actions_before_txn( def _remove_old_push_actions_before_txn(
self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
@ -965,12 +1071,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago), (user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
) )
txn.execute( old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
""" txn,
DELETE FROM event_push_summary table="event_push_summary_stream_ordering",
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? keyvalues={},
""", retcol="stream_ordering",
(room_id, user_id, stream_ordering), )
notif_count, unread_count = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
)
self.db_pool.simple_upsert_txn(
txn,
table="event_push_summary",
keyvalues={"room_id": room_id, "user_id": user_id},
values={
"notif_count": notif_count,
"unread_count": unread_count,
"stream_ordering": old_rotate_stream_ordering,
},
) )

View File

@ -110,9 +110,9 @@ def _load_rules(
# the abstract methods being implemented. # the abstract methods being implemented.
class PushRulesWorkerStore( class PushRulesWorkerStore(
ApplicationServiceWorkerStore, ApplicationServiceWorkerStore,
ReceiptsWorkerStore,
PusherWorkerStore, PusherWorkerStore,
RoomMemberWorkerStore, RoomMemberWorkerStore,
ReceiptsWorkerStore,
EventsWorkerStore, EventsWorkerStore,
SQLBaseStore, SQLBaseStore,
metaclass=abc.ABCMeta, metaclass=abc.ABCMeta,

View File

@ -118,7 +118,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self._receipts_id_gen.get_current_token() return self._receipts_id_gen.get_current_token()
async def get_last_receipt_event_id_for_user( async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_types: Iterable[str] self, user_id: str, room_id: str, receipt_types: Collection[str]
) -> Optional[str]: ) -> Optional[str]:
""" """
Fetch the event ID for the latest receipt in a room with one of the given receipt types. Fetch the event ID for the latest receipt in a room with one of the given receipt types.
@ -126,58 +126,63 @@ class ReceiptsWorkerStore(SQLBaseStore):
Args: Args:
user_id: The user to fetch receipts for. user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for. room_id: The room ID to fetch the receipt for.
receipt_type: The receipt types to fetch. Earlier receipt types receipt_type: The receipt types to fetch.
are given priority if multiple receipts point to the same event.
Returns: Returns:
The latest receipt, if one exists. The latest receipt, if one exists.
""" """
latest_event_id: Optional[str] = None result = await self.db_pool.runInteraction(
latest_stream_ordering = 0 "get_last_receipt_event_id_for_user",
for receipt_type in receipt_types: self.get_last_receipt_for_user_txn,
result = await self._get_last_receipt_event_id_for_user( user_id,
user_id, room_id, receipt_type room_id,
receipt_types,
) )
if result is None: if not result:
continue return None
event_id, stream_ordering = result
if latest_event_id is None or latest_stream_ordering < stream_ordering: event_id, _ = result
latest_event_id = event_id return event_id
latest_stream_ordering = stream_ordering
return latest_event_id def get_last_receipt_for_user_txn(
self,
@cached() txn: LoggingTransaction,
async def _get_last_receipt_event_id_for_user( user_id: str,
self, user_id: str, room_id: str, receipt_type: str room_id: str,
receipt_types: Collection[str],
) -> Optional[Tuple[str, int]]: ) -> Optional[Tuple[str, int]]:
""" """
Fetch the event ID and stream ordering for the latest receipt. Fetch the event ID and stream_ordering for the latest receipt in a room
with one of the given receipt types.
Args: Args:
user_id: The user to fetch receipts for. user_id: The user to fetch receipts for.
room_id: The room ID to fetch the receipt for. room_id: The room ID to fetch the receipt for.
receipt_type: The receipt type to fetch. receipt_type: The receipt types to fetch.
Returns: Returns:
The event ID and stream ordering of the latest receipt, if one exists; The latest receipt, if one exists.
otherwise `None`.
""" """
sql = """
clause, args = make_in_list_sql_clause(
self.database_engine, "receipt_type", receipt_types
)
sql = f"""
SELECT event_id, stream_ordering SELECT event_id, stream_ordering
FROM receipts_linearized FROM receipts_linearized
INNER JOIN events USING (room_id, event_id) INNER JOIN events USING (room_id, event_id)
WHERE user_id = ? WHERE {clause}
AND user_id = ?
AND room_id = ? AND room_id = ?
AND receipt_type = ? ORDER BY stream_ordering DESC
LIMIT 1
""" """
def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]: args.extend((user_id, room_id))
txn.execute(sql, (user_id, room_id, receipt_type)) txn.execute(sql, args)
return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction("get_own_receipt_for_user", f) return cast(Optional[Tuple[str, int]], txn.fetchone())
async def get_receipts_for_user( async def get_receipts_for_user(
self, user_id: str, receipt_types: Iterable[str] self, user_id: str, receipt_types: Iterable[str]
@ -577,8 +582,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
) -> None: ) -> None:
self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type)) self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate((room_id,)) self._get_linearized_receipts_for_room.invalidate((room_id,))
self._get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type) # We use this method to invalidate so that we don't end up with circular
# dependencies between the receipts and push action stores.
self._attempt_to_invalidate_cache(
"get_unread_event_push_actions_by_room_for_user", (room_id,)
) )
def process_replication_rows( def process_replication_rows(

View File

@ -13,9 +13,10 @@
* limitations under the License. * limitations under the License.
*/ */
-- Aggregate of old notification counts that have been deleted out of the -- Aggregate of notification counts up to `stream_ordering`, including those
-- main event_push_actions table. This count does not include those that were -- that may have been deleted out of the main event_push_actions table. This
-- highlights, as they remain in the event_push_actions table. -- count does not include those that were highlights, as they remain in the
-- event_push_actions table.
CREATE TABLE event_push_summary ( CREATE TABLE event_push_summary (
user_id TEXT NOT NULL, user_id TEXT NOT NULL,
room_id TEXT NOT NULL, room_id TEXT NOT NULL,

View File

@ -0,0 +1,18 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a unique index to `event_push_summary`
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7002, 'event_push_summary_unique_index', '{}');

View File

@ -577,7 +577,7 @@ class HTTPPusherTests(HomeserverTestCase):
# Carry out our option-value specific test # Carry out our option-value specific test
# #
# This push should still only contain an unread count of 1 (for 1 unread room) # This push should still only contain an unread count of 1 (for 1 unread room)
self._check_push_attempt(6, 1) self._check_push_attempt(7, 1)
@override_config({"push": {"group_unread_count_by_room": False}}) @override_config({"push": {"group_unread_count_by_room": False}})
def test_push_unread_count_message_count(self) -> None: def test_push_unread_count_message_count(self) -> None:
@ -591,7 +591,7 @@ class HTTPPusherTests(HomeserverTestCase):
# #
# We're counting every unread message, so there should now be 3 since the # We're counting every unread message, so there should now be 3 since the
# last read receipt # last read receipt
self._check_push_attempt(6, 3) self._check_push_attempt(7, 3)
def _test_push_unread_count(self) -> None: def _test_push_unread_count(self) -> None:
""" """
@ -641,18 +641,18 @@ class HTTPPusherTests(HomeserverTestCase):
response = self.helper.send( response = self.helper.send(
room_id, body="Hello there!", tok=other_access_token room_id, body="Hello there!", tok=other_access_token
) )
# To get an unread count, the user who is getting notified has to have a read
# position in the room. We'll set the read position to this event in a moment
first_message_event_id = response["event_id"] first_message_event_id = response["event_id"]
expected_push_attempts = 1 expected_push_attempts = 1
self._check_push_attempt(expected_push_attempts, 0) self._check_push_attempt(expected_push_attempts, 1)
self._send_read_request(access_token, first_message_event_id, room_id) self._send_read_request(access_token, first_message_event_id, room_id)
# Unread count has not changed. Therefore, ensure that read request does not # Unread count has changed. Therefore, ensure that read request triggers
# trigger a push notification. # a push notification.
self.assertEqual(len(self.push_attempts), 1) expected_push_attempts += 1
self.assertEqual(len(self.push_attempts), expected_push_attempts)
# Send another message # Send another message
response2 = self.helper.send( response2 = self.helper.send(

View File

@ -15,7 +15,9 @@ import logging
from typing import Iterable, Optional from typing import Iterable, Optional
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from parameterized import parameterized
from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
@ -156,17 +158,26 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
], ],
) )
def test_push_actions_for_user(self): @parameterized.expand([(True,), (False,)])
def test_push_actions_for_user(self, send_receipt: bool):
self.persist(type="m.room.create", key="", creator=USER_ID) self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.join", key=USER_ID, membership="join") self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist( self.persist(
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join" type="m.room.member", sender=USER_ID, key=USER_ID_2, membership="join"
) )
event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello") event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
self.replicate() self.replicate()
if send_receipt:
self.get_success(
self.master_store.insert_receipt(
ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {}
)
)
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=0), NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
) )
@ -179,7 +190,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.replicate() self.replicate()
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=1), NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
) )
@ -194,7 +205,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.replicate() self.replicate()
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=1, unread_count=0, notify_count=2), NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
) )

View File

@ -51,10 +51,16 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
room_id = "!foo:example.com" room_id = "!foo:example.com"
user_id = "@user1235:example.com" user_id = "@user1235:example.com"
last_read_stream_ordering = [0]
def _assert_counts(noitf_count, highlight_count): def _assert_counts(noitf_count, highlight_count):
counts = self.get_success( counts = self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"", self.store._get_unread_counts_by_pos_txn, room_id, user_id, 0 "",
self.store._get_unread_counts_by_pos_txn,
room_id,
user_id,
last_read_stream_ordering[0],
) )
) )
self.assertEqual( self.assertEqual(
@ -98,6 +104,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
) )
def _mark_read(stream, depth): def _mark_read(stream, depth):
last_read_stream_ordering[0] = stream
self.get_success( self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"", "",
@ -144,8 +151,19 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_assert_counts(1, 1) _assert_counts(1, 1)
_rotate(9) _rotate(9)
_assert_counts(1, 1) _assert_counts(1, 1)
_rotate(10)
_assert_counts(1, 1) # Check that adding another notification and rotating after highlight
# works.
_inject_actions(10, PlAIN_NOTIF)
_rotate(11)
_assert_counts(2, 1)
# Check that sending read receipts at different points results in the
# right counts.
_mark_read(8, 8)
_assert_counts(1, 0)
_mark_read(10, 10)
_assert_counts(0, 0)
def test_find_first_stream_ordering_after_ts(self): def test_find_first_stream_ordering_after_ts(self):
def add_event(so, ts): def add_event(so, ts):