diff --git a/changelog.d/12391.feature b/changelog.d/12391.feature new file mode 100644 index 000000000..9a064ec8b --- /dev/null +++ b/changelog.d/12391.feature @@ -0,0 +1 @@ +Add a module API for reading and writing global account data. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 9a61593ff..8f9e62927 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -119,6 +119,7 @@ from synapse.types import ( from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable from synapse.util.caches.descriptors import cached +from synapse.util.frozenutils import freeze if TYPE_CHECKING: from synapse.app.generic_worker import GenericWorkerSlavedStore @@ -211,6 +212,7 @@ class ModuleApi: # We expose these as properties below in order to attach a helpful docstring. self._http_client: SimpleHttpClient = hs.get_simple_http_client() self._public_room_list_manager = PublicRoomListManager(hs) + self._account_data_manager = AccountDataManager(hs) self._spam_checker = hs.get_spam_checker() self._account_validity_handler = hs.get_account_validity_handler() @@ -431,6 +433,14 @@ class ModuleApi: """ return self._public_room_list_manager + @property + def account_data_manager(self) -> "AccountDataManager": + """Allows reading and modifying users' account data. + + Added in Synapse v1.57.0. + """ + return self._account_data_manager + @property def public_baseurl(self) -> str: """The configured public base URL for this homeserver. @@ -1386,3 +1396,69 @@ class PublicRoomListManager: room_id: The ID of the room. """ await self._store.set_room_is_public(room_id, False) + + +class AccountDataManager: + """ + Allows modules to manage account data. + """ + + def __init__(self, hs: "HomeServer") -> None: + self._hs = hs + self._store = hs.get_datastores().main + self._handler = hs.get_account_data_handler() + + def _validate_user_id(self, user_id: str) -> None: + """ + Validates a user ID is valid and local. + Private method to be used in other account data methods. + """ + user = UserID.from_string(user_id) + if not self._hs.is_mine(user): + raise ValueError( + f"{user_id} is not local to this homeserver; can't access account data for remote users." + ) + + async def get_global(self, user_id: str, data_type: str) -> Optional[JsonDict]: + """ + Gets some global account data, of a specified type, for the specified user. + + The provided user ID must be a valid user ID of a local user. + + Added in Synapse v1.57.0. + """ + self._validate_user_id(user_id) + + data = await self._store.get_global_account_data_by_type_for_user( + user_id, data_type + ) + # We clone and freeze to prevent the module accidentally mutating the + # dict that lives in the cache, as that could introduce nasty bugs. + return freeze(data) + + async def put_global( + self, user_id: str, data_type: str, new_data: JsonDict + ) -> None: + """ + Puts some global account data, of a specified type, for the specified user. + + The provided user ID must be a valid user ID of a local user. + + Please note that this will overwrite existing the account data of that type + for that user! + + Added in Synapse v1.57.0. + """ + self._validate_user_id(user_id) + + if not isinstance(data_type, str): + raise TypeError(f"data_type must be a str; got {type(data_type).__name__}") + + if not isinstance(new_data, dict): + raise TypeError(f"new_data must be a dict; got {type(new_data).__name__}") + + # Ensure the user exists, so we don't just write to users that aren't there. + if await self._store.get_userinfo_by_id(user_id) is None: + raise ValueError(f"User {user_id} does not exist on this server.") + + await self._handler.add_account_data_for_user(user_id, data_type, new_data) diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py new file mode 100644 index 000000000..bec018d9e --- /dev/null +++ b/tests/module_api/test_account_data_manager.py @@ -0,0 +1,157 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import SynapseError +from synapse.rest import admin + +from tests.unittest import HomeserverTestCase + + +class ModuleApiTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + ] + + def prepare(self, reactor, clock, homeserver) -> None: + self._store = homeserver.get_datastores().main + self._module_api = homeserver.get_module_api() + self._account_data_mgr = self._module_api.account_data_manager + + self.user_id = self.register_user("kristina", "secret") + + def test_get_global(self) -> None: + """ + Tests that getting global account data through the module API works as + expected, including getting `None` for unset account data. + """ + self.get_success( + self._store.add_account_data_for_user( + self.user_id, "test.data", {"wombat": True} + ) + ) + + # Getting existent account data works as expected. + self.assertEqual( + self.get_success( + self._account_data_mgr.get_global(self.user_id, "test.data") + ), + {"wombat": True}, + ) + + # Getting non-existent account data returns None. + self.assertIsNone( + self.get_success( + self._account_data_mgr.get_global(self.user_id, "no.data.at.all") + ) + ) + + def test_get_global_validation(self) -> None: + """ + Tests that invalid or remote user IDs are treated as errors and raised as exceptions, + whilst getting global account data for a user. + + This is a design choice to try and communicate potential bugs to modules + earlier on. + """ + with self.assertRaises(SynapseError): + self.get_success_or_raise( + self._account_data_mgr.get_global("this isn't a user id", "test.data") + ) + + with self.assertRaises(ValueError): + self.get_success_or_raise( + self._account_data_mgr.get_global("@valid.but:remote", "test.data") + ) + + def test_get_global_no_mutability(self) -> None: + """ + Tests that modules can't introduce bugs into Synapse by mutating the result + of `get_global`. + """ + # First add some account data to set up the test. + self.get_success( + self._store.add_account_data_for_user( + self.user_id, "test.data", {"wombat": True} + ) + ) + + # Now request that data and then mutate it (out of negligence or otherwise). + the_data = self.get_success( + self._account_data_mgr.get_global(self.user_id, "test.data") + ) + with self.assertRaises(TypeError): + # This throws an exception because it's a frozen dict. + the_data["wombat"] = False + + def test_put_global(self) -> None: + """ + Tests that written account data using `put_global` can be read out again later. + """ + + self.get_success( + self._module_api.account_data_manager.put_global( + self.user_id, "test.data", {"wombat": True} + ) + ) + + # Request that account data from the normal store; check it's as we expect. + self.assertEqual( + self.get_success( + self._store.get_global_account_data_by_type_for_user( + self.user_id, "test.data" + ) + ), + {"wombat": True}, + ) + + def test_put_global_validation(self) -> None: + """ + Tests that a module can't write account data to user IDs that don't have + actual users registered to them. + Modules also must supply the correct types. + """ + + with self.assertRaises(SynapseError): + self.get_success_or_raise( + self._account_data_mgr.put_global( + "this isn't a user id", "test.data", {} + ) + ) + + with self.assertRaises(ValueError): + self.get_success_or_raise( + self._account_data_mgr.put_global("@valid.but:remote", "test.data", {}) + ) + + with self.assertRaises(ValueError): + self.get_success_or_raise( + self._module_api.account_data_manager.put_global( + "@notregistered:test", "test.data", {} + ) + ) + + with self.assertRaises(TypeError): + # The account data type must be a string. + self.get_success_or_raise( + self._module_api.account_data_manager.put_global( + self.user_id, 42, {} # type: ignore + ) + ) + + with self.assertRaises(TypeError): + # The account data dict must be a dict. + self.get_success_or_raise( + self._module_api.account_data_manager.put_global( + self.user_id, "test.data", 42 # type: ignore + ) + )