mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Remove not needed database updates in modify user admin API (#10627)
This commit is contained in:
parent
0c3565da4c
commit
220f901229
1
changelog.d/10627.misc
Normal file
1
changelog.d/10627.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Remove not needed database updates in modify user admin API.
|
@ -21,11 +21,15 @@ It returns a JSON body like the following:
|
|||||||
"threepids": [
|
"threepids": [
|
||||||
{
|
{
|
||||||
"medium": "email",
|
"medium": "email",
|
||||||
"address": "<user_mail_1>"
|
"address": "<user_mail_1>",
|
||||||
|
"added_at": 1586458409743,
|
||||||
|
"validated_at": 1586458409743
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"medium": "email",
|
"medium": "email",
|
||||||
"address": "<user_mail_2>"
|
"address": "<user_mail_2>",
|
||||||
|
"added_at": 1586458409743,
|
||||||
|
"validated_at": 1586458409743
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"avatar_url": "<avatar_url>",
|
"avatar_url": "<avatar_url>",
|
||||||
|
@ -228,13 +228,18 @@ class UserRestServletV2(RestServlet):
|
|||||||
if not isinstance(deactivate, bool):
|
if not isinstance(deactivate, bool):
|
||||||
raise SynapseError(400, "'deactivated' parameter is not of type boolean")
|
raise SynapseError(400, "'deactivated' parameter is not of type boolean")
|
||||||
|
|
||||||
# convert into List[Tuple[str, str]]
|
# convert List[Dict[str, str]] into Set[Tuple[str, str]]
|
||||||
if external_ids is not None:
|
if external_ids is not None:
|
||||||
new_external_ids = []
|
new_external_ids = {
|
||||||
for external_id in external_ids:
|
(external_id["auth_provider"], external_id["external_id"])
|
||||||
new_external_ids.append(
|
for external_id in external_ids
|
||||||
(external_id["auth_provider"], external_id["external_id"])
|
}
|
||||||
)
|
|
||||||
|
# convert List[Dict[str, str]] into Set[Tuple[str, str]]
|
||||||
|
if threepids is not None:
|
||||||
|
new_threepids = {
|
||||||
|
(threepid["medium"], threepid["address"]) for threepid in threepids
|
||||||
|
}
|
||||||
|
|
||||||
if user: # modify user
|
if user: # modify user
|
||||||
if "displayname" in body:
|
if "displayname" in body:
|
||||||
@ -243,29 +248,39 @@ class UserRestServletV2(RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if threepids is not None:
|
if threepids is not None:
|
||||||
# remove old threepids from user
|
# get changed threepids (added and removed)
|
||||||
old_threepids = await self.store.user_get_threepids(user_id)
|
# convert List[Dict[str, Any]] into Set[Tuple[str, str]]
|
||||||
for threepid in old_threepids:
|
cur_threepids = {
|
||||||
|
(threepid["medium"], threepid["address"])
|
||||||
|
for threepid in await self.store.user_get_threepids(user_id)
|
||||||
|
}
|
||||||
|
add_threepids = new_threepids - cur_threepids
|
||||||
|
del_threepids = cur_threepids - new_threepids
|
||||||
|
|
||||||
|
# remove old threepids
|
||||||
|
for medium, address in del_threepids:
|
||||||
try:
|
try:
|
||||||
await self.auth_handler.delete_threepid(
|
await self.auth_handler.delete_threepid(
|
||||||
user_id, threepid["medium"], threepid["address"], None
|
user_id, medium, address, None
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to remove threepids")
|
logger.exception("Failed to remove threepids")
|
||||||
raise SynapseError(500, "Failed to remove threepids")
|
raise SynapseError(500, "Failed to remove threepids")
|
||||||
|
|
||||||
# add new threepids to user
|
# add new threepids
|
||||||
current_time = self.hs.get_clock().time_msec()
|
current_time = self.hs.get_clock().time_msec()
|
||||||
for threepid in threepids:
|
for medium, address in add_threepids:
|
||||||
await self.auth_handler.add_threepid(
|
await self.auth_handler.add_threepid(
|
||||||
user_id, threepid["medium"], threepid["address"], current_time
|
user_id, medium, address, current_time
|
||||||
)
|
)
|
||||||
|
|
||||||
if external_ids is not None:
|
if external_ids is not None:
|
||||||
# get changed external_ids (added and removed)
|
# get changed external_ids (added and removed)
|
||||||
cur_external_ids = await self.store.get_external_ids_by_user(user_id)
|
cur_external_ids = set(
|
||||||
add_external_ids = set(new_external_ids) - set(cur_external_ids)
|
await self.store.get_external_ids_by_user(user_id)
|
||||||
del_external_ids = set(cur_external_ids) - set(new_external_ids)
|
)
|
||||||
|
add_external_ids = new_external_ids - cur_external_ids
|
||||||
|
del_external_ids = cur_external_ids - new_external_ids
|
||||||
|
|
||||||
# remove old external_ids
|
# remove old external_ids
|
||||||
for auth_provider, external_id in del_external_ids:
|
for auth_provider, external_id in del_external_ids:
|
||||||
@ -348,9 +363,9 @@ class UserRestServletV2(RestServlet):
|
|||||||
|
|
||||||
if threepids is not None:
|
if threepids is not None:
|
||||||
current_time = self.hs.get_clock().time_msec()
|
current_time = self.hs.get_clock().time_msec()
|
||||||
for threepid in threepids:
|
for medium, address in new_threepids:
|
||||||
await self.auth_handler.add_threepid(
|
await self.auth_handler.add_threepid(
|
||||||
user_id, threepid["medium"], threepid["address"], current_time
|
user_id, medium, address, current_time
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
self.hs.config.email_enable_notifs
|
self.hs.config.email_enable_notifs
|
||||||
@ -362,8 +377,8 @@ class UserRestServletV2(RestServlet):
|
|||||||
kind="email",
|
kind="email",
|
||||||
app_id="m.email",
|
app_id="m.email",
|
||||||
app_display_name="Email Notifications",
|
app_display_name="Email Notifications",
|
||||||
device_display_name=threepid["address"],
|
device_display_name=address,
|
||||||
pushkey=threepid["address"],
|
pushkey=address,
|
||||||
lang=None, # We don't know a user's language here
|
lang=None, # We don't know a user's language here
|
||||||
data={},
|
data={},
|
||||||
)
|
)
|
||||||
|
@ -754,16 +754,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
)
|
)
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
def get_user_id_by_threepid_txn(self, txn, medium, address):
|
def get_user_id_by_threepid_txn(
|
||||||
|
self, txn, medium: str, address: str
|
||||||
|
) -> Optional[str]:
|
||||||
"""Returns user id from threepid
|
"""Returns user id from threepid
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn (cursor):
|
txn (cursor):
|
||||||
medium (str): threepid medium e.g. email
|
medium: threepid medium e.g. email
|
||||||
address (str): threepid address e.g. me@example.com
|
address: threepid address e.g. me@example.com
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str|None: user id or None if no user id/threepid mapping exists
|
user id, or None if no user id/threepid mapping exists
|
||||||
"""
|
"""
|
||||||
ret = self.db_pool.simple_select_one_txn(
|
ret = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
@ -776,14 +778,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
return ret["user_id"]
|
return ret["user_id"]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
async def user_add_threepid(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
medium: str,
|
||||||
|
address: str,
|
||||||
|
validated_at: int,
|
||||||
|
added_at: int,
|
||||||
|
) -> None:
|
||||||
await self.db_pool.simple_upsert(
|
await self.db_pool.simple_upsert(
|
||||||
"user_threepids",
|
"user_threepids",
|
||||||
{"medium": medium, "address": address},
|
{"medium": medium, "address": address},
|
||||||
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
|
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def user_get_threepids(self, user_id):
|
async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
|
||||||
return await self.db_pool.simple_select_list(
|
return await self.db_pool.simple_select_list(
|
||||||
"user_threepids",
|
"user_threepids",
|
||||||
{"user_id": user_id},
|
{"user_id": user_id},
|
||||||
@ -791,7 +800,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||||||
"user_get_threepids",
|
"user_get_threepids",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def user_delete_threepid(self, user_id, medium, address) -> None:
|
async def user_delete_threepid(
|
||||||
|
self, user_id: str, medium: str, address: str
|
||||||
|
) -> None:
|
||||||
await self.db_pool.simple_delete(
|
await self.db_pool.simple_delete(
|
||||||
"user_threepids",
|
"user_threepids",
|
||||||
keyvalues={"user_id": user_id, "medium": medium, "address": address},
|
keyvalues={"user_id": user_id, "medium": medium, "address": address},
|
||||||
|
@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
|
self.assertEqual(1, len(channel.json_body["threepids"]))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"external_id1", channel.json_body["external_ids"][0]["external_id"]
|
"external_id1", channel.json_body["external_ids"][0]["external_id"]
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
"auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
|
"auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
|
||||||
)
|
)
|
||||||
|
self.assertEqual(1, len(channel.json_body["external_ids"]))
|
||||||
self.assertFalse(channel.json_body["admin"])
|
self.assertFalse(channel.json_body["admin"])
|
||||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||||
self._check_fields(channel.json_body)
|
self._check_fields(channel.json_body)
|
||||||
@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||||||
Test setting threepid for an other user.
|
Test setting threepid for an other user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Delete old and add new threepid to user
|
# Add two threepids to user
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
self.url_other_user,
|
self.url_other_user,
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
|
content={
|
||||||
|
"threepids": [
|
||||||
|
{"medium": "email", "address": "bob1@bob.bob"},
|
||||||
|
{"medium": "email", "address": "bob2@bob.bob"},
|
||||||
|
],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
|
self.assertEqual(2, len(channel.json_body["threepids"]))
|
||||||
|
# result does not always have the same sort order, therefore it becomes sorted
|
||||||
|
sorted_result = sorted(
|
||||||
|
channel.json_body["threepids"], key=lambda k: k["address"]
|
||||||
|
)
|
||||||
|
self.assertEqual("email", sorted_result[0]["medium"])
|
||||||
|
self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
|
||||||
|
self.assertEqual("email", sorted_result[1]["medium"])
|
||||||
|
self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
|
||||||
|
self._check_fields(channel.json_body)
|
||||||
|
|
||||||
|
# Set a new and remove a threepid
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url_other_user,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={
|
||||||
|
"threepids": [
|
||||||
|
{"medium": "email", "address": "bob2@bob.bob"},
|
||||||
|
{"medium": "email", "address": "bob3@bob.bob"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
|
self.assertEqual(2, len(channel.json_body["threepids"]))
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
|
self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
|
||||||
|
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
|
||||||
|
self._check_fields(channel.json_body)
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
|
self.assertEqual(2, len(channel.json_body["threepids"]))
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
|
self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
|
||||||
|
self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
|
||||||
|
self._check_fields(channel.json_body)
|
||||||
|
|
||||||
|
# Remove threepids
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url_other_user,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"threepids": []},
|
||||||
|
)
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
|
self.assertEqual(0, len(channel.json_body["threepids"]))
|
||||||
|
self._check_fields(channel.json_body)
|
||||||
|
|
||||||
def test_set_external_id(self):
|
def test_set_external_id(self):
|
||||||
"""
|
"""
|
||||||
@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
|
self.assertEqual(2, len(channel.json_body["external_ids"]))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
channel.json_body["external_ids"],
|
channel.json_body["external_ids"],
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user