Support enabling/disabling pushers (from MSC3881) (#13799)

Partial implementation of MSC3881
This commit is contained in:
Brendan Abolivier 2022-09-21 15:39:01 +01:00 committed by GitHub
parent 6bd8763804
commit 8ae42ab8fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 294 additions and 71 deletions

View File

@ -0,0 +1 @@
Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).

View File

@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = {
"e2e_fallback_keys_json": ["used"], "e2e_fallback_keys_json": ["used"],
"access_tokens": ["used"], "access_tokens": ["used"],
"device_lists_changes_in_room": ["converted_to_destinations"], "device_lists_changes_in_room": ["converted_to_destinations"],
"pushers": ["enabled"],
} }

View File

@ -93,3 +93,6 @@ class ExperimentalConfig(Config):
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
# MSC3881: Remotely toggle push notifications for another client
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)

View File

@ -997,7 +997,7 @@ class RegistrationHandler:
assert user_tuple assert user_tuple
token_id = user_tuple.token_id token_id = user_tuple.token_id
await self.pusher_pool.add_pusher( await self.pusher_pool.add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="email", kind="email",
@ -1005,7 +1005,7 @@ class RegistrationHandler:
app_display_name="Email Notifications", app_display_name="Email Notifications",
device_display_name=threepid["address"], device_display_name=threepid["address"],
pushkey=threepid["address"], pushkey=threepid["address"],
lang=None, # We don't know a user's language here lang=None,
data={}, data={},
) )

View File

@ -116,6 +116,7 @@ class PusherConfig:
last_stream_ordering: int last_stream_ordering: int
last_success: Optional[int] last_success: Optional[int]
failing_since: Optional[int] failing_since: Optional[int]
enabled: bool
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> Dict[str, Any]:
"""Information that can be retrieved about a pusher after creation.""" """Information that can be retrieved about a pusher after creation."""
@ -128,6 +129,7 @@ class PusherConfig:
"lang": self.lang, "lang": self.lang,
"profile_tag": self.profile_tag, "profile_tag": self.profile_tag,
"pushkey": self.pushkey, "pushkey": self.pushkey,
"enabled": self.enabled,
} }

View File

@ -94,7 +94,7 @@ class PusherPool:
return return
run_as_background_process("start_pushers", self._start_pushers) run_as_background_process("start_pushers", self._start_pushers)
async def add_pusher( async def add_or_update_pusher(
self, self,
user_id: str, user_id: str,
access_token: Optional[int], access_token: Optional[int],
@ -106,6 +106,7 @@ class PusherPool:
lang: Optional[str], lang: Optional[str],
data: JsonDict, data: JsonDict,
profile_tag: str = "", profile_tag: str = "",
enabled: bool = True,
) -> Optional[Pusher]: ) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool """Creates a new pusher and adds it to the pool
@ -147,9 +148,20 @@ class PusherPool:
last_stream_ordering=last_stream_ordering, last_stream_ordering=last_stream_ordering,
last_success=None, last_success=None,
failing_since=None, failing_since=None,
enabled=enabled,
) )
) )
# Before we actually persist the pusher, we check if the user already has one
# for this app ID and pushkey. If so, we want to keep the access token in place,
# since this could be one device modifying (e.g. enabling/disabling) another
# device's pusher.
existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
user_id, app_id, pushkey
)
if existing_config:
access_token = existing_config.access_token
await self.store.add_pusher( await self.store.add_pusher(
user_id=user_id, user_id=user_id,
access_token=access_token, access_token=access_token,
@ -163,8 +175,9 @@ class PusherPool:
data=data, data=data,
last_stream_ordering=last_stream_ordering, last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag, profile_tag=profile_tag,
enabled=enabled,
) )
pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)
return pusher return pusher
@ -276,10 +289,25 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")
async def start_pusher_by_id( async def _get_pusher_config_for_user_by_app_id_and_pushkey(
self, user_id: str, app_id: str, pushkey: str
) -> Optional[PusherConfig]:
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_config = None
for r in resultlist:
if r.user_name == user_id:
pusher_config = r
return pusher_config
async def process_pusher_change_by_id(
self, app_id: str, pushkey: str, user_id: str self, app_id: str, pushkey: str, user_id: str
) -> Optional[Pusher]: ) -> Optional[Pusher]:
"""Look up the details for the given pusher, and start it """Look up the details for the given pusher, and either start it if its
"enabled" flag is True, or try to stop it otherwise.
If the pusher is new and its "enabled" flag is False, the stop is a noop.
Returns: Returns:
The pusher started, if any The pusher started, if any
@ -290,12 +318,13 @@ class PusherPool:
if not self._pusher_shard_config.should_handle(self._instance_name, user_id): if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return None return None
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
user_id, app_id, pushkey
)
pusher_config = None if pusher_config and not pusher_config.enabled:
for r in resultlist: self.maybe_stop_pusher(app_id, pushkey, user_id)
if r.user_name == user_id: return None
pusher_config = r
pusher = None pusher = None
if pusher_config: if pusher_config:
@ -305,7 +334,7 @@ class PusherPool:
async def _start_pushers(self) -> None: async def _start_pushers(self) -> None:
"""Start all the pushers""" """Start all the pushers"""
pushers = await self.store.get_all_pushers() pushers = await self.store.get_enabled_pushers()
# Stagger starting up the pushers so we don't completely drown the # Stagger starting up the pushers so we don't completely drown the
# process on start up. # process on start up.
@ -363,6 +392,8 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc() synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey)
# Check if there *may* be push to process. We do this as this check is a # Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to # lot cheaper to do than actually fetching the exact rows we need to
# push. # push.
@ -382,16 +413,7 @@ class PusherPool:
return pusher return pusher
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey) self.maybe_stop_pusher(app_id, pushkey, user_id)
byuser = self.pushers.get(user_id, {})
if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
pusher = byuser.pop(appid_pushkey)
pusher.on_stop()
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
# We can only delete pushers on master. # We can only delete pushers on master.
if self._remove_pusher_client: if self._remove_pusher_client:
@ -402,3 +424,22 @@ class PusherPool:
await self.store.delete_pusher_by_app_id_pushkey_user_id( await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id app_id, pushkey, user_id
) )
def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
"""Stops a pusher with the given app ID and push key if one is running.
Args:
app_id: the pusher's app ID.
pushkey: the pusher's push key.
user_id: the user the pusher belongs to. Only used for logging.
"""
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})
if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
pusher = byuser.pop(appid_pushkey)
pusher.on_stop()
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()

View File

@ -189,7 +189,9 @@ class ReplicationDataHandler:
if row.deleted: if row.deleted:
self.stop_pusher(row.user_id, row.app_id, row.pushkey) self.stop_pusher(row.user_id, row.app_id, row.pushkey)
else: else:
await self.start_pusher(row.user_id, row.app_id, row.pushkey) await self.process_pusher_change(
row.user_id, row.app_id, row.pushkey
)
elif stream_name == EventsStream.NAME: elif stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so # We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows. # we don't need to optimise this for multiple rows.
@ -334,13 +336,15 @@ class ReplicationDataHandler:
logger.info("Stopping pusher %r / %r", user_id, key) logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop() pusher.on_stop()
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: async def process_pusher_change(
self, user_id: str, app_id: str, pushkey: str
) -> None:
if not self._notify_pushers: if not self._notify_pushers:
return return
key = "%s:%s" % (app_id, pushkey) key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key) logger.info("Starting pusher %r / %r", user_id, key)
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id)
class FederationSenderHandler: class FederationSenderHandler:

View File

@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet):
and self.hs.config.email.email_notif_for_new_users and self.hs.config.email.email_notif_for_new_users
and medium == "email" and medium == "email"
): ):
await self.pusher_pool.add_pusher( await self.pusher_pool.add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=None, access_token=None,
kind="email", kind="email",
@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet):
app_display_name="Email Notifications", app_display_name="Email Notifications",
device_display_name=address, device_display_name=address,
pushkey=address, pushkey=address,
lang=None, # We don't know a user's language here lang=None,
data={}, data={},
) )

View File

@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -51,9 +52,14 @@ class PushersRestServlet(RestServlet):
user.to_string() user.to_string()
) )
filtered_pushers = [p.as_dict() for p in pushers] pusher_dicts = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers} for pusher in pusher_dicts:
if self._msc3881_enabled:
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
del pusher["enabled"]
return 200, {"pushers": pusher_dicts}
class PushersSetRestServlet(RestServlet): class PushersSetRestServlet(RestServlet):
@ -65,6 +71,7 @@ class PushersSetRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -103,6 +110,10 @@ class PushersSetRestServlet(RestServlet):
if "append" in content: if "append" in content:
append = content["append"] append = content["append"]
enabled = True
if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
enabled = content["org.matrix.msc3881.enabled"]
if not append: if not append:
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"], app_id=content["app_id"],
@ -111,7 +122,7 @@ class PushersSetRestServlet(RestServlet):
) )
try: try:
await self.pusher_pool.add_pusher( await self.pusher_pool.add_or_update_pusher(
user_id=user.to_string(), user_id=user.to_string(),
access_token=requester.access_token_id, access_token=requester.access_token_id,
kind=content["kind"], kind=content["kind"],
@ -122,6 +133,7 @@ class PushersSetRestServlet(RestServlet):
lang=content["lang"], lang=content["lang"],
data=content["data"], data=content["data"],
profile_tag=content.get("profile_tag", ""), profile_tag=content.get("profile_tag", ""),
enabled=enabled,
) )
except PusherConfigException as pce: except PusherConfigException as pce:
raise SynapseError( raise SynapseError(

View File

@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore):
) )
continue continue
# If we're using SQLite, then boolean values are integers. This is
# troublesome since some code using the return value of this method might
# expect it to be a boolean, or will expose it to clients (in responses).
r["enabled"] = bool(r["enabled"])
yield PusherConfig(**r) yield PusherConfig(**r)
async def get_pushers_by_app_id_and_pushkey( async def get_pushers_by_app_id_and_pushkey(
@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore):
return await self.get_pushers_by({"user_name": user_id}) return await self.get_pushers_by({"user_name": user_id})
async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]: async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
ret = await self.db_pool.simple_select_list( """Retrieve pushers that match the given criteria.
"pushers",
keyvalues, Args:
[ keyvalues: A {column: value} dictionary.
"id",
"user_name", Returns:
"access_token", The pushers for which the given columns have the given values.
"profile_tag", """
"kind",
"app_id", def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
"app_display_name", # We could technically use simple_select_list here, but we need to call
"device_display_name", # COALESCE on the 'enabled' column. While it is technically possible to give
"pushkey", # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
"ts", # feels a bit hacky, so it's probably better to just inline the query.
"lang", sql = """
"data", SELECT
"last_stream_ordering", id, user_name, access_token, profile_tag, kind, app_id,
"last_success", app_display_name, device_display_name, pushkey, ts, lang, data,
"failing_since", last_stream_ordering, last_success, failing_since,
], COALESCE(enabled, TRUE) AS enabled
FROM pushers
"""
sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),)
txn.execute(sql, list(keyvalues.values()))
return self.db_pool.cursor_to_dict(txn)
ret = await self.db_pool.runInteraction(
desc="get_pushers_by", desc="get_pushers_by",
func=get_pushers_by_txn,
) )
return self._decode_pushers_rows(ret) return self._decode_pushers_rows(ret)
async def get_all_pushers(self) -> Iterator[PusherConfig]: async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]: def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
txn.execute("SELECT * FROM pushers") txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
rows = self.db_pool.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows) return self._decode_pushers_rows(rows)
return await self.db_pool.runInteraction("get_all_pushers", get_pushers) return await self.db_pool.runInteraction(
"get_enabled_pushers", get_enabled_pushers_txn
)
async def get_all_updated_pushers_rows( async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int self, instance_name: str, last_id: int, current_id: int, limit: int
@ -476,6 +495,7 @@ class PusherStore(PusherWorkerStore):
data: Optional[JsonDict], data: Optional[JsonDict],
last_stream_ordering: int, last_stream_ordering: int,
profile_tag: str = "", profile_tag: str = "",
enabled: bool = True,
) -> None: ) -> None:
async with self._pushers_id_gen.get_next() as stream_id: async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on # no need to lock because `pushers` has a unique key on
@ -494,6 +514,7 @@ class PusherStore(PusherWorkerStore):
"last_stream_ordering": last_stream_ordering, "last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag, "profile_tag": profile_tag,
"id": stream_id, "id": stream_id,
"enabled": enabled,
}, },
desc="add_pusher", desc="add_pusher",
lock=False, lock=False,

View File

@ -0,0 +1,16 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE pushers ADD COLUMN enabled BOOLEAN;

View File

@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
) )
self.pusher = self.get_success( self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id, user_id=self.user_id,
access_token=self.token_id, access_token=self.token_id,
kind="email", kind="email",
@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase):
""" """
with self.assertRaises(SynapseError) as cm: with self.assertRaises(SynapseError) as cm:
self.get_success_or_raise( self.get_success_or_raise(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id, user_id=self.user_id,
access_token=self.token_id, access_token=self.token_id,
kind="email", kind="email",

View File

@ -19,8 +19,8 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfigException from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, receipts, room from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -35,6 +35,7 @@ class HTTPPusherTests(HomeserverTestCase):
login.register_servlets, login.register_servlets,
receipts.register_servlets, receipts.register_servlets,
push_rule.register_servlets, push_rule.register_servlets,
pusher.register_servlets,
] ]
user_id = True user_id = True
hijack_auth = False hijack_auth = False
@ -74,7 +75,7 @@ class HTTPPusherTests(HomeserverTestCase):
def test_data(data: Optional[JsonDict]) -> None: def test_data(data: Optional[JsonDict]) -> None:
self.get_failure( self.get_failure(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="http", kind="http",
@ -119,7 +120,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="http", kind="http",
@ -235,7 +236,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="http", kind="http",
@ -355,7 +356,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="http", kind="http",
@ -441,7 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="http", kind="http",
@ -518,7 +519,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="http", kind="http",
@ -624,7 +625,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="http", kind="http",
@ -728,18 +729,38 @@ class HTTPPusherTests(HomeserverTestCase):
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
def _make_user_with_pusher(self, username: str) -> Tuple[str, str]: def _make_user_with_pusher(
self, username: str, enabled: bool = True
) -> Tuple[str, str]:
"""Registers a user and creates a pusher for them.
Args:
username: the localpart of the new user's Matrix ID.
enabled: whether to create the pusher in an enabled or disabled state.
"""
user_id = self.register_user(username, "pass") user_id = self.register_user(username, "pass")
access_token = self.login(username, "pass") access_token = self.login(username, "pass")
# Register the pusher # Register the pusher
self._set_pusher(user_id, access_token, enabled)
return user_id, access_token
def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None:
"""Creates or updates the pusher for the given user.
Args:
user_id: the user's Matrix ID.
access_token: the access token associated with the pusher.
enabled: whether to enable or disable the pusher.
"""
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="http", kind="http",
@ -749,11 +770,10 @@ class HTTPPusherTests(HomeserverTestCase):
pushkey="a@example.com", pushkey="a@example.com",
lang=None, lang=None,
data={"url": "http://example.com/_matrix/push/v1/notify"}, data={"url": "http://example.com/_matrix/push/v1/notify"},
enabled=enabled,
) )
) )
return user_id, access_token
def test_dont_notify_rule_overrides_message(self) -> None: def test_dont_notify_rule_overrides_message(self) -> None:
""" """
The override push rule will suppress notification The override push rule will suppress notification
@ -791,3 +811,105 @@ class HTTPPusherTests(HomeserverTestCase):
# The user sends a message back (sends a notification) # The user sends a message back (sends a notification)
self.helper.send(room, body="Hello", tok=access_token) self.helper.send(room, body="Hello", tok=access_token)
self.assertEqual(len(self.push_attempts), 1) self.assertEqual(len(self.push_attempts), 1)
@override_config({"experimental_features": {"msc3881_enabled": True}})
def test_disable(self) -> None:
"""Tests that disabling a pusher means it's not pushed to anymore."""
user_id, access_token = self._make_user_with_pusher("user")
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
room = self.helper.create_room_as(user_id, tok=access_token)
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
# Send a message and check that it generated a push.
self.helper.send(room, body="Hi!", tok=other_access_token)
self.assertEqual(len(self.push_attempts), 1)
# Disable the pusher.
self._set_pusher(user_id, access_token, enabled=False)
# Send another message and check that it did not generate a push.
self.helper.send(room, body="Hi!", tok=other_access_token)
self.assertEqual(len(self.push_attempts), 1)
# Get the pushers for the user and check that it is marked as disabled.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
self.assertFalse(enabled)
self.assertTrue(isinstance(enabled, bool))
@override_config({"experimental_features": {"msc3881_enabled": True}})
def test_enable(self) -> None:
"""Tests that enabling a disabled pusher means it gets pushed to."""
# Create the user with the pusher already disabled.
user_id, access_token = self._make_user_with_pusher("user", enabled=False)
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
room = self.helper.create_room_as(user_id, tok=access_token)
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
# Send a message and check that it did not generate a push.
self.helper.send(room, body="Hi!", tok=other_access_token)
self.assertEqual(len(self.push_attempts), 0)
# Enable the pusher.
self._set_pusher(user_id, access_token, enabled=True)
# Send another message and check that it did generate a push.
self.helper.send(room, body="Hi!", tok=other_access_token)
self.assertEqual(len(self.push_attempts), 1)
# Get the pushers for the user and check that it is marked as enabled.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
self.assertTrue(enabled)
self.assertTrue(isinstance(enabled, bool))
@override_config({"experimental_features": {"msc3881_enabled": True}})
def test_null_enabled(self) -> None:
"""Tests that a pusher that has an 'enabled' column set to NULL (eg pushers
created before the column was introduced) is considered enabled.
"""
# We intentionally set 'enabled' to None so that it's stored as NULL in the
# database.
user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type]
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
def test_update_different_device_access_token(self) -> None:
"""Tests that if we create a pusher from one device, the update it from another
device, the access token associated with the pusher stays the same.
"""
# Create a user with a pusher.
user_id, access_token = self._make_user_with_pusher("user")
# Get the token ID for the current access token, since that's what we store in
# the pushers table.
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
token_id = user_tuple.token_id
# Generate a new access token, and update the pusher with it.
new_token = self.login("user", "pass")
self._set_pusher(user_id, new_token, enabled=False)
# Get the current list of pushers for the user.
ret = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
)
pushers: List[PusherConfig] = list(ret)
# Check that we still have one pusher, and that the access token associated with
# it didn't change.
self.assertEqual(len(pushers), 1)
self.assertEqual(pushers[0].access_token, token_id)

View File

@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
token_id = user_dict.token_id token_id = user_dict.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id, user_id=user_id,
access_token=token_id, access_token=token_id,
kind="http", kind="http",

View File

@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
token_id = user_tuple.token_id token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.other_user, user_id=self.other_user,
access_token=token_id, access_token=token_id,
kind="http", kind="http",