Convert internal pusher dicts to attrs classes. (#8940)

This improves type hinting and should use less memory.
This commit is contained in:
Patrick Cloke 2020-12-16 11:25:30 -05:00 committed by GitHub
parent 7a332850e6
commit bd30cfe86a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 266 additions and 204 deletions

1
changelog.d/8940.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to push module.

View File

@ -65,6 +65,7 @@ files =
synapse/state, synapse/state,
synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/appservice.py,
synapse/storage/databases/main/events.py, synapse/storage/databases/main/events.py,
synapse/storage/databases/main/pusher.py,
synapse/storage/databases/main/registration.py, synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,

View File

@ -14,24 +14,70 @@
# limitations under the License. # limitations under the License.
import abc import abc
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict, Optional
from synapse.types import RoomStreamToken import attr
from synapse.types import JsonDict, RoomStreamToken
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer from synapse.app.homeserver import HomeServer
@attr.s(slots=True)
class PusherConfig:
"""Parameters necessary to configure a pusher."""
id = attr.ib(type=Optional[str])
user_name = attr.ib(type=str)
access_token = attr.ib(type=Optional[int])
profile_tag = attr.ib(type=str)
kind = attr.ib(type=str)
app_id = attr.ib(type=str)
app_display_name = attr.ib(type=str)
device_display_name = attr.ib(type=str)
pushkey = attr.ib(type=str)
ts = attr.ib(type=int)
lang = attr.ib(type=Optional[str])
data = attr.ib(type=Optional[JsonDict])
last_stream_ordering = attr.ib(type=Optional[int])
last_success = attr.ib(type=Optional[int])
failing_since = attr.ib(type=Optional[int])
def as_dict(self) -> Dict[str, Any]:
"""Information that can be retrieved about a pusher after creation."""
return {
"app_display_name": self.app_display_name,
"app_id": self.app_id,
"data": self.data,
"device_display_name": self.device_display_name,
"kind": self.kind,
"lang": self.lang,
"profile_tag": self.profile_tag,
"pushkey": self.pushkey,
}
@attr.s(slots=True)
class ThrottleParams:
"""Parameters for controlling the rate of sending pushes via email."""
last_sent_ts = attr.ib(type=int)
throttle_ms = attr.ib(type=int)
class Pusher(metaclass=abc.ABCMeta): class Pusher(metaclass=abc.ABCMeta):
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
self.hs = hs self.hs = hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.pusher_id = pusherdict["id"] self.pusher_id = pusher_config.id
self.user_id = pusherdict["user_name"] self.user_id = pusher_config.user_name
self.app_id = pusherdict["app_id"] self.app_id = pusher_config.app_id
self.pushkey = pusherdict["pushkey"] self.pushkey = pusher_config.pushkey
self.last_stream_ordering = pusher_config.last_stream_ordering
# This is the highest stream ordering we know it's safe to process. # This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we # When new events arrive, we'll be given a window of new events: we

View File

@ -14,13 +14,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional
from twisted.internet.base import DelayedCall from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher from synapse.push import Pusher, PusherConfig, ThrottleParams
from synapse.push.mailer import Mailer from synapse.push.mailer import Mailer
if TYPE_CHECKING: if TYPE_CHECKING:
@ -60,15 +60,14 @@ class EmailPusher(Pusher):
factor out the common parts factor out the common parts
""" """
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
super().__init__(hs, pusherdict) super().__init__(hs, pusher_config)
self.mailer = mailer self.mailer = mailer
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.email = pusherdict["pushkey"] self.email = pusher_config.pushkey
self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.timed_call = None # type: Optional[DelayedCall] self.timed_call = None # type: Optional[DelayedCall]
self.throttle_params = {} # type: Dict[str, Dict[str, int]] self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False self._inited = False
self._is_processing = False self._is_processing = False
@ -132,6 +131,7 @@ class EmailPusher(Pusher):
if not self._inited: if not self._inited:
# this is our first loop: load up the throttle params # this is our first loop: load up the throttle params
assert self.pusher_id is not None
self.throttle_params = await self.store.get_throttle_params_by_room( self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id self.pusher_id
) )
@ -157,6 +157,7 @@ class EmailPusher(Pusher):
being run. being run.
""" """
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
assert start is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email( unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering self.user_id, start, self.max_stream_ordering
) )
@ -244,13 +245,13 @@ class EmailPusher(Pusher):
def get_room_throttle_ms(self, room_id: str) -> int: def get_room_throttle_ms(self, room_id: str) -> int:
if room_id in self.throttle_params: if room_id in self.throttle_params:
return self.throttle_params[room_id]["throttle_ms"] return self.throttle_params[room_id].throttle_ms
else: else:
return 0 return 0
def get_room_last_sent_ts(self, room_id: str) -> int: def get_room_last_sent_ts(self, room_id: str) -> int:
if room_id in self.throttle_params: if room_id in self.throttle_params:
return self.throttle_params[room_id]["last_sent_ts"] return self.throttle_params[room_id].last_sent_ts
else: else:
return 0 return 0
@ -301,10 +302,10 @@ class EmailPusher(Pusher):
new_throttle_ms = min( new_throttle_ms = min(
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
) )
self.throttle_params[room_id] = { self.throttle_params[room_id] = ThrottleParams(
"last_sent_ts": self.clock.time_msec(), self.clock.time_msec(), new_throttle_ms,
"throttle_ms": new_throttle_ms, )
} assert self.pusher_id is not None
await self.store.set_throttle_params( await self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id] self.pusher_id, room_id, self.throttle_params[room_id]
) )

View File

@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfigException from synapse.push import Pusher, PusherConfig, PusherConfigException
from . import push_rule_evaluator, push_tools from . import push_rule_evaluator, push_tools
@ -62,33 +62,29 @@ class HttpPusher(Pusher):
# This one's in ms because we compare it against the clock # This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000 GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
super().__init__(hs, pusherdict) super().__init__(hs, pusher_config)
self.storage = self.hs.get_storage() self.storage = self.hs.get_storage()
self.app_display_name = pusherdict["app_display_name"] self.app_display_name = pusher_config.app_display_name
self.device_display_name = pusherdict["device_display_name"] self.device_display_name = pusher_config.device_display_name
self.pushkey_ts = pusherdict["ts"] self.pushkey_ts = pusher_config.ts
self.data = pusherdict["data"] self.data = pusher_config.data
self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusherdict["failing_since"] self.failing_since = pusher_config.failing_since
self.timed_call = None self.timed_call = None
self._is_processing = False self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
if "data" not in pusherdict: self.data = pusher_config.data
raise PusherConfigException("No 'data' key for HTTP pusher") if self.data is None:
self.data = pusherdict["data"] raise PusherConfigException("'data' key can not be null for HTTP pusher")
self.name = "%s/%s/%s" % ( self.name = "%s/%s/%s" % (
pusherdict["user_name"], pusher_config.user_name,
pusherdict["app_id"], pusher_config.app_id,
pusherdict["pushkey"], pusher_config.pushkey,
) )
if self.data is None:
raise PusherConfigException("data can not be null for HTTP pusher")
# Validate that there's a URL and it is of the proper form. # Validate that there's a URL and it is of the proper form.
if "url" not in self.data: if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher") raise PusherConfigException("'url' required in data for HTTP pusher")
@ -180,6 +176,7 @@ class HttpPusher(Pusher):
Never call this directly: use _process which will only allow this to Never call this directly: use _process which will only allow this to
run once per pusher. run once per pusher.
""" """
assert self.last_stream_ordering is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http( unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
@ -208,6 +205,7 @@ class HttpPusher(Pusher):
http_push_processed_counter.inc() http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"] self.last_stream_ordering = push_action["stream_ordering"]
assert self.last_stream_ordering is not None
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success( pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
@ -314,6 +312,8 @@ class HttpPusher(Pusher):
# or may do so (i.e. is encrypted so has unknown effects). # or may do so (i.e. is encrypted so has unknown effects).
priority = "high" priority = "high"
# This was checked in the __init__, but mypy doesn't seem to know that.
assert self.data is not None
if self.data.get("format") == "event_id_only": if self.data.get("format") == "event_id_only":
d = { d = {
"notification": { "notification": {

View File

@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional from typing import TYPE_CHECKING, Callable, Dict, Optional
from synapse.push import Pusher from synapse.push import Pusher, PusherConfig
from synapse.push.emailpusher import EmailPusher from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher from synapse.push.httppusher import HttpPusher
from synapse.push.mailer import Mailer from synapse.push.mailer import Mailer
@ -34,7 +34,7 @@ class PusherFactory:
self.pusher_types = { self.pusher_types = {
"http": HttpPusher "http": HttpPusher
} # type: Dict[str, Callable[[HomeServer, dict], Pusher]] } # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
logger.info("email enable notifs: %r", hs.config.email_enable_notifs) logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs: if hs.config.email_enable_notifs:
@ -47,18 +47,18 @@ class PusherFactory:
logger.info("defined email pusher type") logger.info("defined email pusher type")
def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]: def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
kind = pusherdict["kind"] kind = pusher_config.kind
f = self.pusher_types.get(kind, None) f = self.pusher_types.get(kind, None)
if not f: if not f:
return None return None
logger.debug("creating %s pusher for %r", kind, pusherdict) logger.debug("creating %s pusher for %r", kind, pusher_config)
return f(self.hs, pusherdict) return f(self.hs, pusher_config)
def _create_email_pusher( def _create_email_pusher(
self, _hs: "HomeServer", pusherdict: Dict[str, Any] self, _hs: "HomeServer", pusher_config: PusherConfig
) -> EmailPusher: ) -> EmailPusher:
app_name = self._app_name_from_pusherdict(pusherdict) app_name = self._app_name_from_pusherdict(pusher_config)
mailer = self.mailers.get(app_name) mailer = self.mailers.get(app_name)
if not mailer: if not mailer:
mailer = Mailer( mailer = Mailer(
@ -68,10 +68,10 @@ class PusherFactory:
template_text=self._notif_template_text, template_text=self._notif_template_text,
) )
self.mailers[app_name] = mailer self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer) return EmailPusher(self.hs, pusher_config, mailer)
def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str: def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str:
data = pusherdict["data"] data = pusher_config.data
if isinstance(data, dict): if isinstance(data, dict):
brand = data.get("brand") brand = data.get("brand")

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Dict, Iterable, Optional
from prometheus_client import Gauge from prometheus_client import Gauge
@ -23,9 +23,9 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process, run_as_background_process,
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.push import Pusher, PusherConfigException from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.push.pusher import PusherFactory from synapse.push.pusher import PusherFactory
from synapse.types import RoomStreamToken from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING: if TYPE_CHECKING:
@ -77,7 +77,7 @@ class PusherPool:
# map from user id to app_id:pushkey to pusher # map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Pusher]] self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
def start(self): def start(self) -> None:
"""Starts the pushers off in a background process. """Starts the pushers off in a background process.
""" """
if not self._should_start_pushers: if not self._should_start_pushers:
@ -87,16 +87,16 @@ class PusherPool:
async def add_pusher( async def add_pusher(
self, self,
user_id, user_id: str,
access_token, access_token: Optional[int],
kind, kind: str,
app_id, app_id: str,
app_display_name, app_display_name: str,
device_display_name, device_display_name: str,
pushkey, pushkey: str,
lang, lang: Optional[str],
data, data: JsonDict,
profile_tag="", profile_tag: str = "",
) -> Optional[Pusher]: ) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool """Creates a new pusher and adds it to the pool
@ -111,21 +111,23 @@ class PusherPool:
# recreated, added and started: this means we have only one # recreated, added and started: this means we have only one
# code path adding pushers. # code path adding pushers.
self.pusher_factory.create_pusher( self.pusher_factory.create_pusher(
{ PusherConfig(
"id": None, id=None,
"user_name": user_id, user_name=user_id,
"kind": kind, access_token=access_token,
"app_id": app_id, profile_tag=profile_tag,
"app_display_name": app_display_name, kind=kind,
"device_display_name": device_display_name, app_id=app_id,
"pushkey": pushkey, app_display_name=app_display_name,
"ts": time_now_msec, device_display_name=device_display_name,
"lang": lang, pushkey=pushkey,
"data": data, ts=time_now_msec,
"last_stream_ordering": None, lang=lang,
"last_success": None, data=data,
"failing_since": None, last_stream_ordering=None,
} last_success=None,
failing_since=None,
)
) )
# create the pusher setting last_stream_ordering to the current maximum # create the pusher setting last_stream_ordering to the current maximum
@ -151,43 +153,44 @@ class PusherPool:
return pusher return pusher
async def remove_pushers_by_app_id_and_pushkey_not_user( async def remove_pushers_by_app_id_and_pushkey_not_user(
self, app_id, pushkey, not_user_id self, app_id: str, pushkey: str, not_user_id: str
): ) -> None:
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove: for p in to_remove:
if p["user_name"] != not_user_id: if p.user_name != not_user_id:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
app_id, app_id,
pushkey, pushkey,
p["user_name"], p.user_name,
) )
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
async def remove_pushers_by_access_token(self, user_id, access_tokens): async def remove_pushers_by_access_token(
self, user_id: str, access_tokens: Iterable[int]
) -> None:
"""Remove the pushers for a given user corresponding to a set of """Remove the pushers for a given user corresponding to a set of
access_tokens. access_tokens.
Args: Args:
user_id (str): user to remove pushers for user_id: user to remove pushers for
access_tokens (Iterable[int]): access token *ids* to remove pushers access_tokens: access token *ids* to remove pushers for
for
""" """
if not self._pusher_shard_config.should_handle(self._instance_name, user_id): if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return return
tokens = set(access_tokens) tokens = set(access_tokens)
for p in await self.store.get_pushers_by_user_id(user_id): for p in await self.store.get_pushers_by_user_id(user_id):
if p["access_token"] in tokens: if p.access_token in tokens:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p["app_id"], p.app_id,
p["pushkey"], p.pushkey,
p["user_name"], p.user_name,
) )
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
def on_new_notifications(self, max_token: RoomStreamToken): def on_new_notifications(self, max_token: RoomStreamToken) -> None:
if not self.pushers: if not self.pushers:
# nothing to do here. # nothing to do here.
return return
@ -206,7 +209,7 @@ class PusherPool:
self._on_new_notifications(max_token) self._on_new_notifications(max_token)
@wrap_as_background_process("on_new_notifications") @wrap_as_background_process("on_new_notifications")
async def _on_new_notifications(self, max_token: RoomStreamToken): async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock # We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector # component. This is safe to do as long as we *always* ignore the vector
# clock components. # clock components.
@ -236,7 +239,9 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids): async def on_new_receipts(
self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
) -> None:
if not self.pushers: if not self.pushers:
# nothing to do here. # nothing to do here.
return return
@ -280,14 +285,14 @@ class PusherPool:
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None pusher_config = None
for r in resultlist: for r in resultlist:
if r["user_name"] == user_id: if r.user_name == user_id:
pusher_dict = r pusher_config = r
pusher = None pusher = None
if pusher_dict: if pusher_config:
pusher = await self._start_pusher(pusher_dict) pusher = await self._start_pusher(pusher_config)
return pusher return pusher
@ -302,44 +307,44 @@ class PusherPool:
logger.info("Started pushers") logger.info("Started pushers")
async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]: async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
"""Start the given pusher """Start the given pusher
Args: Args:
pusherdict: dict with the values pulled from the db table pusher_config: The pusher configuration with the values pulled from the db table
Returns: Returns:
The newly created pusher or None. The newly created pusher or None.
""" """
if not self._pusher_shard_config.should_handle( if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"] self._instance_name, pusher_config.user_name
): ):
return None return None
try: try:
p = self.pusher_factory.create_pusher(pusherdict) p = self.pusher_factory.create_pusher(pusher_config)
except PusherConfigException as e: except PusherConfigException as e:
logger.warning( logger.warning(
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
pusherdict["id"], pusher_config.id,
pusherdict.get("user_name"), pusher_config.user_name,
pusherdict.get("app_id"), pusher_config.app_id,
pusherdict.get("pushkey"), pusher_config.pushkey,
e, e,
) )
return None return None
except Exception: except Exception:
logger.exception( logger.exception(
"Couldn't start pusher id %i: caught Exception", pusherdict["id"], "Couldn't start pusher id %i: caught Exception", pusher_config.id,
) )
return None return None
if not p: if not p:
return None return None
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"]) appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
byuser = self.pushers.setdefault(pusherdict["user_name"], {}) byuser = self.pushers.setdefault(pusher_config.user_name, {})
if appid_pushkey in byuser: if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop() byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p byuser[appid_pushkey] = p
@ -349,8 +354,8 @@ class PusherPool:
# Check if there *may* be push to process. We do this as this check is a # Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to # lot cheaper to do than actually fetching the exact rows we need to
# push. # push.
user_id = pusherdict["user_name"] user_id = pusher_config.user_name
last_stream_ordering = pusherdict["last_stream_ordering"] last_stream_ordering = pusher_config.last_stream_ordering
if last_stream_ordering: if last_stream_ordering:
have_notifs = await self.store.get_if_maybe_push_in_range_for_user( have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering user_id, last_stream_ordering
@ -364,7 +369,7 @@ class PusherPool:
return p return p
async def remove_pusher(self, app_id, pushkey, user_id): async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey) appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {}) byuser = self.pushers.get(user_id, {})

View File

@ -12,21 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import _load_current_id from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker: class SlavedIdTracker:
def __init__(self, db_conn, table, column, extra_tables=[], step=1): def __init__(
self,
db_conn: Connection,
table: str,
column: str,
extra_tables: Optional[List[Tuple[str, str]]] = None,
step: int = 1,
):
self.step = step self.step = step
self._current = _load_current_id(db_conn, table, column, step) self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables: if extra_tables:
self.advance(None, _load_current_id(db_conn, table, column)) for table, column in extra_tables:
self.advance(None, _load_current_id(db_conn, table, column))
def advance(self, instance_name, new_id): def advance(self, instance_name: Optional[str], new_id: int):
self._current = (max if self.step > 0 else min)(self._current, new_id) self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self): def get_current_token(self) -> int:
""" """
Returns: Returns:

View File

@ -13,26 +13,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from synapse.replication.tcp.streams import PushersStream from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.pusher import PusherWorkerStore from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.types import Connection
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker( self._pushers_id_gen = SlavedIdTracker( # type: ignore
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
) )
def get_pushers_stream_token(self): def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self, stream_name: str, instance_name: str, token, rows
) -> None:
if stream_name == PushersStream.NAME: if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token) self._pushers_id_gen.advance(instance_name, token) # type: ignore
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@ -42,17 +42,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_GET_PUSHERS_ALLOWED_KEYS = {
"app_display_name",
"app_id",
"data",
"device_display_name",
"kind",
"lang",
"profile_tag",
"pushkey",
}
class UsersRestServlet(RestServlet): class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
@ -770,10 +759,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.store.get_pushers_by_user_id(user_id) pushers = await self.store.get_pushers_by_user_id(user_id)
filtered_pushers = [ filtered_pushers = [p.as_dict() for p in pushers]
{k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
for p in pushers
]
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)} return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}

View File

@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALLOWED_KEYS = {
"app_display_name",
"app_id",
"data",
"device_display_name",
"kind",
"lang",
"profile_tag",
"pushkey",
}
class PushersRestServlet(RestServlet): class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True) PATTERNS = client_patterns("/pushers$", v1=True)
@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
filtered_pushers = [ filtered_pushers = [p.as_dict() for p in pushers]
{k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
]
return 200, {"pushers": filtered_pushers} return 200, {"pushers": filtered_pushers}

View File

@ -149,9 +149,6 @@ class DataStore(
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
self._group_updates_id_gen = StreamIdGenerator( self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id" db_conn, "local_group_updates", "stream_id"
) )

View File

@ -15,18 +15,32 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, Iterator, List, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PusherWorkerStore(SQLBaseStore): class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table """JSON-decode the data in the rows returned from the `pushers` table
Drops any rows whose data cannot be decoded Drops any rows whose data cannot be decoded
@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore):
) )
continue continue
yield r yield PusherConfig(**r)
async def user_has_pusher(self, user_id): async def user_has_pusher(self, user_id: str) -> bool:
ret = await self.db_pool.simple_select_one_onecol( ret = await self.db_pool.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True "pushers", {"user_name": user_id}, "id", allow_none=True
) )
return ret is not None return ret is not None
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey): async def get_pushers_by_app_id_and_pushkey(
return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey}) self, app_id: str, pushkey: str
) -> Iterator[PusherConfig]:
return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
def get_pushers_by_user_id(self, user_id): async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
return self.get_pushers_by({"user_name": user_id}) return await self.get_pushers_by({"user_name": user_id})
async def get_pushers_by(self, keyvalues): async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
ret = await self.db_pool.simple_select_list( ret = await self.db_pool.simple_select_list(
"pushers", "pushers",
keyvalues, keyvalues,
@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore):
) )
return self._decode_pushers_rows(ret) return self._decode_pushers_rows(ret)
async def get_all_pushers(self): async def get_all_pushers(self) -> Iterator[PusherConfig]:
def get_pushers(txn): def get_pushers(txn):
txn.execute("SELECT * FROM pushers") txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore):
) )
@cached(num_args=1, max_entries=15000) @cached(num_args=1, max_entries=15000)
async def get_if_user_has_pusher(self, user_id): async def get_if_user_has_pusher(self, user_id: str):
# This only exists for the cachedList decorator # This only exists for the cachedList decorator
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1, cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
) )
async def get_if_users_have_pushers(self, user_ids): async def get_if_users_have_pushers(
self, user_ids: Iterable[str]
) -> Dict[str, bool]:
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="pushers", table="pushers",
column="user_name", column="user_name",
@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore):
return bool(updated) return bool(updated)
async def update_pusher_failing_since( async def update_pusher_failing_since(
self, app_id, pushkey, user_id, failing_since self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
) -> None: ) -> None:
await self.db_pool.simple_update( await self.db_pool.simple_update(
table="pushers", table="pushers",
@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore):
desc="update_pusher_failing_since", desc="update_pusher_failing_since",
) )
async def get_throttle_params_by_room(self, pusher_id): async def get_throttle_params_by_room(
self, pusher_id: str
) -> Dict[str, ThrottleParams]:
res = await self.db_pool.simple_select_list( res = await self.db_pool.simple_select_list(
"pusher_throttle", "pusher_throttle",
{"pusher": pusher_id}, {"pusher": pusher_id},
@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore):
params_by_room = {} params_by_room = {}
for row in res: for row in res:
params_by_room[row["room_id"]] = { params_by_room[row["room_id"]] = ThrottleParams(
"last_sent_ts": row["last_sent_ts"], row["last_sent_ts"], row["throttle_ms"],
"throttle_ms": row["throttle_ms"], )
}
return params_by_room return params_by_room
async def set_throttle_params(self, pusher_id, room_id, params) -> None: async def set_throttle_params(
self, pusher_id: str, room_id: str, params: ThrottleParams
) -> None:
# no need to lock because `pusher_throttle` has a primary key on # no need to lock because `pusher_throttle` has a primary key on
# (pusher, room_id) so simple_upsert will retry # (pusher, room_id) so simple_upsert will retry
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
"pusher_throttle", "pusher_throttle",
{"pusher": pusher_id, "room_id": room_id}, {"pusher": pusher_id, "room_id": room_id},
params, {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
desc="set_throttle_params", desc="set_throttle_params",
lock=False, lock=False,
) )
class PusherStore(PusherWorkerStore): class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self): def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
async def add_pusher( async def add_pusher(
self, self,
user_id, user_id: str,
access_token, access_token: Optional[int],
kind, kind: str,
app_id, app_id: str,
app_display_name, app_display_name: str,
device_display_name, device_display_name: str,
pushkey, pushkey: str,
pushkey_ts, pushkey_ts: int,
lang, lang: Optional[str],
data, data: Optional[JsonDict],
last_stream_ordering, last_stream_ordering: int,
profile_tag="", profile_tag: str = "",
) -> None: ) -> None:
async with self._pushers_id_gen.get_next() as stream_id: async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on # no need to lock because `pushers` has a unique key on
@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore):
# invalidate, since we the user might not have had a pusher before # invalidate, since we the user might not have had a pusher before
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_pusher", "add_pusher",
self._invalidate_cache_and_stream, self._invalidate_cache_and_stream, # type: ignore
self.get_if_user_has_pusher, self.get_if_user_has_pusher,
(user_id,), (user_id,),
) )
async def delete_pusher_by_app_id_pushkey_user_id( async def delete_pusher_by_app_id_pushkey_user_id(
self, app_id, pushkey, user_id self, app_id: str, pushkey: str, user_id: str
) -> None: ) -> None:
def delete_pusher_txn(txn, stream_id): def delete_pusher_txn(txn, stream_id):
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream( # type: ignore
txn, self.get_if_user_has_pusher, (user_id,) txn, self.get_if_user_has_pusher, (user_id,)
) )

View File

@ -153,12 +153,12 @@ class StreamIdGenerator:
return _AsyncCtxManagerWrapper(manager()) return _AsyncCtxManagerWrapper(manager())
def get_current_token(self): def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted. equal to it have been successfully persisted.
Returns: Returns:
int The maximum stream id.
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:

View File

@ -209,7 +209,7 @@ class EmailPusherTests(HomeserverTestCase):
) )
pushers = list(pushers) pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0]["last_stream_ordering"] last_stream_ordering = pushers[0].last_stream_ordering
# Advance time a bit, so the pusher will register something has happened # Advance time a bit, so the pusher will register something has happened
self.pump(10) self.pump(10)
@ -220,7 +220,7 @@ class EmailPusherTests(HomeserverTestCase):
) )
pushers = list(pushers) pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
# One email was attempted to be sent # One email was attempted to be sent
self.assertEqual(len(self.email_attempts), 1) self.assertEqual(len(self.email_attempts), 1)
@ -238,4 +238,4 @@ class EmailPusherTests(HomeserverTestCase):
) )
pushers = list(pushers) pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)

View File

@ -144,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase):
) )
pushers = list(pushers) pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0]["last_stream_ordering"] last_stream_ordering = pushers[0].last_stream_ordering
# Advance time a bit, so the pusher will register something has happened # Advance time a bit, so the pusher will register something has happened
self.pump() self.pump()
@ -155,7 +155,7 @@ class HTTPPusherTests(HomeserverTestCase):
) )
pushers = list(pushers) pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
# One push was attempted to be sent -- it'll be the first message # One push was attempted to be sent -- it'll be the first message
self.assertEqual(len(self.push_attempts), 1) self.assertEqual(len(self.push_attempts), 1)
@ -176,8 +176,8 @@ class HTTPPusherTests(HomeserverTestCase):
) )
pushers = list(pushers) pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
last_stream_ordering = pushers[0]["last_stream_ordering"] last_stream_ordering = pushers[0].last_stream_ordering
# Now it'll try and send the second push message, which will be the second one # Now it'll try and send the second push message, which will be the second one
self.assertEqual(len(self.push_attempts), 2) self.assertEqual(len(self.push_attempts), 2)
@ -198,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase):
) )
pushers = list(pushers) pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
def test_sends_high_priority_for_encrypted(self): def test_sends_high_priority_for_encrypted(self):
""" """

View File

@ -766,7 +766,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
) )
pushers = list(pushers) pushers = list(pushers)
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertEqual("@bob:test", pushers[0]["user_name"]) self.assertEqual("@bob:test", pushers[0].user_name)
@override_config( @override_config(
{ {