mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Convert internal pusher dicts to attrs classes. (#8940)
This improves type hinting and should use less memory.
This commit is contained in:
parent
7a332850e6
commit
bd30cfe86a
1
changelog.d/8940.misc
Normal file
1
changelog.d/8940.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add type hints to push module.
|
1
mypy.ini
1
mypy.ini
@ -65,6 +65,7 @@ files =
|
|||||||
synapse/state,
|
synapse/state,
|
||||||
synapse/storage/databases/main/appservice.py,
|
synapse/storage/databases/main/appservice.py,
|
||||||
synapse/storage/databases/main/events.py,
|
synapse/storage/databases/main/events.py,
|
||||||
|
synapse/storage/databases/main/pusher.py,
|
||||||
synapse/storage/databases/main/registration.py,
|
synapse/storage/databases/main/registration.py,
|
||||||
synapse/storage/databases/main/stream.py,
|
synapse/storage/databases/main/stream.py,
|
||||||
synapse/storage/databases/main/ui_auth.py,
|
synapse/storage/databases/main/ui_auth.py,
|
||||||
|
@ -14,24 +14,70 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
from synapse.types import RoomStreamToken
|
import attr
|
||||||
|
|
||||||
|
from synapse.types import JsonDict, RoomStreamToken
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.app.homeserver import HomeServer
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class PusherConfig:
|
||||||
|
"""Parameters necessary to configure a pusher."""
|
||||||
|
|
||||||
|
id = attr.ib(type=Optional[str])
|
||||||
|
user_name = attr.ib(type=str)
|
||||||
|
access_token = attr.ib(type=Optional[int])
|
||||||
|
profile_tag = attr.ib(type=str)
|
||||||
|
kind = attr.ib(type=str)
|
||||||
|
app_id = attr.ib(type=str)
|
||||||
|
app_display_name = attr.ib(type=str)
|
||||||
|
device_display_name = attr.ib(type=str)
|
||||||
|
pushkey = attr.ib(type=str)
|
||||||
|
ts = attr.ib(type=int)
|
||||||
|
lang = attr.ib(type=Optional[str])
|
||||||
|
data = attr.ib(type=Optional[JsonDict])
|
||||||
|
last_stream_ordering = attr.ib(type=Optional[int])
|
||||||
|
last_success = attr.ib(type=Optional[int])
|
||||||
|
failing_since = attr.ib(type=Optional[int])
|
||||||
|
|
||||||
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Information that can be retrieved about a pusher after creation."""
|
||||||
|
return {
|
||||||
|
"app_display_name": self.app_display_name,
|
||||||
|
"app_id": self.app_id,
|
||||||
|
"data": self.data,
|
||||||
|
"device_display_name": self.device_display_name,
|
||||||
|
"kind": self.kind,
|
||||||
|
"lang": self.lang,
|
||||||
|
"profile_tag": self.profile_tag,
|
||||||
|
"pushkey": self.pushkey,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True)
|
||||||
|
class ThrottleParams:
|
||||||
|
"""Parameters for controlling the rate of sending pushes via email."""
|
||||||
|
|
||||||
|
last_sent_ts = attr.ib(type=int)
|
||||||
|
throttle_ms = attr.ib(type=int)
|
||||||
|
|
||||||
|
|
||||||
class Pusher(metaclass=abc.ABCMeta):
|
class Pusher(metaclass=abc.ABCMeta):
|
||||||
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
|
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
|
|
||||||
self.pusher_id = pusherdict["id"]
|
self.pusher_id = pusher_config.id
|
||||||
self.user_id = pusherdict["user_name"]
|
self.user_id = pusher_config.user_name
|
||||||
self.app_id = pusherdict["app_id"]
|
self.app_id = pusher_config.app_id
|
||||||
self.pushkey = pusherdict["pushkey"]
|
self.pushkey = pusher_config.pushkey
|
||||||
|
|
||||||
|
self.last_stream_ordering = pusher_config.last_stream_ordering
|
||||||
|
|
||||||
# This is the highest stream ordering we know it's safe to process.
|
# This is the highest stream ordering we know it's safe to process.
|
||||||
# When new events arrive, we'll be given a window of new events: we
|
# When new events arrive, we'll be given a window of new events: we
|
||||||
|
@ -14,13 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
from twisted.internet.base import DelayedCall
|
from twisted.internet.base import DelayedCall
|
||||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.push import Pusher
|
from synapse.push import Pusher, PusherConfig, ThrottleParams
|
||||||
from synapse.push.mailer import Mailer
|
from synapse.push.mailer import Mailer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -60,15 +60,14 @@ class EmailPusher(Pusher):
|
|||||||
factor out the common parts
|
factor out the common parts
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
|
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
|
||||||
super().__init__(hs, pusherdict)
|
super().__init__(hs, pusher_config)
|
||||||
self.mailer = mailer
|
self.mailer = mailer
|
||||||
|
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.email = pusherdict["pushkey"]
|
self.email = pusher_config.pushkey
|
||||||
self.last_stream_ordering = pusherdict["last_stream_ordering"]
|
|
||||||
self.timed_call = None # type: Optional[DelayedCall]
|
self.timed_call = None # type: Optional[DelayedCall]
|
||||||
self.throttle_params = {} # type: Dict[str, Dict[str, int]]
|
self.throttle_params = {} # type: Dict[str, ThrottleParams]
|
||||||
self._inited = False
|
self._inited = False
|
||||||
|
|
||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
@ -132,6 +131,7 @@ class EmailPusher(Pusher):
|
|||||||
|
|
||||||
if not self._inited:
|
if not self._inited:
|
||||||
# this is our first loop: load up the throttle params
|
# this is our first loop: load up the throttle params
|
||||||
|
assert self.pusher_id is not None
|
||||||
self.throttle_params = await self.store.get_throttle_params_by_room(
|
self.throttle_params = await self.store.get_throttle_params_by_room(
|
||||||
self.pusher_id
|
self.pusher_id
|
||||||
)
|
)
|
||||||
@ -157,6 +157,7 @@ class EmailPusher(Pusher):
|
|||||||
being run.
|
being run.
|
||||||
"""
|
"""
|
||||||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
||||||
|
assert start is not None
|
||||||
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
|
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
|
||||||
self.user_id, start, self.max_stream_ordering
|
self.user_id, start, self.max_stream_ordering
|
||||||
)
|
)
|
||||||
@ -244,13 +245,13 @@ class EmailPusher(Pusher):
|
|||||||
|
|
||||||
def get_room_throttle_ms(self, room_id: str) -> int:
|
def get_room_throttle_ms(self, room_id: str) -> int:
|
||||||
if room_id in self.throttle_params:
|
if room_id in self.throttle_params:
|
||||||
return self.throttle_params[room_id]["throttle_ms"]
|
return self.throttle_params[room_id].throttle_ms
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def get_room_last_sent_ts(self, room_id: str) -> int:
|
def get_room_last_sent_ts(self, room_id: str) -> int:
|
||||||
if room_id in self.throttle_params:
|
if room_id in self.throttle_params:
|
||||||
return self.throttle_params[room_id]["last_sent_ts"]
|
return self.throttle_params[room_id].last_sent_ts
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -301,10 +302,10 @@ class EmailPusher(Pusher):
|
|||||||
new_throttle_ms = min(
|
new_throttle_ms = min(
|
||||||
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
|
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
|
||||||
)
|
)
|
||||||
self.throttle_params[room_id] = {
|
self.throttle_params[room_id] = ThrottleParams(
|
||||||
"last_sent_ts": self.clock.time_msec(),
|
self.clock.time_msec(), new_throttle_ms,
|
||||||
"throttle_ms": new_throttle_ms,
|
)
|
||||||
}
|
assert self.pusher_id is not None
|
||||||
await self.store.set_throttle_params(
|
await self.store.set_throttle_params(
|
||||||
self.pusher_id, room_id, self.throttle_params[room_id]
|
self.pusher_id, room_id, self.throttle_params[room_id]
|
||||||
)
|
)
|
||||||
|
@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes
|
|||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.push import Pusher, PusherConfigException
|
from synapse.push import Pusher, PusherConfig, PusherConfigException
|
||||||
|
|
||||||
from . import push_rule_evaluator, push_tools
|
from . import push_rule_evaluator, push_tools
|
||||||
|
|
||||||
@ -62,33 +62,29 @@ class HttpPusher(Pusher):
|
|||||||
# This one's in ms because we compare it against the clock
|
# This one's in ms because we compare it against the clock
|
||||||
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
|
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
|
def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
|
||||||
super().__init__(hs, pusherdict)
|
super().__init__(hs, pusher_config)
|
||||||
self.storage = self.hs.get_storage()
|
self.storage = self.hs.get_storage()
|
||||||
self.app_display_name = pusherdict["app_display_name"]
|
self.app_display_name = pusher_config.app_display_name
|
||||||
self.device_display_name = pusherdict["device_display_name"]
|
self.device_display_name = pusher_config.device_display_name
|
||||||
self.pushkey_ts = pusherdict["ts"]
|
self.pushkey_ts = pusher_config.ts
|
||||||
self.data = pusherdict["data"]
|
self.data = pusher_config.data
|
||||||
self.last_stream_ordering = pusherdict["last_stream_ordering"]
|
|
||||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||||
self.failing_since = pusherdict["failing_since"]
|
self.failing_since = pusher_config.failing_since
|
||||||
self.timed_call = None
|
self.timed_call = None
|
||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
|
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
|
||||||
|
|
||||||
if "data" not in pusherdict:
|
self.data = pusher_config.data
|
||||||
raise PusherConfigException("No 'data' key for HTTP pusher")
|
if self.data is None:
|
||||||
self.data = pusherdict["data"]
|
raise PusherConfigException("'data' key can not be null for HTTP pusher")
|
||||||
|
|
||||||
self.name = "%s/%s/%s" % (
|
self.name = "%s/%s/%s" % (
|
||||||
pusherdict["user_name"],
|
pusher_config.user_name,
|
||||||
pusherdict["app_id"],
|
pusher_config.app_id,
|
||||||
pusherdict["pushkey"],
|
pusher_config.pushkey,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.data is None:
|
|
||||||
raise PusherConfigException("data can not be null for HTTP pusher")
|
|
||||||
|
|
||||||
# Validate that there's a URL and it is of the proper form.
|
# Validate that there's a URL and it is of the proper form.
|
||||||
if "url" not in self.data:
|
if "url" not in self.data:
|
||||||
raise PusherConfigException("'url' required in data for HTTP pusher")
|
raise PusherConfigException("'url' required in data for HTTP pusher")
|
||||||
@ -180,6 +176,7 @@ class HttpPusher(Pusher):
|
|||||||
Never call this directly: use _process which will only allow this to
|
Never call this directly: use _process which will only allow this to
|
||||||
run once per pusher.
|
run once per pusher.
|
||||||
"""
|
"""
|
||||||
|
assert self.last_stream_ordering is not None
|
||||||
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
|
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
|
||||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||||
)
|
)
|
||||||
@ -208,6 +205,7 @@ class HttpPusher(Pusher):
|
|||||||
http_push_processed_counter.inc()
|
http_push_processed_counter.inc()
|
||||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||||
self.last_stream_ordering = push_action["stream_ordering"]
|
self.last_stream_ordering = push_action["stream_ordering"]
|
||||||
|
assert self.last_stream_ordering is not None
|
||||||
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
|
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.pushkey,
|
self.pushkey,
|
||||||
@ -314,6 +312,8 @@ class HttpPusher(Pusher):
|
|||||||
# or may do so (i.e. is encrypted so has unknown effects).
|
# or may do so (i.e. is encrypted so has unknown effects).
|
||||||
priority = "high"
|
priority = "high"
|
||||||
|
|
||||||
|
# This was checked in the __init__, but mypy doesn't seem to know that.
|
||||||
|
assert self.data is not None
|
||||||
if self.data.get("format") == "event_id_only":
|
if self.data.get("format") == "event_id_only":
|
||||||
d = {
|
d = {
|
||||||
"notification": {
|
"notification": {
|
||||||
|
@ -14,9 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, Optional
|
||||||
|
|
||||||
from synapse.push import Pusher
|
from synapse.push import Pusher, PusherConfig
|
||||||
from synapse.push.emailpusher import EmailPusher
|
from synapse.push.emailpusher import EmailPusher
|
||||||
from synapse.push.httppusher import HttpPusher
|
from synapse.push.httppusher import HttpPusher
|
||||||
from synapse.push.mailer import Mailer
|
from synapse.push.mailer import Mailer
|
||||||
@ -34,7 +34,7 @@ class PusherFactory:
|
|||||||
|
|
||||||
self.pusher_types = {
|
self.pusher_types = {
|
||||||
"http": HttpPusher
|
"http": HttpPusher
|
||||||
} # type: Dict[str, Callable[[HomeServer, dict], Pusher]]
|
} # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
|
||||||
|
|
||||||
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
|
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
|
||||||
if hs.config.email_enable_notifs:
|
if hs.config.email_enable_notifs:
|
||||||
@ -47,18 +47,18 @@ class PusherFactory:
|
|||||||
|
|
||||||
logger.info("defined email pusher type")
|
logger.info("defined email pusher type")
|
||||||
|
|
||||||
def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
|
def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
|
||||||
kind = pusherdict["kind"]
|
kind = pusher_config.kind
|
||||||
f = self.pusher_types.get(kind, None)
|
f = self.pusher_types.get(kind, None)
|
||||||
if not f:
|
if not f:
|
||||||
return None
|
return None
|
||||||
logger.debug("creating %s pusher for %r", kind, pusherdict)
|
logger.debug("creating %s pusher for %r", kind, pusher_config)
|
||||||
return f(self.hs, pusherdict)
|
return f(self.hs, pusher_config)
|
||||||
|
|
||||||
def _create_email_pusher(
|
def _create_email_pusher(
|
||||||
self, _hs: "HomeServer", pusherdict: Dict[str, Any]
|
self, _hs: "HomeServer", pusher_config: PusherConfig
|
||||||
) -> EmailPusher:
|
) -> EmailPusher:
|
||||||
app_name = self._app_name_from_pusherdict(pusherdict)
|
app_name = self._app_name_from_pusherdict(pusher_config)
|
||||||
mailer = self.mailers.get(app_name)
|
mailer = self.mailers.get(app_name)
|
||||||
if not mailer:
|
if not mailer:
|
||||||
mailer = Mailer(
|
mailer = Mailer(
|
||||||
@ -68,10 +68,10 @@ class PusherFactory:
|
|||||||
template_text=self._notif_template_text,
|
template_text=self._notif_template_text,
|
||||||
)
|
)
|
||||||
self.mailers[app_name] = mailer
|
self.mailers[app_name] = mailer
|
||||||
return EmailPusher(self.hs, pusherdict, mailer)
|
return EmailPusher(self.hs, pusher_config, mailer)
|
||||||
|
|
||||||
def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str:
|
def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str:
|
||||||
data = pusherdict["data"]
|
data = pusher_config.data
|
||||||
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
brand = data.get("brand")
|
brand = data.get("brand")
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Dict, Iterable, Optional
|
||||||
|
|
||||||
from prometheus_client import Gauge
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
@ -23,9 +23,9 @@ from synapse.metrics.background_process_metrics import (
|
|||||||
run_as_background_process,
|
run_as_background_process,
|
||||||
wrap_as_background_process,
|
wrap_as_background_process,
|
||||||
)
|
)
|
||||||
from synapse.push import Pusher, PusherConfigException
|
from synapse.push import Pusher, PusherConfig, PusherConfigException
|
||||||
from synapse.push.pusher import PusherFactory
|
from synapse.push.pusher import PusherFactory
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import JsonDict, RoomStreamToken
|
||||||
from synapse.util.async_helpers import concurrently_execute
|
from synapse.util.async_helpers import concurrently_execute
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -77,7 +77,7 @@ class PusherPool:
|
|||||||
# map from user id to app_id:pushkey to pusher
|
# map from user id to app_id:pushkey to pusher
|
||||||
self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
|
self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
|
||||||
|
|
||||||
def start(self):
|
def start(self) -> None:
|
||||||
"""Starts the pushers off in a background process.
|
"""Starts the pushers off in a background process.
|
||||||
"""
|
"""
|
||||||
if not self._should_start_pushers:
|
if not self._should_start_pushers:
|
||||||
@ -87,16 +87,16 @@ class PusherPool:
|
|||||||
|
|
||||||
async def add_pusher(
|
async def add_pusher(
|
||||||
self,
|
self,
|
||||||
user_id,
|
user_id: str,
|
||||||
access_token,
|
access_token: Optional[int],
|
||||||
kind,
|
kind: str,
|
||||||
app_id,
|
app_id: str,
|
||||||
app_display_name,
|
app_display_name: str,
|
||||||
device_display_name,
|
device_display_name: str,
|
||||||
pushkey,
|
pushkey: str,
|
||||||
lang,
|
lang: Optional[str],
|
||||||
data,
|
data: JsonDict,
|
||||||
profile_tag="",
|
profile_tag: str = "",
|
||||||
) -> Optional[Pusher]:
|
) -> Optional[Pusher]:
|
||||||
"""Creates a new pusher and adds it to the pool
|
"""Creates a new pusher and adds it to the pool
|
||||||
|
|
||||||
@ -111,21 +111,23 @@ class PusherPool:
|
|||||||
# recreated, added and started: this means we have only one
|
# recreated, added and started: this means we have only one
|
||||||
# code path adding pushers.
|
# code path adding pushers.
|
||||||
self.pusher_factory.create_pusher(
|
self.pusher_factory.create_pusher(
|
||||||
{
|
PusherConfig(
|
||||||
"id": None,
|
id=None,
|
||||||
"user_name": user_id,
|
user_name=user_id,
|
||||||
"kind": kind,
|
access_token=access_token,
|
||||||
"app_id": app_id,
|
profile_tag=profile_tag,
|
||||||
"app_display_name": app_display_name,
|
kind=kind,
|
||||||
"device_display_name": device_display_name,
|
app_id=app_id,
|
||||||
"pushkey": pushkey,
|
app_display_name=app_display_name,
|
||||||
"ts": time_now_msec,
|
device_display_name=device_display_name,
|
||||||
"lang": lang,
|
pushkey=pushkey,
|
||||||
"data": data,
|
ts=time_now_msec,
|
||||||
"last_stream_ordering": None,
|
lang=lang,
|
||||||
"last_success": None,
|
data=data,
|
||||||
"failing_since": None,
|
last_stream_ordering=None,
|
||||||
}
|
last_success=None,
|
||||||
|
failing_since=None,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# create the pusher setting last_stream_ordering to the current maximum
|
# create the pusher setting last_stream_ordering to the current maximum
|
||||||
@ -151,43 +153,44 @@ class PusherPool:
|
|||||||
return pusher
|
return pusher
|
||||||
|
|
||||||
async def remove_pushers_by_app_id_and_pushkey_not_user(
|
async def remove_pushers_by_app_id_and_pushkey_not_user(
|
||||||
self, app_id, pushkey, not_user_id
|
self, app_id: str, pushkey: str, not_user_id: str
|
||||||
):
|
) -> None:
|
||||||
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||||
for p in to_remove:
|
for p in to_remove:
|
||||||
if p["user_name"] != not_user_id:
|
if p.user_name != not_user_id:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||||
app_id,
|
app_id,
|
||||||
pushkey,
|
pushkey,
|
||||||
p["user_name"],
|
p.user_name,
|
||||||
)
|
)
|
||||||
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
|
await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
|
||||||
|
|
||||||
async def remove_pushers_by_access_token(self, user_id, access_tokens):
|
async def remove_pushers_by_access_token(
|
||||||
|
self, user_id: str, access_tokens: Iterable[int]
|
||||||
|
) -> None:
|
||||||
"""Remove the pushers for a given user corresponding to a set of
|
"""Remove the pushers for a given user corresponding to a set of
|
||||||
access_tokens.
|
access_tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): user to remove pushers for
|
user_id: user to remove pushers for
|
||||||
access_tokens (Iterable[int]): access token *ids* to remove pushers
|
access_tokens: access token *ids* to remove pushers for
|
||||||
for
|
|
||||||
"""
|
"""
|
||||||
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
|
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
|
||||||
return
|
return
|
||||||
|
|
||||||
tokens = set(access_tokens)
|
tokens = set(access_tokens)
|
||||||
for p in await self.store.get_pushers_by_user_id(user_id):
|
for p in await self.store.get_pushers_by_user_id(user_id):
|
||||||
if p["access_token"] in tokens:
|
if p.access_token in tokens:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||||
p["app_id"],
|
p.app_id,
|
||||||
p["pushkey"],
|
p.pushkey,
|
||||||
p["user_name"],
|
p.user_name,
|
||||||
)
|
)
|
||||||
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
|
await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
|
||||||
|
|
||||||
def on_new_notifications(self, max_token: RoomStreamToken):
|
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
|
||||||
if not self.pushers:
|
if not self.pushers:
|
||||||
# nothing to do here.
|
# nothing to do here.
|
||||||
return
|
return
|
||||||
@ -206,7 +209,7 @@ class PusherPool:
|
|||||||
self._on_new_notifications(max_token)
|
self._on_new_notifications(max_token)
|
||||||
|
|
||||||
@wrap_as_background_process("on_new_notifications")
|
@wrap_as_background_process("on_new_notifications")
|
||||||
async def _on_new_notifications(self, max_token: RoomStreamToken):
|
async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
|
||||||
# We just use the minimum stream ordering and ignore the vector clock
|
# We just use the minimum stream ordering and ignore the vector clock
|
||||||
# component. This is safe to do as long as we *always* ignore the vector
|
# component. This is safe to do as long as we *always* ignore the vector
|
||||||
# clock components.
|
# clock components.
|
||||||
@ -236,7 +239,9 @@ class PusherPool:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Exception in pusher on_new_notifications")
|
logger.exception("Exception in pusher on_new_notifications")
|
||||||
|
|
||||||
async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
|
async def on_new_receipts(
|
||||||
|
self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
|
||||||
|
) -> None:
|
||||||
if not self.pushers:
|
if not self.pushers:
|
||||||
# nothing to do here.
|
# nothing to do here.
|
||||||
return
|
return
|
||||||
@ -280,14 +285,14 @@ class PusherPool:
|
|||||||
|
|
||||||
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||||
|
|
||||||
pusher_dict = None
|
pusher_config = None
|
||||||
for r in resultlist:
|
for r in resultlist:
|
||||||
if r["user_name"] == user_id:
|
if r.user_name == user_id:
|
||||||
pusher_dict = r
|
pusher_config = r
|
||||||
|
|
||||||
pusher = None
|
pusher = None
|
||||||
if pusher_dict:
|
if pusher_config:
|
||||||
pusher = await self._start_pusher(pusher_dict)
|
pusher = await self._start_pusher(pusher_config)
|
||||||
|
|
||||||
return pusher
|
return pusher
|
||||||
|
|
||||||
@ -302,44 +307,44 @@ class PusherPool:
|
|||||||
|
|
||||||
logger.info("Started pushers")
|
logger.info("Started pushers")
|
||||||
|
|
||||||
async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
|
async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
|
||||||
"""Start the given pusher
|
"""Start the given pusher
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pusherdict: dict with the values pulled from the db table
|
pusher_config: The pusher configuration with the values pulled from the db table
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The newly created pusher or None.
|
The newly created pusher or None.
|
||||||
"""
|
"""
|
||||||
if not self._pusher_shard_config.should_handle(
|
if not self._pusher_shard_config.should_handle(
|
||||||
self._instance_name, pusherdict["user_name"]
|
self._instance_name, pusher_config.user_name
|
||||||
):
|
):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
p = self.pusher_factory.create_pusher(pusherdict)
|
p = self.pusher_factory.create_pusher(pusher_config)
|
||||||
except PusherConfigException as e:
|
except PusherConfigException as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
|
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
|
||||||
pusherdict["id"],
|
pusher_config.id,
|
||||||
pusherdict.get("user_name"),
|
pusher_config.user_name,
|
||||||
pusherdict.get("app_id"),
|
pusher_config.app_id,
|
||||||
pusherdict.get("pushkey"),
|
pusher_config.pushkey,
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Couldn't start pusher id %i: caught Exception", pusherdict["id"],
|
"Couldn't start pusher id %i: caught Exception", pusher_config.id,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not p:
|
if not p:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
|
appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
|
||||||
|
|
||||||
byuser = self.pushers.setdefault(pusherdict["user_name"], {})
|
byuser = self.pushers.setdefault(pusher_config.user_name, {})
|
||||||
if appid_pushkey in byuser:
|
if appid_pushkey in byuser:
|
||||||
byuser[appid_pushkey].on_stop()
|
byuser[appid_pushkey].on_stop()
|
||||||
byuser[appid_pushkey] = p
|
byuser[appid_pushkey] = p
|
||||||
@ -349,8 +354,8 @@ class PusherPool:
|
|||||||
# Check if there *may* be push to process. We do this as this check is a
|
# Check if there *may* be push to process. We do this as this check is a
|
||||||
# lot cheaper to do than actually fetching the exact rows we need to
|
# lot cheaper to do than actually fetching the exact rows we need to
|
||||||
# push.
|
# push.
|
||||||
user_id = pusherdict["user_name"]
|
user_id = pusher_config.user_name
|
||||||
last_stream_ordering = pusherdict["last_stream_ordering"]
|
last_stream_ordering = pusher_config.last_stream_ordering
|
||||||
if last_stream_ordering:
|
if last_stream_ordering:
|
||||||
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
|
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
|
||||||
user_id, last_stream_ordering
|
user_id, last_stream_ordering
|
||||||
@ -364,7 +369,7 @@ class PusherPool:
|
|||||||
|
|
||||||
return p
|
return p
|
||||||
|
|
||||||
async def remove_pusher(self, app_id, pushkey, user_id):
|
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
|
||||||
appid_pushkey = "%s:%s" % (app_id, pushkey)
|
appid_pushkey = "%s:%s" % (app_id, pushkey)
|
||||||
|
|
||||||
byuser = self.pushers.get(user_id, {})
|
byuser = self.pushers.get(user_id, {})
|
||||||
|
@ -12,21 +12,31 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from synapse.storage.types import Connection
|
||||||
from synapse.storage.util.id_generators import _load_current_id
|
from synapse.storage.util.id_generators import _load_current_id
|
||||||
|
|
||||||
|
|
||||||
class SlavedIdTracker:
|
class SlavedIdTracker:
|
||||||
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_conn: Connection,
|
||||||
|
table: str,
|
||||||
|
column: str,
|
||||||
|
extra_tables: Optional[List[Tuple[str, str]]] = None,
|
||||||
|
step: int = 1,
|
||||||
|
):
|
||||||
self.step = step
|
self.step = step
|
||||||
self._current = _load_current_id(db_conn, table, column, step)
|
self._current = _load_current_id(db_conn, table, column, step)
|
||||||
for table, column in extra_tables:
|
if extra_tables:
|
||||||
self.advance(None, _load_current_id(db_conn, table, column))
|
for table, column in extra_tables:
|
||||||
|
self.advance(None, _load_current_id(db_conn, table, column))
|
||||||
|
|
||||||
def advance(self, instance_name, new_id):
|
def advance(self, instance_name: Optional[str], new_id: int):
|
||||||
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
||||||
|
|
||||||
def get_current_token(self):
|
def get_current_token(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -13,26 +13,33 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from synapse.replication.tcp.streams import PushersStream
|
from synapse.replication.tcp.streams import PushersStream
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.databases.main.pusher import PusherWorkerStore
|
from synapse.storage.databases.main.pusher import PusherWorkerStore
|
||||||
|
from synapse.storage.types import Connection
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
|
|
||||||
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
self._pushers_id_gen = SlavedIdTracker(
|
self._pushers_id_gen = SlavedIdTracker( # type: ignore
|
||||||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_pushers_stream_token(self):
|
def get_pushers_stream_token(self) -> int:
|
||||||
return self._pushers_id_gen.get_current_token()
|
return self._pushers_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(
|
||||||
|
self, stream_name: str, instance_name: str, token, rows
|
||||||
|
) -> None:
|
||||||
if stream_name == PushersStream.NAME:
|
if stream_name == PushersStream.NAME:
|
||||||
self._pushers_id_gen.advance(instance_name, token)
|
self._pushers_id_gen.advance(instance_name, token) # type: ignore
|
||||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
|
@ -42,17 +42,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_GET_PUSHERS_ALLOWED_KEYS = {
|
|
||||||
"app_display_name",
|
|
||||||
"app_id",
|
|
||||||
"data",
|
|
||||||
"device_display_name",
|
|
||||||
"kind",
|
|
||||||
"lang",
|
|
||||||
"profile_tag",
|
|
||||||
"pushkey",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class UsersRestServlet(RestServlet):
|
class UsersRestServlet(RestServlet):
|
||||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
|
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
|
||||||
@ -770,10 +759,7 @@ class PushersRestServlet(RestServlet):
|
|||||||
|
|
||||||
pushers = await self.store.get_pushers_by_user_id(user_id)
|
pushers = await self.store.get_pushers_by_user_id(user_id)
|
||||||
|
|
||||||
filtered_pushers = [
|
filtered_pushers = [p.as_dict() for p in pushers]
|
||||||
{k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
|
|
||||||
for p in pushers
|
|
||||||
]
|
|
||||||
|
|
||||||
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
|
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
|
||||||
|
|
||||||
|
@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALLOWED_KEYS = {
|
|
||||||
"app_display_name",
|
|
||||||
"app_id",
|
|
||||||
"data",
|
|
||||||
"device_display_name",
|
|
||||||
"kind",
|
|
||||||
"lang",
|
|
||||||
"profile_tag",
|
|
||||||
"pushkey",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class PushersRestServlet(RestServlet):
|
class PushersRestServlet(RestServlet):
|
||||||
PATTERNS = client_patterns("/pushers$", v1=True)
|
PATTERNS = client_patterns("/pushers$", v1=True)
|
||||||
@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
|
|||||||
|
|
||||||
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
|
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
|
||||||
|
|
||||||
filtered_pushers = [
|
filtered_pushers = [p.as_dict() for p in pushers]
|
||||||
{k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
|
|
||||||
]
|
|
||||||
|
|
||||||
return 200, {"pushers": filtered_pushers}
|
return 200, {"pushers": filtered_pushers}
|
||||||
|
|
||||||
|
@ -149,9 +149,6 @@ class DataStore(
|
|||||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||||
self._pushers_id_gen = StreamIdGenerator(
|
|
||||||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
|
||||||
)
|
|
||||||
self._group_updates_id_gen = StreamIdGenerator(
|
self._group_updates_id_gen = StreamIdGenerator(
|
||||||
db_conn, "local_group_updates", "stream_id"
|
db_conn, "local_group_updates", "stream_id"
|
||||||
)
|
)
|
||||||
|
@ -15,18 +15,32 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterable, Iterator, List, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
|
from synapse.push import PusherConfig, ThrottleParams
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
|
from synapse.storage.database import DatabasePool
|
||||||
|
from synapse.storage.types import Connection
|
||||||
|
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PusherWorkerStore(SQLBaseStore):
|
class PusherWorkerStore(SQLBaseStore):
|
||||||
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
|
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
||||||
|
super().__init__(database, db_conn, hs)
|
||||||
|
self._pushers_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
|
||||||
"""JSON-decode the data in the rows returned from the `pushers` table
|
"""JSON-decode the data in the rows returned from the `pushers` table
|
||||||
|
|
||||||
Drops any rows whose data cannot be decoded
|
Drops any rows whose data cannot be decoded
|
||||||
@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield r
|
yield PusherConfig(**r)
|
||||||
|
|
||||||
async def user_has_pusher(self, user_id):
|
async def user_has_pusher(self, user_id: str) -> bool:
|
||||||
ret = await self.db_pool.simple_select_one_onecol(
|
ret = await self.db_pool.simple_select_one_onecol(
|
||||||
"pushers", {"user_name": user_id}, "id", allow_none=True
|
"pushers", {"user_name": user_id}, "id", allow_none=True
|
||||||
)
|
)
|
||||||
return ret is not None
|
return ret is not None
|
||||||
|
|
||||||
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
|
async def get_pushers_by_app_id_and_pushkey(
|
||||||
return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
|
self, app_id: str, pushkey: str
|
||||||
|
) -> Iterator[PusherConfig]:
|
||||||
|
return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
|
||||||
|
|
||||||
def get_pushers_by_user_id(self, user_id):
|
async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
|
||||||
return self.get_pushers_by({"user_name": user_id})
|
return await self.get_pushers_by({"user_name": user_id})
|
||||||
|
|
||||||
async def get_pushers_by(self, keyvalues):
|
async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
|
||||||
ret = await self.db_pool.simple_select_list(
|
ret = await self.db_pool.simple_select_list(
|
||||||
"pushers",
|
"pushers",
|
||||||
keyvalues,
|
keyvalues,
|
||||||
@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
return self._decode_pushers_rows(ret)
|
return self._decode_pushers_rows(ret)
|
||||||
|
|
||||||
async def get_all_pushers(self):
|
async def get_all_pushers(self) -> Iterator[PusherConfig]:
|
||||||
def get_pushers(txn):
|
def get_pushers(txn):
|
||||||
txn.execute("SELECT * FROM pushers")
|
txn.execute("SELECT * FROM pushers")
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@cached(num_args=1, max_entries=15000)
|
@cached(num_args=1, max_entries=15000)
|
||||||
async def get_if_user_has_pusher(self, user_id):
|
async def get_if_user_has_pusher(self, user_id: str):
|
||||||
# This only exists for the cachedList decorator
|
# This only exists for the cachedList decorator
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(
|
@cachedList(
|
||||||
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
|
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
|
||||||
)
|
)
|
||||||
async def get_if_users_have_pushers(self, user_ids):
|
async def get_if_users_have_pushers(
|
||||||
|
self, user_ids: Iterable[str]
|
||||||
|
) -> Dict[str, bool]:
|
||||||
rows = await self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="pushers",
|
table="pushers",
|
||||||
column="user_name",
|
column="user_name",
|
||||||
@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
return bool(updated)
|
return bool(updated)
|
||||||
|
|
||||||
async def update_pusher_failing_since(
|
async def update_pusher_failing_since(
|
||||||
self, app_id, pushkey, user_id, failing_since
|
self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.db_pool.simple_update(
|
await self.db_pool.simple_update(
|
||||||
table="pushers",
|
table="pushers",
|
||||||
@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
desc="update_pusher_failing_since",
|
desc="update_pusher_failing_since",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_throttle_params_by_room(self, pusher_id):
|
async def get_throttle_params_by_room(
|
||||||
|
self, pusher_id: str
|
||||||
|
) -> Dict[str, ThrottleParams]:
|
||||||
res = await self.db_pool.simple_select_list(
|
res = await self.db_pool.simple_select_list(
|
||||||
"pusher_throttle",
|
"pusher_throttle",
|
||||||
{"pusher": pusher_id},
|
{"pusher": pusher_id},
|
||||||
@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
params_by_room = {}
|
params_by_room = {}
|
||||||
for row in res:
|
for row in res:
|
||||||
params_by_room[row["room_id"]] = {
|
params_by_room[row["room_id"]] = ThrottleParams(
|
||||||
"last_sent_ts": row["last_sent_ts"],
|
row["last_sent_ts"], row["throttle_ms"],
|
||||||
"throttle_ms": row["throttle_ms"],
|
)
|
||||||
}
|
|
||||||
|
|
||||||
return params_by_room
|
return params_by_room
|
||||||
|
|
||||||
async def set_throttle_params(self, pusher_id, room_id, params) -> None:
|
async def set_throttle_params(
|
||||||
|
self, pusher_id: str, room_id: str, params: ThrottleParams
|
||||||
|
) -> None:
|
||||||
# no need to lock because `pusher_throttle` has a primary key on
|
# no need to lock because `pusher_throttle` has a primary key on
|
||||||
# (pusher, room_id) so simple_upsert will retry
|
# (pusher, room_id) so simple_upsert will retry
|
||||||
await self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
"pusher_throttle",
|
"pusher_throttle",
|
||||||
{"pusher": pusher_id, "room_id": room_id},
|
{"pusher": pusher_id, "room_id": room_id},
|
||||||
params,
|
{"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
|
||||||
desc="set_throttle_params",
|
desc="set_throttle_params",
|
||||||
lock=False,
|
lock=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PusherStore(PusherWorkerStore):
|
class PusherStore(PusherWorkerStore):
|
||||||
def get_pushers_stream_token(self):
|
def get_pushers_stream_token(self) -> int:
|
||||||
return self._pushers_id_gen.get_current_token()
|
return self._pushers_id_gen.get_current_token()
|
||||||
|
|
||||||
async def add_pusher(
|
async def add_pusher(
|
||||||
self,
|
self,
|
||||||
user_id,
|
user_id: str,
|
||||||
access_token,
|
access_token: Optional[int],
|
||||||
kind,
|
kind: str,
|
||||||
app_id,
|
app_id: str,
|
||||||
app_display_name,
|
app_display_name: str,
|
||||||
device_display_name,
|
device_display_name: str,
|
||||||
pushkey,
|
pushkey: str,
|
||||||
pushkey_ts,
|
pushkey_ts: int,
|
||||||
lang,
|
lang: Optional[str],
|
||||||
data,
|
data: Optional[JsonDict],
|
||||||
last_stream_ordering,
|
last_stream_ordering: int,
|
||||||
profile_tag="",
|
profile_tag: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
async with self._pushers_id_gen.get_next() as stream_id:
|
async with self._pushers_id_gen.get_next() as stream_id:
|
||||||
# no need to lock because `pushers` has a unique key on
|
# no need to lock because `pushers` has a unique key on
|
||||||
@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore):
|
|||||||
# invalidate, since we the user might not have had a pusher before
|
# invalidate, since we the user might not have had a pusher before
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_pusher",
|
"add_pusher",
|
||||||
self._invalidate_cache_and_stream,
|
self._invalidate_cache_and_stream, # type: ignore
|
||||||
self.get_if_user_has_pusher,
|
self.get_if_user_has_pusher,
|
||||||
(user_id,),
|
(user_id,),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_pusher_by_app_id_pushkey_user_id(
|
async def delete_pusher_by_app_id_pushkey_user_id(
|
||||||
self, app_id, pushkey, user_id
|
self, app_id: str, pushkey: str, user_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
def delete_pusher_txn(txn, stream_id):
|
def delete_pusher_txn(txn, stream_id):
|
||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream( # type: ignore
|
||||||
txn, self.get_if_user_has_pusher, (user_id,)
|
txn, self.get_if_user_has_pusher, (user_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -153,12 +153,12 @@ class StreamIdGenerator:
|
|||||||
|
|
||||||
return _AsyncCtxManagerWrapper(manager())
|
return _AsyncCtxManagerWrapper(manager())
|
||||||
|
|
||||||
def get_current_token(self):
|
def get_current_token(self) -> int:
|
||||||
"""Returns the maximum stream id such that all stream ids less than or
|
"""Returns the maximum stream id such that all stream ids less than or
|
||||||
equal to it have been successfully persisted.
|
equal to it have been successfully persisted.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int
|
The maximum stream id.
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._unfinished_ids:
|
if self._unfinished_ids:
|
||||||
|
@ -209,7 +209,7 @@ class EmailPusherTests(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
last_stream_ordering = pushers[0]["last_stream_ordering"]
|
last_stream_ordering = pushers[0].last_stream_ordering
|
||||||
|
|
||||||
# Advance time a bit, so the pusher will register something has happened
|
# Advance time a bit, so the pusher will register something has happened
|
||||||
self.pump(10)
|
self.pump(10)
|
||||||
@ -220,7 +220,7 @@ class EmailPusherTests(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
|
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
|
||||||
|
|
||||||
# One email was attempted to be sent
|
# One email was attempted to be sent
|
||||||
self.assertEqual(len(self.email_attempts), 1)
|
self.assertEqual(len(self.email_attempts), 1)
|
||||||
@ -238,4 +238,4 @@ class EmailPusherTests(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
|
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
||||||
|
@ -144,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
last_stream_ordering = pushers[0]["last_stream_ordering"]
|
last_stream_ordering = pushers[0].last_stream_ordering
|
||||||
|
|
||||||
# Advance time a bit, so the pusher will register something has happened
|
# Advance time a bit, so the pusher will register something has happened
|
||||||
self.pump()
|
self.pump()
|
||||||
@ -155,7 +155,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
|
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
|
||||||
|
|
||||||
# One push was attempted to be sent -- it'll be the first message
|
# One push was attempted to be sent -- it'll be the first message
|
||||||
self.assertEqual(len(self.push_attempts), 1)
|
self.assertEqual(len(self.push_attempts), 1)
|
||||||
@ -176,8 +176,8 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
|
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
||||||
last_stream_ordering = pushers[0]["last_stream_ordering"]
|
last_stream_ordering = pushers[0].last_stream_ordering
|
||||||
|
|
||||||
# Now it'll try and send the second push message, which will be the second one
|
# Now it'll try and send the second push message, which will be the second one
|
||||||
self.assertEqual(len(self.push_attempts), 2)
|
self.assertEqual(len(self.push_attempts), 2)
|
||||||
@ -198,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
|
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
||||||
|
|
||||||
def test_sends_high_priority_for_encrypted(self):
|
def test_sends_high_priority_for_encrypted(self):
|
||||||
"""
|
"""
|
||||||
|
@ -766,7 +766,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
pushers = list(pushers)
|
pushers = list(pushers)
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertEqual("@bob:test", pushers[0]["user_name"])
|
self.assertEqual("@bob:test", pushers[0].user_name)
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user