diff --git a/changelog.d/11174.feature b/changelog.d/11174.feature new file mode 100644 index 000000000..8eecd9268 --- /dev/null +++ b/changelog.d/11174.feature @@ -0,0 +1 @@ +Users admin API can now also modify user type in addition to allowing it to be set on user creation. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index 534f8400b..f03539c9f 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -50,7 +50,8 @@ It returns a JSON body like the following: "auth_provider": "", "external_id": "" } - ] + ], + "user_type": null } ``` @@ -97,7 +98,8 @@ with a body of: ], "avatar_url": "", "admin": false, - "deactivated": false + "deactivated": false, + "user_type": null } ``` @@ -135,6 +137,9 @@ Body parameters: unchanged on existing accounts and set to `false` for new accounts. A user cannot be erased by deactivating with this API. For details on deactivating users see [Deactivate Account](#deactivate-account). +- `user_type` - string or null, optional. If provided, the user type will be + adjusted. If `null` given, the user type will be cleared. Other + allowed options are: `bot` and `support`. If the user already exists then optional parameters default to the current value. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index c0bebc3cf..d14fafbbc 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -326,6 +326,9 @@ class UserRestServletV2(RestServlet): target_user.to_string() ) + if "user_type" in body: + await self.store.set_user_type(target_user, user_type) + user = await self.admin_handler.get_user(target_user) assert user is not None diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 37d47aa82..6c7d6ba50 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -499,6 +499,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn) + async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None: + """Sets the user type. + + Args: + user: user ID of the user. + user_type: type of the user or None for a user without a type. + """ + + def set_user_type_txn(txn): + self.db_pool.simple_update_one_txn( + txn, "users", {"name": user.to_string()}, {"user_type": user_type} + ) + self._invalidate_cache_and_stream( + txn, self.get_user_by_id, (user.to_string(),) + ) + + await self.db_pool.runInteraction("set_user_type", set_user_type_txn) + def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]: sql = """ SELECT users.name as user_id, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 839442ddb..25e8d6cf2 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2270,6 +2270,57 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("@user:test", channel.json_body["name"]) self.assertTrue(channel.json_body["admin"]) + def test_set_user_type(self): + """ + Test changing user type. + """ + + # Set to support type + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"user_type": UserTypes.SUPPORT}, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"]) + + # Change back to a regular user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"user_type": None}, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertIsNone(channel.json_body["user_type"]) + + # Get user + channel = self.make_request( + "GET", + self.url_other_user, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertIsNone(channel.json_body["user_type"]) + def test_accidental_deactivation_prevention(self): """ Ensure an account can't accidentally be deactivated by using a str value