Convert internal pusher dicts to attrs classes. (#8940)

This improves type hinting and should use less memory.
This commit is contained in:
Patrick Cloke 2020-12-16 11:25:30 -05:00 committed by GitHub
parent 7a332850e6
commit bd30cfe86a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 266 additions and 204 deletions

View file

@ -14,24 +14,70 @@
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Optional
from synapse.types import RoomStreamToken
import attr
from synapse.types import JsonDict, RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@attr.s(slots=True)
class PusherConfig:
"""Parameters necessary to configure a pusher."""
id = attr.ib(type=Optional[str])
user_name = attr.ib(type=str)
access_token = attr.ib(type=Optional[int])
profile_tag = attr.ib(type=str)
kind = attr.ib(type=str)
app_id = attr.ib(type=str)
app_display_name = attr.ib(type=str)
device_display_name = attr.ib(type=str)
pushkey = attr.ib(type=str)
ts = attr.ib(type=int)
lang = attr.ib(type=Optional[str])
data = attr.ib(type=Optional[JsonDict])
last_stream_ordering = attr.ib(type=Optional[int])
last_success = attr.ib(type=Optional[int])
failing_since = attr.ib(type=Optional[int])
def as_dict(self) -> Dict[str, Any]:
"""Information that can be retrieved about a pusher after creation."""
return {
"app_display_name": self.app_display_name,
"app_id": self.app_id,
"data": self.data,
"device_display_name": self.device_display_name,
"kind": self.kind,
"lang": self.lang,
"profile_tag": self.profile_tag,
"pushkey": self.pushkey,
}
@attr.s(slots=True)
class ThrottleParams:
"""Parameters for controlling the rate of sending pushes via email."""
last_sent_ts = attr.ib(type=int)
throttle_ms = attr.ib(type=int)
class Pusher(metaclass=abc.ABCMeta):
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pusher_id = pusherdict["id"]
self.user_id = pusherdict["user_name"]
self.app_id = pusherdict["app_id"]
self.pushkey = pusherdict["pushkey"]
self.pusher_id = pusher_config.id
self.user_id = pusher_config.user_name
self.app_id = pusher_config.app_id
self.pushkey = pusher_config.pushkey
self.last_stream_ordering = pusher_config.last_stream_ordering
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we

View file

@ -14,13 +14,13 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional
from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher
from synapse.push import Pusher, PusherConfig, ThrottleParams
from synapse.push.mailer import Mailer
if TYPE_CHECKING:
@ -60,15 +60,14 @@ class EmailPusher(Pusher):
factor out the common parts
"""
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
super().__init__(hs, pusherdict)
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
super().__init__(hs, pusher_config)
self.mailer = mailer
self.store = self.hs.get_datastore()
self.email = pusherdict["pushkey"]
self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.email = pusher_config.pushkey
self.timed_call = None # type: Optional[DelayedCall]
self.throttle_params = {} # type: Dict[str, Dict[str, int]]
self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False
self._is_processing = False
@ -132,6 +131,7 @@ class EmailPusher(Pusher):
if not self._inited:
# this is our first loop: load up the throttle params
assert self.pusher_id is not None
self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id
)
@ -157,6 +157,7 @@ class EmailPusher(Pusher):
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
assert start is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering
)
@ -244,13 +245,13 @@ class EmailPusher(Pusher):
def get_room_throttle_ms(self, room_id: str) -> int:
if room_id in self.throttle_params:
return self.throttle_params[room_id]["throttle_ms"]
return self.throttle_params[room_id].throttle_ms
else:
return 0
def get_room_last_sent_ts(self, room_id: str) -> int:
if room_id in self.throttle_params:
return self.throttle_params[room_id]["last_sent_ts"]
return self.throttle_params[room_id].last_sent_ts
else:
return 0
@ -301,10 +302,10 @@ class EmailPusher(Pusher):
new_throttle_ms = min(
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
)
self.throttle_params[room_id] = {
"last_sent_ts": self.clock.time_msec(),
"throttle_ms": new_throttle_ms,
}
self.throttle_params[room_id] = ThrottleParams(
self.clock.time_msec(), new_throttle_ms,
)
assert self.pusher_id is not None
await self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
)

View file

@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfigException
from synapse.push import Pusher, PusherConfig, PusherConfigException
from . import push_rule_evaluator, push_tools
@ -62,33 +62,29 @@ class HttpPusher(Pusher):
# This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
super().__init__(hs, pusherdict)
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
super().__init__(hs, pusher_config)
self.storage = self.hs.get_storage()
self.app_display_name = pusherdict["app_display_name"]
self.device_display_name = pusherdict["device_display_name"]
self.pushkey_ts = pusherdict["ts"]
self.data = pusherdict["data"]
self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.app_display_name = pusher_config.app_display_name
self.device_display_name = pusher_config.device_display_name
self.pushkey_ts = pusher_config.ts
self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusherdict["failing_since"]
self.failing_since = pusher_config.failing_since
self.timed_call = None
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
if "data" not in pusherdict:
raise PusherConfigException("No 'data' key for HTTP pusher")
self.data = pusherdict["data"]
self.data = pusher_config.data
if self.data is None:
raise PusherConfigException("'data' key can not be null for HTTP pusher")
self.name = "%s/%s/%s" % (
pusherdict["user_name"],
pusherdict["app_id"],
pusherdict["pushkey"],
pusher_config.user_name,
pusher_config.app_id,
pusher_config.pushkey,
)
if self.data is None:
raise PusherConfigException("data can not be null for HTTP pusher")
# Validate that there's a URL and it is of the proper form.
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
@ -180,6 +176,7 @@ class HttpPusher(Pusher):
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
assert self.last_stream_ordering is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
@ -208,6 +205,7 @@ class HttpPusher(Pusher):
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
assert self.last_stream_ordering is not None
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.pushkey,
@ -314,6 +312,8 @@ class HttpPusher(Pusher):
# or may do so (i.e. is encrypted so has unknown effects).
priority = "high"
# This was checked in the __init__, but mypy doesn't seem to know that.
assert self.data is not None
if self.data.get("format") == "event_id_only":
d = {
"notification": {

View file

@ -14,9 +14,9 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from typing import TYPE_CHECKING, Callable, Dict, Optional
from synapse.push import Pusher
from synapse.push import Pusher, PusherConfig
from synapse.push.emailpusher import EmailPusher
from synapse.push.httppusher import HttpPusher
from synapse.push.mailer import Mailer
@ -34,7 +34,7 @@ class PusherFactory:
self.pusher_types = {
"http": HttpPusher
} # type: Dict[str, Callable[[HomeServer, dict], Pusher]]
} # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
@ -47,18 +47,18 @@ class PusherFactory:
logger.info("defined email pusher type")
def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
kind = pusherdict["kind"]
def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
kind = pusher_config.kind
f = self.pusher_types.get(kind, None)
if not f:
return None
logger.debug("creating %s pusher for %r", kind, pusherdict)
return f(self.hs, pusherdict)
logger.debug("creating %s pusher for %r", kind, pusher_config)
return f(self.hs, pusher_config)
def _create_email_pusher(
self, _hs: "HomeServer", pusherdict: Dict[str, Any]
self, _hs: "HomeServer", pusher_config: PusherConfig
) -> EmailPusher:
app_name = self._app_name_from_pusherdict(pusherdict)
app_name = self._app_name_from_pusherdict(pusher_config)
mailer = self.mailers.get(app_name)
if not mailer:
mailer = Mailer(
@ -68,10 +68,10 @@ class PusherFactory:
template_text=self._notif_template_text,
)
self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer)
return EmailPusher(self.hs, pusher_config, mailer)
def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str:
data = pusherdict["data"]
def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str:
data = pusher_config.data
if isinstance(data, dict):
brand = data.get("brand")

View file

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Dict, Iterable, Optional
from prometheus_client import Gauge
@ -23,9 +23,9 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.push import Pusher, PusherConfigException
from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.push.pusher import PusherFactory
from synapse.types import RoomStreamToken
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
@ -77,7 +77,7 @@ class PusherPool:
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
def start(self):
def start(self) -> None:
"""Starts the pushers off in a background process.
"""
if not self._should_start_pushers:
@ -87,16 +87,16 @@ class PusherPool:
async def add_pusher(
self,
user_id,
access_token,
kind,
app_id,
app_display_name,
device_display_name,
pushkey,
lang,
data,
profile_tag="",
user_id: str,
access_token: Optional[int],
kind: str,
app_id: str,
app_display_name: str,
device_display_name: str,
pushkey: str,
lang: Optional[str],
data: JsonDict,
profile_tag: str = "",
) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool
@ -111,21 +111,23 @@ class PusherPool:
# recreated, added and started: this means we have only one
# code path adding pushers.
self.pusher_factory.create_pusher(
{
"id": None,
"user_name": user_id,
"kind": kind,
"app_id": app_id,
"app_display_name": app_display_name,
"device_display_name": device_display_name,
"pushkey": pushkey,
"ts": time_now_msec,
"lang": lang,
"data": data,
"last_stream_ordering": None,
"last_success": None,
"failing_since": None,
}
PusherConfig(
id=None,
user_name=user_id,
access_token=access_token,
profile_tag=profile_tag,
kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
ts=time_now_msec,
lang=lang,
data=data,
last_stream_ordering=None,
last_success=None,
failing_since=None,
)
)
# create the pusher setting last_stream_ordering to the current maximum
@ -151,43 +153,44 @@ class PusherPool:
return pusher
async def remove_pushers_by_app_id_and_pushkey_not_user(
self, app_id, pushkey, not_user_id
):
self, app_id: str, pushkey: str, not_user_id: str
) -> None:
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove:
if p["user_name"] != not_user_id:
if p.user_name != not_user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
app_id,
pushkey,
p["user_name"],
p.user_name,
)
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
async def remove_pushers_by_access_token(self, user_id, access_tokens):
async def remove_pushers_by_access_token(
self, user_id: str, access_tokens: Iterable[int]
) -> None:
"""Remove the pushers for a given user corresponding to a set of
access_tokens.
Args:
user_id (str): user to remove pushers for
access_tokens (Iterable[int]): access token *ids* to remove pushers
for
user_id: user to remove pushers for
access_tokens: access token *ids* to remove pushers for
"""
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
tokens = set(access_tokens)
for p in await self.store.get_pushers_by_user_id(user_id):
if p["access_token"] in tokens:
if p.access_token in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p["app_id"],
p["pushkey"],
p["user_name"],
p.app_id,
p.pushkey,
p.user_name,
)
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
def on_new_notifications(self, max_token: RoomStreamToken):
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
if not self.pushers:
# nothing to do here.
return
@ -206,7 +209,7 @@ class PusherPool:
self._on_new_notifications(max_token)
@wrap_as_background_process("on_new_notifications")
async def _on_new_notifications(self, max_token: RoomStreamToken):
async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
@ -236,7 +239,9 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
async def on_new_receipts(
self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
) -> None:
if not self.pushers:
# nothing to do here.
return
@ -280,14 +285,14 @@ class PusherPool:
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None
pusher_config = None
for r in resultlist:
if r["user_name"] == user_id:
pusher_dict = r
if r.user_name == user_id:
pusher_config = r
pusher = None
if pusher_dict:
pusher = await self._start_pusher(pusher_dict)
if pusher_config:
pusher = await self._start_pusher(pusher_config)
return pusher
@ -302,44 +307,44 @@ class PusherPool:
logger.info("Started pushers")
async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
"""Start the given pusher
Args:
pusherdict: dict with the values pulled from the db table
pusher_config: The pusher configuration with the values pulled from the db table
Returns:
The newly created pusher or None.
"""
if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"]
self._instance_name, pusher_config.user_name
):
return None
try:
p = self.pusher_factory.create_pusher(pusherdict)
p = self.pusher_factory.create_pusher(pusher_config)
except PusherConfigException as e:
logger.warning(
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
pusherdict["id"],
pusherdict.get("user_name"),
pusherdict.get("app_id"),
pusherdict.get("pushkey"),
pusher_config.id,
pusher_config.user_name,
pusher_config.app_id,
pusher_config.pushkey,
e,
)
return None
except Exception:
logger.exception(
"Couldn't start pusher id %i: caught Exception", pusherdict["id"],
"Couldn't start pusher id %i: caught Exception", pusher_config.id,
)
return None
if not p:
return None
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
byuser = self.pushers.setdefault(pusherdict["user_name"], {})
byuser = self.pushers.setdefault(pusher_config.user_name, {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
@ -349,8 +354,8 @@ class PusherPool:
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
# push.
user_id = pusherdict["user_name"]
last_stream_ordering = pusherdict["last_stream_ordering"]
user_id = pusher_config.user_name
last_stream_ordering = pusher_config.last_stream_ordering
if last_stream_ordering:
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering
@ -364,7 +369,7 @@ class PusherPool:
return p
async def remove_pusher(self, app_id, pushkey, user_id):
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})