Finish up work to allow per-user feature flags (#17392)

Follows on from @H-Shay's great work at
https://github.com/matrix-org/synapse/pull/15344 and MSC4026.

Also enables its use for MSC3881, mainly as an easy but concrete example
of how to use it.
This commit is contained in:
Erik Johnston 2024-07-05 13:02:35 +01:00 committed by GitHub
parent 45b35f8eae
commit 57538eb4d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 189 additions and 49 deletions

1
changelog.d/17392.misc Normal file
View File

@ -0,0 +1 @@
Finish up work to allow per-user feature flags.

View File

@ -2,13 +2,8 @@
This API allows a server administrator to enable or disable some experimental features on a per-user This API allows a server administrator to enable or disable some experimental features on a per-user
basis. The currently supported features are: basis. The currently supported features are:
- [MSC3026](https://github.com/matrix-org/matrix-spec-proposals/pull/3026): busy
presence state enabled
- [MSC3881](https://github.com/matrix-org/matrix-spec-proposals/pull/3881): enable remotely toggling push notifications - [MSC3881](https://github.com/matrix-org/matrix-spec-proposals/pull/3881): enable remotely toggling push notifications
for another client for another client
- [MSC3967](https://github.com/matrix-org/matrix-spec-proposals/pull/3967): do not require
UIA when first uploading cross-signing keys.
To use it, you will need to authenticate by providing an `access_token` To use it, you will need to authenticate by providing an `access_token`
for a server admin: see [Admin API](../usage/administration/admin_api/). for a server admin: see [Admin API](../usage/administration/admin_api/).

View File

@ -31,7 +31,9 @@ from synapse.rest.admin import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from typing_extensions import assert_never
from synapse.server import HomeServer, HomeServerConfig
class ExperimentalFeature(str, Enum): class ExperimentalFeature(str, Enum):
@ -39,9 +41,14 @@ class ExperimentalFeature(str, Enum):
Currently supported per-user features Currently supported per-user features
""" """
MSC3026 = "msc3026"
MSC3881 = "msc3881" MSC3881 = "msc3881"
def is_globally_enabled(self, config: "HomeServerConfig") -> bool:
if self is ExperimentalFeature.MSC3881:
return config.experimental.msc3881_enabled
assert_never(self)
class ExperimentalFeaturesRestServlet(RestServlet): class ExperimentalFeaturesRestServlet(RestServlet):
""" """

View File

@ -32,6 +32,7 @@ from synapse.http.servlet import (
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource
from synapse.types import JsonDict from synapse.types import JsonDict
@ -49,20 +50,22 @@ class PushersRestServlet(RestServlet):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled self._store = hs.get_datastores().main
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = requester.user user_id = requester.user.to_string()
pushers = await self.hs.get_datastores().main.get_pushers_by_user_id( msc3881_enabled = await self._store.is_feature_enabled(
user.to_string() user_id, ExperimentalFeature.MSC3881
) )
pushers = await self.hs.get_datastores().main.get_pushers_by_user_id(user_id)
pusher_dicts = [p.as_dict() for p in pushers] pusher_dicts = [p.as_dict() for p in pushers]
for pusher in pusher_dicts: for pusher in pusher_dicts:
if self._msc3881_enabled: if msc3881_enabled:
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"] pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
pusher["org.matrix.msc3881.device_id"] = pusher["device_id"] pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
del pusher["enabled"] del pusher["enabled"]
@ -80,11 +83,15 @@ class PushersSetRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled self._store = hs.get_datastores().main
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = requester.user user_id = requester.user.to_string()
msc3881_enabled = await self._store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -95,7 +102,7 @@ class PushersSetRestServlet(RestServlet):
and content["kind"] is None and content["kind"] is None
): ):
await self.pusher_pool.remove_pusher( await self.pusher_pool.remove_pusher(
content["app_id"], content["pushkey"], user_id=user.to_string() content["app_id"], content["pushkey"], user_id=user_id
) )
return 200, {} return 200, {}
@ -120,19 +127,19 @@ class PushersSetRestServlet(RestServlet):
append = content["append"] append = content["append"]
enabled = True enabled = True
if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content: if msc3881_enabled and "org.matrix.msc3881.enabled" in content:
enabled = content["org.matrix.msc3881.enabled"] enabled = content["org.matrix.msc3881.enabled"]
if not append: if not append:
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"], app_id=content["app_id"],
pushkey=content["pushkey"], pushkey=content["pushkey"],
not_user_id=user.to_string(), not_user_id=user_id,
) )
try: try:
await self.pusher_pool.add_or_update_pusher( await self.pusher_pool.add_or_update_pusher(
user_id=user.to_string(), user_id=user_id,
kind=content["kind"], kind=content["kind"],
app_id=content["app_id"], app_id=content["app_id"],
app_display_name=content["app_display_name"], app_display_name=content["app_display_name"],

View File

@ -25,11 +25,11 @@ import logging
import re import re
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.constants import RoomCreationPreset from synapse.api.constants import RoomCreationPreset
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.types import JsonDict from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
@ -45,6 +45,8 @@ class VersionsRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.config = hs.config self.config = hs.config
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
# Calculate these once since they shouldn't change after start-up. # Calculate these once since they shouldn't change after start-up.
self.e2ee_forced_public = ( self.e2ee_forced_public = (
@ -60,7 +62,17 @@ class VersionsRestServlet(RestServlet):
in self.config.room.encryption_enabled_by_default_for_room_presets in self.config.room.encryption_enabled_by_default_for_room_presets
) )
def on_GET(self, request: Request) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
msc3881_enabled = self.config.experimental.msc3881_enabled
if self.auth.has_access_token(request):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
msc3881_enabled = await self.store.is_feature_enabled(
user_id, ExperimentalFeature.MSC3881
)
return ( return (
200, 200,
{ {
@ -124,7 +136,7 @@ class VersionsRestServlet(RestServlet):
# TODO: this is no longer needed once unstable MSC3882 does not need to be supported: # TODO: this is no longer needed once unstable MSC3882 does not need to be supported:
"org.matrix.msc3882": self.config.auth.login_via_existing_enabled, "org.matrix.msc3882": self.config.auth.login_via_existing_enabled,
# Adds support for remotely enabling/disabling pushers, as per MSC3881 # Adds support for remotely enabling/disabling pushers, as per MSC3881
"org.matrix.msc3881": self.config.experimental.msc3881_enabled, "org.matrix.msc3881": msc3881_enabled,
# Adds support for filtering /messages by event relation. # Adds support for filtering /messages by event relation.
"org.matrix.msc3874": self.config.experimental.msc3874_enabled, "org.matrix.msc3874": self.config.experimental.msc3874_enabled,
# Adds support for simple HTTP rendezvous as per MSC3886 # Adds support for simple HTTP rendezvous as per MSC3886

View File

@ -21,7 +21,11 @@
from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast from typing import TYPE_CHECKING, Dict, FrozenSet, List, Tuple, cast
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -73,12 +77,54 @@ class ExperimentalFeaturesStore(CacheInvalidationWorkerStore):
features: features:
pairs of features and True/False for whether the feature should be enabled pairs of features and True/False for whether the feature should be enabled
""" """
def set_features_for_user_txn(txn: LoggingTransaction) -> None:
for feature, enabled in features.items(): for feature, enabled in features.items():
await self.db_pool.simple_upsert( self.db_pool.simple_upsert_txn(
txn,
table="per_user_experimental_features", table="per_user_experimental_features",
keyvalues={"feature": feature, "user_id": user}, keyvalues={"feature": feature, "user_id": user},
values={"enabled": enabled}, values={"enabled": enabled},
insertion_values={"user_id": user, "feature": feature}, insertion_values={"user_id": user, "feature": feature},
) )
await self.invalidate_cache_and_stream("list_enabled_features", (user,)) self._invalidate_cache_and_stream(
txn, self.is_feature_enabled, (user, feature)
)
self._invalidate_cache_and_stream(txn, self.list_enabled_features, (user,))
return await self.db_pool.runInteraction(
"set_features_for_user", set_features_for_user_txn
)
@cached()
async def is_feature_enabled(
self, user_id: str, feature: "ExperimentalFeature"
) -> bool:
"""
Checks to see if a given feature is enabled for the user
Args:
user_id: the user to be queried on
feature: the feature in question
Returns:
True if the feature is enabled, False if it is not or if the feature was
not found.
"""
if feature.is_globally_enabled(self.hs.config):
return True
# if it's not enabled globally, check if it is enabled per-user
res = await self.db_pool.simple_select_one_onecol(
table="per_user_experimental_features",
keyvalues={"user_id": user_id, "feature": feature},
retcol="enabled",
allow_none=True,
desc="get_feature_enabled",
)
# None and false are treated the same
db_enabled = bool(res)
return db_enabled

View File

@ -26,7 +26,8 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.rest.client import login, push_rule, pusher, receipts, room, versions
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -42,6 +43,7 @@ class HTTPPusherTests(HomeserverTestCase):
receipts.register_servlets, receipts.register_servlets,
push_rule.register_servlets, push_rule.register_servlets,
pusher.register_servlets, pusher.register_servlets,
versions.register_servlets,
] ]
user_id = True user_id = True
hijack_auth = False hijack_auth = False
@ -969,6 +971,84 @@ class HTTPPusherTests(HomeserverTestCase):
lookup_result.device_id, lookup_result.device_id,
) )
def test_device_id_feature_flag(self) -> None:
"""Tests that a pusher created with a given device ID shows that device ID in
GET /pushers requests when feature is enabled for the user
"""
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
# We create the pusher with an HTTP request rather than with
# _make_user_with_pusher so that we can test the device ID is correctly set when
# creating a pusher via an API call.
self.make_request(
method="POST",
path="/pushers/set",
content={
"kind": "http",
"app_id": "m.http",
"app_display_name": "HTTP Push Notifications",
"device_display_name": "pushy push",
"pushkey": "a@example.com",
"lang": "en",
"data": {"url": "http://example.com/_matrix/push/v1/notify"},
},
access_token=access_token,
)
# Look up the user info for the access token so we can compare the device ID.
store = self.hs.get_datastores().main
lookup_result = self.get_success(store.get_user_by_access_token(access_token))
assert lookup_result is not None
# Check field is not there before we enable the feature flag
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertNotIn(
"org.matrix.msc3881.device_id", channel.json_body["pushers"][0]
)
self.get_success(
store.set_features_for_user(user_id, {ExperimentalFeature.MSC3881: True})
)
# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertEqual(
channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
lookup_result.device_id,
)
def test_msc3881_client_versions_flag(self) -> None:
"""Tests that MSC3881 only appears in /versions if user has it enabled."""
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
# Check feature is disabled in /versions
channel = self.make_request(
"GET", "/_matrix/client/versions", access_token=access_token
)
self.assertEqual(channel.code, 200)
self.assertFalse(channel.json_body["unstable_features"]["org.matrix.msc3881"])
# Enable feature for user
self.get_success(
self.hs.get_datastores().main.set_features_for_user(
user_id, {ExperimentalFeature.MSC3881: True}
)
)
# Check feature is now enabled in /versions for user
channel = self.make_request(
"GET", "/_matrix/client/versions", access_token=access_token
)
self.assertEqual(channel.code, 200)
self.assertTrue(channel.json_body["unstable_features"]["org.matrix.msc3881"])
@override_config({"push": {"jitter_delay": "10s"}}) @override_config({"push": {"jitter_delay": "10s"}})
def test_jitter(self) -> None: def test_jitter(self) -> None:
"""Tests that enabling jitter actually delays sending push.""" """Tests that enabling jitter actually delays sending push."""

View File

@ -384,7 +384,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
"PUT", "PUT",
url, url,
content={ content={
"features": {"msc3026": True, "msc3881": True}, "features": {"msc3881": True},
}, },
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
@ -399,10 +399,6 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(
True,
channel.json_body["features"]["msc3026"],
)
self.assertEqual( self.assertEqual(
True, True,
channel.json_body["features"]["msc3881"], channel.json_body["features"]["msc3881"],
@ -413,7 +409,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
url, url,
content={"features": {"msc3026": False}}, content={"features": {"msc3881": False}},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@ -429,10 +425,6 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual( self.assertEqual(
False, False,
channel.json_body["features"]["msc3026"],
)
self.assertEqual(
True,
channel.json_body["features"]["msc3881"], channel.json_body["features"]["msc3881"],
) )
@ -441,7 +433,7 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",
url, url,
content={"features": {"msc3026": False}}, content={"features": {"msc3881": False}},
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)