Track device IDs for pushers (#13831)

Second half of the MSC3881 implementation
This commit is contained in:
Brendan Abolivier 2022-09-21 16:31:53 +01:00 committed by GitHub
parent 0fd2f2d460
commit ccca14140a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 154 additions and 10 deletions

View File

@ -0,0 +1 @@
Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).

View File

@ -117,6 +117,7 @@ class PusherConfig:
last_success: Optional[int] last_success: Optional[int]
failing_since: Optional[int] failing_since: Optional[int]
enabled: bool enabled: bool
device_id: Optional[str]
def as_dict(self) -> Dict[str, Any]: def as_dict(self) -> Dict[str, Any]:
"""Information that can be retrieved about a pusher after creation.""" """Information that can be retrieved about a pusher after creation."""
@ -130,6 +131,7 @@ class PusherConfig:
"profile_tag": self.profile_tag, "profile_tag": self.profile_tag,
"pushkey": self.pushkey, "pushkey": self.pushkey,
"enabled": self.enabled, "enabled": self.enabled,
"device_id": self.device_id,
} }

View File

@ -107,6 +107,7 @@ class PusherPool:
data: JsonDict, data: JsonDict,
profile_tag: str = "", profile_tag: str = "",
enabled: bool = True, enabled: bool = True,
device_id: Optional[str] = None,
) -> Optional[Pusher]: ) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool """Creates a new pusher and adds it to the pool
@ -149,18 +150,20 @@ class PusherPool:
last_success=None, last_success=None,
failing_since=None, failing_since=None,
enabled=enabled, enabled=enabled,
device_id=device_id,
) )
) )
# Before we actually persist the pusher, we check if the user already has one # Before we actually persist the pusher, we check if the user already has one
# for this app ID and pushkey. If so, we want to keep the access token in place, # this app ID and pushkey. If so, we want to keep the access token and device ID
# since this could be one device modifying (e.g. enabling/disabling) another # in place, since this could be one device modifying (e.g. enabling/disabling)
# device's pusher. # another device's pusher.
existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
user_id, app_id, pushkey user_id, app_id, pushkey
) )
if existing_config: if existing_config:
access_token = existing_config.access_token access_token = existing_config.access_token
device_id = existing_config.device_id
await self.store.add_pusher( await self.store.add_pusher(
user_id=user_id, user_id=user_id,
@ -176,6 +179,7 @@ class PusherPool:
last_stream_ordering=last_stream_ordering, last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag, profile_tag=profile_tag,
enabled=enabled, enabled=enabled,
device_id=device_id,
) )
pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id) pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)

View File

@ -57,7 +57,9 @@ class PushersRestServlet(RestServlet):
for pusher in pusher_dicts: for pusher in pusher_dicts:
if self._msc3881_enabled: if self._msc3881_enabled:
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"] pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
del pusher["enabled"] del pusher["enabled"]
del pusher["device_id"]
return 200, {"pushers": pusher_dicts} return 200, {"pushers": pusher_dicts}
@ -134,6 +136,7 @@ class PushersSetRestServlet(RestServlet):
data=content["data"], data=content["data"],
profile_tag=content.get("profile_tag", ""), profile_tag=content.get("profile_tag", ""),
enabled=enabled, enabled=enabled,
device_id=requester.device_id,
) )
except PusherConfigException as pce: except PusherConfigException as pce:
raise SynapseError( raise SynapseError(

View File

@ -124,7 +124,7 @@ class PusherWorkerStore(SQLBaseStore):
id, user_name, access_token, profile_tag, kind, app_id, id, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, ts, lang, data, app_display_name, device_display_name, pushkey, ts, lang, data,
last_stream_ordering, last_success, failing_since, last_stream_ordering, last_success, failing_since,
COALESCE(enabled, TRUE) AS enabled COALESCE(enabled, TRUE) AS enabled, device_id
FROM pushers FROM pushers
""" """
@ -477,7 +477,74 @@ class PusherWorkerStore(SQLBaseStore):
return number_deleted return number_deleted
class PusherStore(PusherWorkerStore): class PusherBackgroundUpdatesStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
"set_device_id_for_pushers", self._set_device_id_for_pushers
)
async def _set_device_id_for_pushers(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to populate the device_id column of the pushers table."""
last_pusher_id = progress.get("pusher_id", 0)
def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT p.id, at.device_id
FROM pushers AS p
INNER JOIN access_tokens AS at
ON p.access_token = at.id
WHERE
p.access_token IS NOT NULL
AND at.device_id IS NOT NULL
AND p.id > ?
ORDER BY p.id
LIMIT ?
""",
(last_pusher_id, batch_size),
)
rows = self.db_pool.cursor_to_dict(txn)
if len(rows) == 0:
return 0
self.db_pool.simple_update_many_txn(
txn=txn,
table="pushers",
key_names=("id",),
key_values=[(row["id"],) for row in rows],
value_names=("device_id",),
value_values=[(row["device_id"],) for row in rows],
)
self.db_pool.updates._background_update_progress_txn(
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["id"]}
)
return len(rows)
nb_processed = await self.db_pool.runInteraction(
"set_device_id_for_pushers", set_device_id_for_pushers_txn
)
if nb_processed < batch_size:
await self.db_pool.updates._end_background_update(
"set_device_id_for_pushers"
)
return nb_processed
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
def get_pushers_stream_token(self) -> int: def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
@ -496,6 +563,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering: int, last_stream_ordering: int,
profile_tag: str = "", profile_tag: str = "",
enabled: bool = True, enabled: bool = True,
device_id: Optional[str] = None,
) -> None: ) -> None:
async with self._pushers_id_gen.get_next() as stream_id: async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on # no need to lock because `pushers` has a unique key on
@ -515,6 +583,7 @@ class PusherStore(PusherWorkerStore):
"profile_tag": profile_tag, "profile_tag": profile_tag,
"id": stream_id, "id": stream_id,
"enabled": enabled, "enabled": enabled,
"device_id": device_id,
}, },
desc="add_pusher", desc="add_pusher",
lock=False, lock=False,

View File

@ -0,0 +1,20 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a device_id column to track the device ID that created the pusher. It's NULLable
-- on purpose, because a) it might not be possible to track down the device that created
-- old pushers (pushers.access_token and access_tokens.device_id are both NULLable), and
-- b) access tokens retrieved via the admin API don't have a device associated to them.
ALTER TABLE pushers ADD COLUMN device_id TEXT;

View File

@ -22,6 +22,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -771,6 +772,7 @@ class HTTPPusherTests(HomeserverTestCase):
lang=None, lang=None,
data={"url": "http://example.com/_matrix/push/v1/notify"}, data={"url": "http://example.com/_matrix/push/v1/notify"},
enabled=enabled, enabled=enabled,
device_id=user_tuple.device_id,
) )
) )
@ -885,19 +887,21 @@ class HTTPPusherTests(HomeserverTestCase):
self.assertEqual(len(channel.json_body["pushers"]), 1) self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]) self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
def test_update_different_device_access_token(self) -> None: def test_update_different_device_access_token_device_id(self) -> None:
"""Tests that if we create a pusher from one device, the update it from another """Tests that if we create a pusher from one device, the update it from another
device, the access token associated with the pusher stays the same. device, the access token and device ID associated with the pusher stays the
same.
""" """
# Create a user with a pusher. # Create a user with a pusher.
user_id, access_token = self._make_user_with_pusher("user") user_id, access_token = self._make_user_with_pusher("user")
# Get the token ID for the current access token, since that's what we store in # Get the token ID for the current access token, since that's what we store in
# the pushers table. # the pushers table. Also get the device ID from it.
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token) self.hs.get_datastores().main.get_user_by_access_token(access_token)
) )
token_id = user_tuple.token_id token_id = user_tuple.token_id
device_id = user_tuple.device_id
# Generate a new access token, and update the pusher with it. # Generate a new access token, and update the pusher with it.
new_token = self.login("user", "pass") new_token = self.login("user", "pass")
@ -909,7 +913,48 @@ class HTTPPusherTests(HomeserverTestCase):
) )
pushers: List[PusherConfig] = list(ret) pushers: List[PusherConfig] = list(ret)
# Check that we still have one pusher, and that the access token associated with # Check that we still have one pusher, and that the access token and device ID
# it didn't change. # associated with it didn't change.
self.assertEqual(len(pushers), 1) self.assertEqual(len(pushers), 1)
self.assertEqual(pushers[0].access_token, token_id) self.assertEqual(pushers[0].access_token, token_id)
self.assertEqual(pushers[0].device_id, device_id)
@override_config({"experimental_features": {"msc3881_enabled": True}})
def test_device_id(self) -> None:
"""Tests that a pusher created with a given device ID shows that device ID in
GET /pushers requests.
"""
self.register_user("user", "pass")
access_token = self.login("user", "pass")
# We create the pusher with an HTTP request rather than with
# _make_user_with_pusher so that we can test the device ID is correctly set when
# creating a pusher via an API call.
self.make_request(
method="POST",
path="/pushers/set",
content={
"kind": "http",
"app_id": "m.http",
"app_display_name": "HTTP Push Notifications",
"device_display_name": "pushy push",
"pushkey": "a@example.com",
"lang": "en",
"data": {"url": "http://example.com/_matrix/push/v1/notify"},
},
access_token=access_token,
)
# Look up the user info for the access token so we can compare the device ID.
lookup_result: TokenLookupResult = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertEqual(
channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
lookup_result.device_id,
)