Add type hints to event_push_actions. (#11594)

This commit is contained in:
Patrick Cloke 2021-12-21 08:25:34 -05:00 committed by GitHub
parent 2215954147
commit b6102230a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 226 additions and 155 deletions

1
changelog.d/11594.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to storage classes.

View File

@ -28,7 +28,6 @@ exclude = (?x)
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py |synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/event_push_actions.py
|synapse/storage/databases/main/events_bg_updates.py |synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/group_server.py |synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py |synapse/storage/databases/main/metrics.py
@ -200,6 +199,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.end_to_end_keys] [mypy-synapse.storage.databases.main.end_to_end_keys]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.event_push_actions]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.events_worker] [mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -36,6 +36,7 @@ from synapse.events import EventBase
from synapse.logging.context import current_context from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
@ -1041,7 +1042,7 @@ class SyncHandler:
async def unread_notifs_for_room_id( async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig self, room_id: str, sync_config: SyncConfig
) -> Dict[str, int]: ) -> NotifCounts:
with Measure(self.clock, "unread_notifs_for_room_id"): with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
@ -1049,10 +1050,9 @@ class SyncHandler:
receipt_type=ReceiptTypes.READ, receipt_type=ReceiptTypes.READ,
) )
notifs = await self.store.get_unread_event_push_actions_by_room_for_user( return await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id room_id, sync_config.user.to_string(), last_unread_event_id
) )
return notifs
async def generate_sync_result( async def generate_sync_result(
self, self,
@ -2174,10 +2174,10 @@ class SyncHandler:
if room_sync or always_include: if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config) notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
unread_notifications["notification_count"] = notifs["notify_count"] unread_notifications["notification_count"] = notifs.notify_count
unread_notifications["highlight_count"] = notifs["highlight_count"] unread_notifications["highlight_count"] = notifs.highlight_count
room_sync.unread_count = notifs["unread_count"] room_sync.unread_count = notifs.unread_count
sync_result_builder.joined.append(room_sync) sync_result_builder.joined.append(room_sync)

View File

@ -177,12 +177,12 @@ class EmailPusher(Pusher):
return return
for push_action in unprocessed: for push_action in unprocessed:
received_at = push_action["received_ts"] received_at = push_action.received_ts
if received_at is None: if received_at is None:
received_at = 0 received_at = 0
notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS
room_ready_at = self.room_ready_to_notify_at(push_action["room_id"]) room_ready_at = self.room_ready_to_notify_at(push_action.room_id)
should_notify_at = max(notif_ready_at, room_ready_at) should_notify_at = max(notif_ready_at, room_ready_at)
@ -193,23 +193,23 @@ class EmailPusher(Pusher):
# to be delivered. # to be delivered.
reason: EmailReason = { reason: EmailReason = {
"room_id": push_action["room_id"], "room_id": push_action.room_id,
"now": self.clock.time_msec(), "now": self.clock.time_msec(),
"received_at": received_at, "received_at": received_at,
"delay_before_mail_ms": DELAY_BEFORE_MAIL_MS, "delay_before_mail_ms": DELAY_BEFORE_MAIL_MS,
"last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]), "last_sent_ts": self.get_room_last_sent_ts(push_action.room_id),
"throttle_ms": self.get_room_throttle_ms(push_action["room_id"]), "throttle_ms": self.get_room_throttle_ms(push_action.room_id),
} }
await self.send_notification(unprocessed, reason) await self.send_notification(unprocessed, reason)
await self.save_last_stream_ordering_and_success( await self.save_last_stream_ordering_and_success(
max(ea["stream_ordering"] for ea in unprocessed) max(ea.stream_ordering for ea in unprocessed)
) )
# we update the throttle on all the possible unprocessed push actions # we update the throttle on all the possible unprocessed push actions
for ea in unprocessed: for ea in unprocessed:
await self.sent_notif_update_throttle(ea["room_id"], ea) await self.sent_notif_update_throttle(ea.room_id, ea)
break break
else: else:
if soonest_due_at is None or should_notify_at < soonest_due_at: if soonest_due_at is None or should_notify_at < soonest_due_at:
@ -284,10 +284,10 @@ class EmailPusher(Pusher):
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a # THROTTLE_RESET_AFTER_MS after the previous one that triggered a
# notif, we release the throttle. Otherwise, the throttle is increased. # notif, we release the throttle. Otherwise, the throttle is increased.
time_of_previous_notifs = await self.store.get_time_of_last_push_action_before( time_of_previous_notifs = await self.store.get_time_of_last_push_action_before(
notified_push_action["stream_ordering"] notified_push_action.stream_ordering
) )
time_of_this_notifs = notified_push_action["received_ts"] time_of_this_notifs = notified_push_action.received_ts
if time_of_previous_notifs is not None and time_of_this_notifs is not None: if time_of_previous_notifs is not None and time_of_this_notifs is not None:
gap = time_of_this_notifs - time_of_previous_notifs gap = time_of_this_notifs - time_of_previous_notifs

View File

@ -199,7 +199,7 @@ class HttpPusher(Pusher):
"http-push", "http-push",
tags={ tags={
"authenticated_entity": self.user_id, "authenticated_entity": self.user_id,
"event_id": push_action["event_id"], "event_id": push_action.event_id,
"app_id": self.app_id, "app_id": self.app_id,
"app_display_name": self.app_display_name, "app_display_name": self.app_display_name,
}, },
@ -209,7 +209,7 @@ class HttpPusher(Pusher):
if processed: if processed:
http_push_processed_counter.inc() http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"] self.last_stream_ordering = push_action.stream_ordering
pusher_still_exists = ( pusher_still_exists = (
await self.store.update_pusher_last_stream_ordering_and_success( await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.app_id,
@ -252,7 +252,7 @@ class HttpPusher(Pusher):
self.pushkey, self.pushkey,
) )
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"] self.last_stream_ordering = push_action.stream_ordering
await self.store.update_pusher_last_stream_ordering( await self.store.update_pusher_last_stream_ordering(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
@ -275,17 +275,17 @@ class HttpPusher(Pusher):
break break
async def _process_one(self, push_action: HttpPushAction) -> bool: async def _process_one(self, push_action: HttpPushAction) -> bool:
if "notify" not in push_action["actions"]: if "notify" not in push_action.actions:
return True return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) tweaks = push_rule_evaluator.tweaks_for_actions(push_action.actions)
badge = await push_tools.get_badge_count( badge = await push_tools.get_badge_count(
self.hs.get_datastore(), self.hs.get_datastore(),
self.user_id, self.user_id,
group_by_room=self._group_unread_count_by_room, group_by_room=self._group_unread_count_by_room,
) )
event = await self.store.get_event(push_action["event_id"], allow_none=True) event = await self.store.get_event(push_action.event_id, allow_none=True)
if event is None: if event is None:
return True # It's been redacted return True # It's been redacted
rejected = await self.dispatch_push(event, tweaks, badge) rejected = await self.dispatch_push(event, tweaks, badge)

View File

@ -232,15 +232,13 @@ class Mailer:
reason: The notification that was ready and is the cause of an email reason: The notification that was ready and is the cause of an email
being sent. being sent.
""" """
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions]) rooms_in_order = deduped_ordered_list([pa.room_id for pa in push_actions])
notif_events = await self.store.get_events( notif_events = await self.store.get_events([pa.event_id for pa in push_actions])
[pa["event_id"] for pa in push_actions]
)
notifs_by_room: Dict[str, List[EmailPushAction]] = {} notifs_by_room: Dict[str, List[EmailPushAction]] = {}
for pa in push_actions: for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa) notifs_by_room.setdefault(pa.room_id, []).append(pa)
# collect the current state for all the rooms in which we have # collect the current state for all the rooms in which we have
# notifications # notifications
@ -264,7 +262,7 @@ class Mailer:
await concurrently_execute(_fetch_room_state, rooms_in_order, 3) await concurrently_execute(_fetch_room_state, rooms_in_order, 3)
# actually sort our so-called rooms_in_order list, most recent room first # actually sort our so-called rooms_in_order list, most recent room first
rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0)) rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1].received_ts or 0))
rooms: List[RoomVars] = [] rooms: List[RoomVars] = []
@ -356,7 +354,7 @@ class Mailer:
# Check if one of the notifs is an invite event for the user. # Check if one of the notifs is an invite event for the user.
is_invite = False is_invite = False
for n in notifs: for n in notifs:
ev = notif_events[n["event_id"]] ev = notif_events[n.event_id]
if ev.type == EventTypes.Member and ev.state_key == user_id: if ev.type == EventTypes.Member and ev.state_key == user_id:
if ev.content.get("membership") == Membership.INVITE: if ev.content.get("membership") == Membership.INVITE:
is_invite = True is_invite = True
@ -376,7 +374,7 @@ class Mailer:
if not is_invite: if not is_invite:
for n in notifs: for n in notifs:
notifvars = await self._get_notif_vars( notifvars = await self._get_notif_vars(
n, user_id, notif_events[n["event_id"]], room_state_ids n, user_id, notif_events[n.event_id], room_state_ids
) )
# merge overlapping notifs together. # merge overlapping notifs together.
@ -444,15 +442,15 @@ class Mailer:
""" """
results = await self.store.get_events_around( results = await self.store.get_events_around(
notif["room_id"], notif.room_id,
notif["event_id"], notif.event_id,
before_limit=CONTEXT_BEFORE, before_limit=CONTEXT_BEFORE,
after_limit=CONTEXT_AFTER, after_limit=CONTEXT_AFTER,
) )
ret: NotifVars = { ret: NotifVars = {
"link": self._make_notif_link(notif), "link": self._make_notif_link(notif),
"ts": notif["received_ts"], "ts": notif.received_ts,
"messages": [], "messages": [],
} }
@ -516,7 +514,7 @@ class Mailer:
ret: MessageVars = { ret: MessageVars = {
"event_type": event.type, "event_type": event.type,
"is_historical": event.event_id != notif["event_id"], "is_historical": event.event_id != notif.event_id,
"id": event.event_id, "id": event.event_id,
"ts": event.origin_server_ts, "ts": event.origin_server_ts,
"sender_name": sender_name, "sender_name": sender_name,
@ -610,7 +608,7 @@ class Mailer:
# See if one of the notifs is an invite event for the user # See if one of the notifs is an invite event for the user
invite_event = None invite_event = None
for n in notifs: for n in notifs:
ev = notif_events[n["event_id"]] ev = notif_events[n.event_id]
if ev.type == EventTypes.Member and ev.state_key == user_id: if ev.type == EventTypes.Member and ev.state_key == user_id:
if ev.content.get("membership") == Membership.INVITE: if ev.content.get("membership") == Membership.INVITE:
invite_event = ev invite_event = ev
@ -659,7 +657,7 @@ class Mailer:
if len(notifs) == 1: if len(notifs) == 1:
# There is just the one notification, so give some detail # There is just the one notification, so give some detail
sender_name = None sender_name = None
event = notif_events[notifs[0]["event_id"]] event = notif_events[notifs[0].event_id]
if ("m.room.member", event.sender) in room_state_ids: if ("m.room.member", event.sender) in room_state_ids:
state_event_id = room_state_ids[("m.room.member", event.sender)] state_event_id = room_state_ids[("m.room.member", event.sender)]
state_event = await self.store.get_event(state_event_id) state_event = await self.store.get_event(state_event_id)
@ -753,9 +751,9 @@ class Mailer:
# are already in descending received_ts. # are already in descending received_ts.
sender_ids = {} sender_ids = {}
for n in notifs: for n in notifs:
sender = notif_events[n["event_id"]].sender sender = notif_events[n.event_id].sender
if sender not in sender_ids: if sender not in sender_ids:
sender_ids[sender] = n["event_id"] sender_ids[sender] = n.event_id
# Get the actual member events (in order to calculate a pretty name for # Get the actual member events (in order to calculate a pretty name for
# the room). # the room).
@ -830,17 +828,17 @@ class Mailer:
if self.hs.config.email.email_riot_base_url: if self.hs.config.email.email_riot_base_url:
return "%s/#/room/%s/%s" % ( return "%s/#/room/%s/%s" % (
self.hs.config.email.email_riot_base_url, self.hs.config.email.email_riot_base_url,
notif["room_id"], notif.room_id,
notif["event_id"], notif.event_id,
) )
elif self.app_name == "Vector": elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS # need /beta for Universal Links to work on iOS
return "https://vector.im/beta/#/room/%s/%s" % ( return "https://vector.im/beta/#/room/%s/%s" % (
notif["room_id"], notif.room_id,
notif["event_id"], notif.event_id,
) )
else: else:
return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"]) return "https://matrix.to/#/%s/%s" % (notif.room_id, notif.event_id)
def _make_unsubscribe_link( def _make_unsubscribe_link(
self, user_id: str, app_id: str, email_address: str self, user_id: str, app_id: str, email_address: str

View File

@ -37,7 +37,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
room_id, user_id, last_unread_event_id room_id, user_id, last_unread_event_id
) )
) )
if notifs["notify_count"] == 0: if notifs.notify_count == 0:
continue continue
if group_by_room: if group_by_room:
@ -45,7 +45,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge += 1 badge += 1
else: else:
# increment the badge count by the number of unread messages in the room # increment the badge count by the number of unread messages in the room
badge += notifs["notify_count"] badge += notifs.notify_count
return badge return badge

View File

@ -58,7 +58,7 @@ class NotificationsServlet(RestServlet):
user_id, ReceiptTypes.READ user_id, ReceiptTypes.READ
) )
notif_event_ids = [pa["event_id"] for pa in push_actions] notif_event_ids = [pa.event_id for pa in push_actions]
notif_events = await self.store.get_events(notif_event_ids) notif_events = await self.store.get_events(notif_event_ids)
returned_push_actions = [] returned_push_actions = []
@ -67,30 +67,30 @@ class NotificationsServlet(RestServlet):
for pa in push_actions: for pa in push_actions:
returned_pa = { returned_pa = {
"room_id": pa["room_id"], "room_id": pa.room_id,
"profile_tag": pa["profile_tag"], "profile_tag": pa.profile_tag,
"actions": pa["actions"], "actions": pa.actions,
"ts": pa["received_ts"], "ts": pa.received_ts,
"event": ( "event": (
await self._event_serializer.serialize_event( await self._event_serializer.serialize_event(
notif_events[pa["event_id"]], notif_events[pa.event_id],
self.clock.time_msec(), self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id, event_format=format_event_for_client_v2_without_room_id,
) )
), ),
} }
if pa["room_id"] not in receipts_by_room: if pa.room_id not in receipts_by_room:
returned_pa["read"] = False returned_pa["read"] = False
else: else:
receipt = receipts_by_room[pa["room_id"]] receipt = receipts_by_room[pa.room_id]
returned_pa["read"] = ( returned_pa["read"] = (
receipt["topological_ordering"], receipt["topological_ordering"],
receipt["stream_ordering"], receipt["stream_ordering"],
) >= (pa["topological_ordering"], pa["stream_ordering"]) ) >= (pa.topological_ordering, pa.stream_ordering)
returned_push_actions.append(returned_pa) returned_push_actions.append(returned_pa)
next_token = str(pa["stream_ordering"]) next_token = str(pa.stream_ordering)
return 200, {"notifications": returned_push_actions, "next_token": next_token} return 200, {"notifications": returned_push_actions, "next_token": next_token}

View File

@ -16,7 +16,6 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr import attr
from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
@ -34,29 +33,64 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] DEFAULT_NOTIF_ACTION: List[Union[dict, str]] = [
DEFAULT_HIGHLIGHT_ACTION = [ "notify",
{"set_tweak": "highlight", "value": False},
]
DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [
"notify", "notify",
{"set_tweak": "sound", "value": "default"}, {"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"}, {"set_tweak": "highlight"},
] ]
class BasePushAction(TypedDict): @attr.s(slots=True, frozen=True, auto_attribs=True)
class HttpPushAction:
"""
HttpPushAction instances include the information used to generate HTTP
requests to a push gateway.
"""
event_id: str event_id: str
room_id: str
stream_ordering: int
actions: List[Union[dict, str]] actions: List[Union[dict, str]]
class HttpPushAction(BasePushAction): @attr.s(slots=True, frozen=True, auto_attribs=True)
room_id: str
stream_ordering: int
class EmailPushAction(HttpPushAction): class EmailPushAction(HttpPushAction):
"""
EmailPushAction instances include the information used to render an email
push notification.
"""
received_ts: Optional[int] received_ts: Optional[int]
def _serialize_action(actions, is_highlight): @attr.s(slots=True, frozen=True, auto_attribs=True)
class UserPushAction(EmailPushAction):
"""
UserPushAction instances include the necessary information to respond to
/notifications requests.
"""
topological_ordering: int
highlight: bool
profile_tag: str
@attr.s(slots=True, frozen=True, auto_attribs=True)
class NotifCounts:
"""
The per-user, per-room count of notifications. Used by sync and push.
"""
notify_count: int
unread_count: int
highlight_count: int
def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions. """Custom serializer for actions. This allows us to "compress" common actions.
We use the fact that most users have the same actions for notifs (and for We use the fact that most users have the same actions for notifs (and for
@ -74,7 +108,7 @@ def _serialize_action(actions, is_highlight):
return json_encoder.encode(actions) return json_encoder.encode(actions)
def _deserialize_action(actions, is_highlight): def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, str]]:
"""Custom deserializer for actions. This allows us to "compress" common actions""" """Custom deserializer for actions. This allows us to "compress" common actions"""
if actions: if actions:
return db_to_json(actions) return db_to_json(actions)
@ -95,8 +129,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn # These get correctly set by _find_stream_orderings_for_times_txn
self.stream_ordering_month_ago = None self.stream_ordering_month_ago: Optional[int] = None
self.stream_ordering_day_ago = None self.stream_ordering_day_ago: Optional[int] = None
cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn") cur = db_conn.cursor(txn_name="_find_stream_orderings_for_times_txn")
self._find_stream_orderings_for_times_txn(cur) self._find_stream_orderings_for_times_txn(cur)
@ -120,7 +154,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
room_id: str, room_id: str,
user_id: str, user_id: str,
last_read_event_id: Optional[str], last_read_event_id: Optional[str],
) -> Dict[str, int]: ) -> NotifCounts:
"""Get the notification count, the highlight count and the unread message count """Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt. for a given user in a given room after the given read receipt.
@ -149,15 +183,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_unread_counts_by_receipt_txn( def _get_unread_counts_by_receipt_txn(
self, self,
txn, txn: LoggingTransaction,
room_id, room_id: str,
user_id, user_id: str,
last_read_event_id, last_read_event_id: Optional[str],
): ) -> NotifCounts:
stream_ordering = None stream_ordering = None
if last_read_event_id is not None: if last_read_event_id is not None:
stream_ordering = self.get_stream_id_for_event_txn( stream_ordering = self.get_stream_id_for_event_txn( # type: ignore[attr-defined]
txn, txn,
last_read_event_id, last_read_event_id,
allow_none=True, allow_none=True,
@ -175,13 +209,15 @@ class EventPushActionsWorkerStore(SQLBaseStore):
retcol="event_id", retcol="event_id",
) )
stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) stream_ordering = self.get_stream_id_for_event_txn(txn, event_id) # type: ignore[attr-defined]
return self._get_unread_counts_by_pos_txn( return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering txn, room_id, user_id, stream_ordering
) )
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering): def _get_unread_counts_by_pos_txn(
self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
) -> NotifCounts:
sql = ( sql = (
"SELECT" "SELECT"
" COUNT(CASE WHEN notif = 1 THEN 1 END)," " COUNT(CASE WHEN notif = 1 THEN 1 END),"
@ -219,16 +255,16 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# for this row. # for this row.
unread_count += row[1] unread_count += row[1]
return { return NotifCounts(
"notify_count": notif_count, notify_count=notif_count,
"unread_count": unread_count, unread_count=unread_count,
"highlight_count": highlight_count, highlight_count=highlight_count,
} )
async def get_push_action_users_in_range( async def get_push_action_users_in_range(
self, min_stream_ordering, max_stream_ordering self, min_stream_ordering: int, max_stream_ordering: int
): ) -> List[str]:
def f(txn): def f(txn: LoggingTransaction) -> List[str]:
sql = ( sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE" "SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
" stream_ordering >= ? AND stream_ordering <= ? AND notif = 1" " stream_ordering >= ? AND stream_ordering <= ? AND notif = 1"
@ -236,8 +272,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn] return [r[0] for r in txn]
ret = await self.db_pool.runInteraction("get_push_action_users_in_range", f) return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
return ret
async def get_unread_push_actions_for_user_in_range_for_http( async def get_unread_push_actions_for_user_in_range_for_http(
self, self,
@ -263,7 +298,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
""" """
# find rooms that have a read receipt in them and return the next # find rooms that have a read receipt in them and return the next
# push actions # push actions
def get_after_receipt(txn): def get_after_receipt(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool]]:
# find rooms that have a read receipt in them and return the next # find rooms that have a read receipt in them and return the next
# push actions # push actions
sql = ( sql = (
@ -289,7 +326,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall() # type: ignore[return-value]
after_read_receipt = await self.db_pool.runInteraction( after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
@ -298,7 +335,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# There are rooms with push actions in them but you don't have a read receipt in # There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do # them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too. # not have read receipts in them too.
def get_no_receipt(txn): def get_no_receipt(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool]]:
sql = ( sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight " " ep.highlight "
@ -318,19 +357,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall() # type: ignore[return-value]
no_read_receipt = await self.db_pool.runInteraction( no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
) )
notifs = [ notifs = [
{ HttpPushAction(
"event_id": row[0], event_id=row[0],
"room_id": row[1], room_id=row[1],
"stream_ordering": row[2], stream_ordering=row[2],
"actions": _deserialize_action(row[3], row[4]), actions=_deserialize_action(row[3], row[4]),
} )
for row in after_read_receipt + no_read_receipt for row in after_read_receipt + no_read_receipt
] ]
@ -338,7 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed # contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered # by results from the second query, but we want them all ordered
# by stream_ordering, oldest first. # by stream_ordering, oldest first.
notifs.sort(key=lambda r: r["stream_ordering"]) notifs.sort(key=lambda r: r.stream_ordering)
# Take only up to the limit. We have to stop at the limit because # Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit. # one of the subqueries may have hit the limit.
@ -368,7 +407,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
""" """
# find rooms that have a read receipt in them and return the most recent # find rooms that have a read receipt in them and return the most recent
# push actions # push actions
def get_after_receipt(txn): def get_after_receipt(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]:
sql = ( sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight, e.received_ts" " ep.highlight, e.received_ts"
@ -393,7 +434,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall() # type: ignore[return-value]
after_read_receipt = await self.db_pool.runInteraction( after_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
@ -402,7 +443,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# There are rooms with push actions in them but you don't have a read receipt in # There are rooms with push actions in them but you don't have a read receipt in
# them e.g. rooms you've been invited to, so get push actions for rooms which do # them e.g. rooms you've been invited to, so get push actions for rooms which do
# not have read receipts in them too. # not have read receipts in them too.
def get_no_receipt(txn): def get_no_receipt(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, str, bool, int]]:
sql = ( sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight, e.received_ts" " ep.highlight, e.received_ts"
@ -422,7 +465,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
) )
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit] args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall() # type: ignore[return-value]
no_read_receipt = await self.db_pool.runInteraction( no_read_receipt = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
@ -430,13 +473,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Make a list of dicts from the two sets of results. # Make a list of dicts from the two sets of results.
notifs = [ notifs = [
{ EmailPushAction(
"event_id": row[0], event_id=row[0],
"room_id": row[1], room_id=row[1],
"stream_ordering": row[2], stream_ordering=row[2],
"actions": _deserialize_action(row[3], row[4]), actions=_deserialize_action(row[3], row[4]),
"received_ts": row[5], received_ts=row[5],
} )
for row in after_read_receipt + no_read_receipt for row in after_read_receipt + no_read_receipt
] ]
@ -444,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed # contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered # by results from the second query, but we want them all ordered
# by received_ts (most recent first) # by received_ts (most recent first)
notifs.sort(key=lambda r: -(r["received_ts"] or 0)) notifs.sort(key=lambda r: -(r.received_ts or 0))
# Now return the first `limit` # Now return the first `limit`
return notifs[:limit] return notifs[:limit]
@ -465,7 +508,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
True if there may be push to process, False if there definitely isn't. True if there may be push to process, False if there definitely isn't.
""" """
def _get_if_maybe_push_in_range_for_user_txn(txn): def _get_if_maybe_push_in_range_for_user_txn(txn: LoggingTransaction) -> bool:
sql = """ sql = """
SELECT 1 FROM event_push_actions SELECT 1 FROM event_push_actions
WHERE user_id = ? AND stream_ordering > ? AND notif = 1 WHERE user_id = ? AND stream_ordering > ? AND notif = 1
@ -499,19 +542,21 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# This is a helper function for generating the necessary tuple that # This is a helper function for generating the necessary tuple that
# can be used to insert into the `event_push_actions_staging` table. # can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(user_id, actions): def _gen_entry(
user_id: str, actions: List[Union[dict, str]]
) -> Tuple[str, str, str, int, int, int]:
is_highlight = 1 if _action_has_highlight(actions) else 0 is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0 notif = 1 if "notify" in actions else 0
return ( return (
event_id, # event_id column event_id, # event_id column
user_id, # user_id column user_id, # user_id column
_serialize_action(actions, is_highlight), # actions column _serialize_action(actions, bool(is_highlight)), # actions column
notif, # notif column notif, # notif column
is_highlight, # highlight column is_highlight, # highlight column
int(count_as_unread), # unread column int(count_as_unread), # unread column
) )
def _add_push_actions_to_staging_txn(txn): def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
# We don't use simple_insert_many here to avoid the overhead # We don't use simple_insert_many here to avoid the overhead
# of generating lists of dicts. # of generating lists of dicts.
@ -539,12 +584,11 @@ class EventPushActionsWorkerStore(SQLBaseStore):
""" """
try: try:
res = await self.db_pool.simple_delete( await self.db_pool.simple_delete(
table="event_push_actions_staging", table="event_push_actions_staging",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging", desc="remove_push_actions_from_staging",
) )
return res
except Exception: except Exception:
# this method is called from an exception handler, so propagating # this method is called from an exception handler, so propagating
# another exception here really isn't helpful - there's nothing # another exception here really isn't helpful - there's nothing
@ -597,7 +641,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
) )
@staticmethod @staticmethod
def _find_first_stream_ordering_after_ts_txn(txn, ts): def _find_first_stream_ordering_after_ts_txn(
txn: LoggingTransaction, ts: int
) -> int:
""" """
Find the stream_ordering of the first event that was received on or Find the stream_ordering of the first event that was received on or
after a given timestamp. This is relatively slow as there is no index after a given timestamp. This is relatively slow as there is no index
@ -609,14 +655,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
stream_ordering stream_ordering
Args: Args:
txn (twisted.enterprise.adbapi.Transaction): txn:
ts (int): timestamp to search for ts: timestamp to search for
Returns: Returns:
int: stream ordering The stream ordering
""" """
txn.execute("SELECT MAX(stream_ordering) FROM events") txn.execute("SELECT MAX(stream_ordering) FROM events")
max_stream_ordering = txn.fetchone()[0] max_stream_ordering = txn.fetchone()[0] # type: ignore[index]
if max_stream_ordering is None: if max_stream_ordering is None:
return 0 return 0
@ -672,8 +718,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end return range_end
async def get_time_of_last_push_action_before(self, stream_ordering): async def get_time_of_last_push_action_before(
def f(txn): self, stream_ordering: int
) -> Optional[int]:
def f(txn: LoggingTransaction) -> Optional[Tuple[int]]:
sql = ( sql = (
"SELECT e.received_ts" "SELECT e.received_ts"
" FROM event_push_actions AS ep" " FROM event_push_actions AS ep"
@ -683,7 +731,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" LIMIT 1" " LIMIT 1"
) )
txn.execute(sql, (stream_ordering,)) txn.execute(sql, (stream_ordering,))
return txn.fetchone() return txn.fetchone() # type: ignore[return-value]
result = await self.db_pool.runInteraction( result = await self.db_pool.runInteraction(
"get_time_of_last_push_action_before", f "get_time_of_last_push_action_before", f
@ -691,7 +739,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return result[0] if result else None return result[0] if result else None
@wrap_as_background_process("rotate_notifs") @wrap_as_background_process("rotate_notifs")
async def _rotate_notifs(self): async def _rotate_notifs(self) -> None:
if self._doing_notif_rotation or self.stream_ordering_day_ago is None: if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return return
self._doing_notif_rotation = True self._doing_notif_rotation = True
@ -709,7 +757,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
finally: finally:
self._doing_notif_rotation = False self._doing_notif_rotation = False
def _rotate_notifs_txn(self, txn): def _rotate_notifs_txn(self, txn: LoggingTransaction) -> bool:
"""Archives older notifications into event_push_summary. Returns whether """Archives older notifications into event_push_summary. Returns whether
the archiving process has caught up or not. the archiving process has caught up or not.
""" """
@ -734,6 +782,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
stream_row = txn.fetchone() stream_row = txn.fetchone()
if stream_row: if stream_row:
(offset_stream_ordering,) = stream_row (offset_stream_ordering,) = stream_row
assert self.stream_ordering_day_ago is not None
rotate_to_stream_ordering = min( rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering self.stream_ordering_day_ago, offset_stream_ordering
) )
@ -749,7 +798,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# We have caught up iff we were limited by `stream_ordering_day_ago` # We have caught up iff we were limited by `stream_ordering_day_ago`
return caught_up return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): def _rotate_notifs_before_txn(
self, txn: LoggingTransaction, rotate_to_stream_ordering: int
) -> None:
old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn( old_rotate_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="event_push_summary_stream_ordering", table="event_push_summary_stream_ordering",
@ -870,8 +921,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
) )
def _remove_old_push_actions_before_txn( def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
): ) -> None:
""" """
Purges old push actions for a user and room before a given Purges old push actions for a user and room before a given
stream_ordering. stream_ordering.
@ -943,9 +994,15 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
) )
async def get_push_actions_for_user( async def get_push_actions_for_user(
self, user_id, before=None, limit=50, only_highlight=False self,
): user_id: str,
def f(txn): before: Optional[str] = None,
limit: int = 50,
only_highlight: bool = False,
) -> List[UserPushAction]:
def f(
txn: LoggingTransaction,
) -> List[Tuple[str, str, int, int, str, bool, str, int]]:
before_clause = "" before_clause = ""
if before: if before:
before_clause = "AND epa.stream_ordering < ?" before_clause = "AND epa.stream_ordering < ?"
@ -972,32 +1029,42 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" LIMIT ?" % (before_clause,) " LIMIT ?" % (before_clause,)
) )
txn.execute(sql, args) txn.execute(sql, args)
return self.db_pool.cursor_to_dict(txn) return txn.fetchall() # type: ignore[return-value]
push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f) push_actions = await self.db_pool.runInteraction("get_push_actions_for_user", f)
for pa in push_actions: return [
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) UserPushAction(
return push_actions event_id=row[0],
room_id=row[1],
stream_ordering=row[2],
actions=_deserialize_action(row[4], row[5]),
received_ts=row[7],
topological_ordering=row[3],
highlight=row[5],
profile_tag=row[6],
)
for row in push_actions
]
def _action_has_highlight(actions): def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
for action in actions: for action in actions:
try: if not isinstance(action, dict):
if action.get("set_tweak", None) == "highlight": continue
return action.get("value", True)
except AttributeError: if action.get("set_tweak", None) == "highlight":
pass return action.get("value", True)
return False return False
@attr.s(slots=True) @attr.s(slots=True, auto_attribs=True)
class _EventPushSummary: class _EventPushSummary:
"""Summary of pending event push actions for a given user in a given room. """Summary of pending event push actions for a given user in a given room.
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions. Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
""" """
unread_count = attr.ib(type=int) unread_count: int
stream_ordering = attr.ib(type=int) stream_ordering: int
old_user_id = attr.ib(type=str) old_user_id: str
notif_count = attr.ib(type=int) notif_count: int

View File

@ -20,6 +20,7 @@ from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
from synapse.types import PersistedEventPosition from synapse.types import PersistedEventPosition
@ -166,7 +167,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "unread_count": 0, "notify_count": 0}, NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
) )
self.persist( self.persist(
@ -179,7 +180,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "unread_count": 0, "notify_count": 1}, NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
) )
self.persist( self.persist(
@ -194,7 +195,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id], [ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 1, "unread_count": 0, "notify_count": 2}, NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
) )
def test_get_rooms_for_user_with_stream_ordering(self): def test_get_rooms_for_user_with_stream_ordering(self):

View File

@ -14,6 +14,8 @@
from unittest.mock import Mock from unittest.mock import Mock
from synapse.storage.databases.main.event_push_actions import NotifCounts
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
USER_ID = "@user:example.com" USER_ID = "@user:example.com"
@ -57,11 +59,11 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
) )
self.assertEquals( self.assertEquals(
counts, counts,
{ NotifCounts(
"notify_count": noitf_count, notify_count=noitf_count,
"unread_count": 0, # Unread counts are tested in the sync tests. unread_count=0, # Unread counts are tested in the sync tests.
"highlight_count": highlight_count, highlight_count=highlight_count,
}, ),
) )
def _inject_actions(stream, action): def _inject_actions(stream, action):