Implement account status endpoints (MSC3720) (#12001)

See matrix-org/matrix-doc#3720

Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
This commit is contained in:
Brendan Abolivier 2022-02-22 16:10:10 +01:00 committed by GitHub
parent 94a396e7c4
commit 250104d357
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 511 additions and 6 deletions

View File

@ -0,0 +1 @@
Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints).

View File

@ -65,3 +65,6 @@ class ExperimentalConfig(Config):
# experimental support for faster joins over federation (msc2775, msc3706)
# requires a target server with msc3706_enabled enabled.
self.faster_joins_enabled: bool = experimental.get("faster_joins", False)
# MSC3720 (Account status endpoint)
self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False)

View File

@ -56,7 +56,7 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.transport.client import SendJoinResponse
from synapse.types import JsonDict, get_domain_from_id
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@ -1610,6 +1610,64 @@ class FederationClient(FederationBase):
except ValueError as e:
raise InvalidResponseError(str(e))
async def get_account_status(
self, destination: str, user_ids: List[str]
) -> Tuple[JsonDict, List[str]]:
"""Retrieves account statuses for a given list of users on a given remote
homeserver.
If the request fails for any reason, all user IDs for this destination are marked
as failed.
Args:
destination: the destination to contact
user_ids: the user ID(s) for which to request account status(es)
Returns:
The account statuses, as well as the list of user IDs for which it was not
possible to retrieve a status.
"""
try:
res = await self.transport_layer.get_account_status(destination, user_ids)
except Exception:
# If the query failed for any reason, mark all the users as failed.
return {}, user_ids
statuses = res.get("account_statuses", {})
failures = res.get("failures", [])
if not isinstance(statuses, dict) or not isinstance(failures, list):
# Make sure we're not feeding back malformed data back to the caller.
logger.warning(
"Destination %s responded with malformed data to account_status query",
destination,
)
return {}, user_ids
for user_id in user_ids:
# Any account whose status is missing is a user we failed to receive the
# status of.
if user_id not in statuses and user_id not in failures:
failures.append(user_id)
# Filter out any user ID that doesn't belong to the remote server that sent its
# status (or failure).
def filter_user_id(user_id: str) -> bool:
try:
return UserID.from_string(user_id).domain == destination
except SynapseError:
# If the user ID doesn't parse, ignore it.
return False
filtered_statuses = dict(
# item is a (key, value) tuple, so item[0] is the user ID.
filter(lambda item: filter_user_id(item[0]), statuses.items())
)
filtered_failures = list(filter(filter_user_id, failures))
return filtered_statuses, filtered_failures
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TimestampToEventResponse:

View File

@ -258,8 +258,9 @@ class TransportLayerClient:
args: dict,
retry_on_dns_fail: bool,
ignore_backoff: bool = False,
prefix: str = FEDERATION_V1_PREFIX,
) -> JsonDict:
path = _create_v1_path("/query/%s", query_type)
path = _create_path(prefix, "/query/%s", query_type)
return await self.client.get_json(
destination=destination,
@ -1247,6 +1248,22 @@ class TransportLayerClient:
args={"suggested_only": "true" if suggested_only else "false"},
)
async def get_account_status(
self, destination: str, user_ids: List[str]
) -> JsonDict:
"""
Args:
destination: The remote server.
user_ids: The user ID(s) for which to request account status(es).
"""
path = _create_path(
FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc3720/account_status"
)
return await self.client.post_json(
destination=destination, path=path, data={"user_ids": user_ids}
)
def _create_path(federation_prefix: str, path: str, *args: str) -> str:
"""

View File

@ -24,6 +24,7 @@ from synapse.federation.transport.server._base import (
)
from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet,
FederationTimestampLookupServlet,
)
from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES
@ -336,6 +337,13 @@ def register_servlets(
):
continue
# Only allow the `/account_status` servlet if msc3720 is enabled
if (
servletclass == FederationAccountStatusServlet
and not hs.config.experimental.msc3720_enabled
):
continue
servletclass(
hs=hs,
authenticator=authenticator,

View File

@ -766,6 +766,40 @@ class RoomComplexityServlet(BaseFederationServlet):
return 200, complexity
class FederationAccountStatusServlet(BaseFederationServerServlet):
PATH = "/query/account_status"
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3720"
def __init__(
self,
hs: "HomeServer",
authenticator: Authenticator,
ratelimiter: FederationRateLimiter,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self._account_handler = hs.get_account_handler()
async def on_POST(
self,
origin: str,
content: JsonDict,
query: Mapping[bytes, Sequence[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
if "user_ids" not in content:
raise SynapseError(
400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
)
statuses, failures = await self._account_handler.get_account_statuses(
content["user_ids"],
allow_remote=False,
)
return 200, {"account_statuses": statuses, "failures": failures}
FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationSendServlet,
FederationEventServlet,
@ -797,4 +831,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationRoomHierarchyUnstableServlet,
FederationV1SendKnockServlet,
FederationMakeKnockServlet,
FederationAccountStatusServlet,
)

144
synapse/handlers/account.py Normal file
View File

@ -0,0 +1,144 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, List, Tuple
from synapse.api.errors import Codes, SynapseError
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
class AccountHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
self._is_mine = hs.is_mine
self._federation_client = hs.get_federation_client()
async def get_account_statuses(
self,
user_ids: List[str],
allow_remote: bool,
) -> Tuple[JsonDict, List[str]]:
"""Get account statuses for a list of user IDs.
If one or more account(s) belong to remote homeservers, retrieve their status(es)
over federation if allowed.
Args:
user_ids: The list of accounts to retrieve the status of.
allow_remote: Whether to try to retrieve the status of remote accounts, if
any.
Returns:
The account statuses as well as the list of users whose statuses could not be
retrieved.
Raises:
SynapseError if a required parameter is missing or malformed, or if one of
the accounts isn't local to this homeserver and allow_remote is False.
"""
statuses = {}
failures = []
remote_users: List[UserID] = []
for raw_user_id in user_ids:
try:
user_id = UserID.from_string(raw_user_id)
except SynapseError:
raise SynapseError(
400,
f"Not a valid Matrix user ID: {raw_user_id}",
Codes.INVALID_PARAM,
)
if self._is_mine(user_id):
status = await self._get_local_account_status(user_id)
statuses[user_id.to_string()] = status
else:
if not allow_remote:
raise SynapseError(
400,
f"Not a local user: {raw_user_id}",
Codes.INVALID_PARAM,
)
remote_users.append(user_id)
if allow_remote and len(remote_users) > 0:
remote_statuses, remote_failures = await self._get_remote_account_statuses(
remote_users,
)
statuses.update(remote_statuses)
failures += remote_failures
return statuses, failures
async def _get_local_account_status(self, user_id: UserID) -> JsonDict:
"""Retrieve the status of a local account.
Args:
user_id: The account to retrieve the status of.
Returns:
The account's status.
"""
status = {"exists": False}
userinfo = await self._store.get_userinfo_by_id(user_id.to_string())
if userinfo is not None:
status = {
"exists": True,
"deactivated": userinfo.is_deactivated,
}
return status
async def _get_remote_account_statuses(
self, remote_users: List[UserID]
) -> Tuple[JsonDict, List[str]]:
"""Send out federation requests to retrieve the statuses of remote accounts.
Args:
remote_users: The accounts to retrieve the statuses of.
Returns:
The statuses of the accounts, and a list of accounts for which no status
could be retrieved.
"""
# Group remote users by destination, so we only send one request per remote
# homeserver.
by_destination: Dict[str, List[str]] = {}
for user in remote_users:
if user.domain not in by_destination:
by_destination[user.domain] = []
by_destination[user.domain].append(user.to_string())
# Retrieve the statuses and failures for remote accounts.
final_statuses: JsonDict = {}
final_failures: List[str] = []
for destination, users in by_destination.items():
statuses, failures = await self._federation_client.get_account_status(
destination,
users,
)
final_statuses.update(statuses)
final_failures += failures
return final_statuses, final_failures

View File

@ -896,6 +896,36 @@ class WhoamiRestServlet(RestServlet):
return 200, response
class AccountStatusRestServlet(RestServlet):
PATTERNS = client_patterns(
"/org.matrix.msc3720/account_status$", unstable=True, releases=()
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self._auth = hs.get_auth()
self._store = hs.get_datastore()
self._is_mine = hs.is_mine
self._federation_client = hs.get_federation_client()
self._account_handler = hs.get_account_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self._auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
if "user_ids" not in body:
raise SynapseError(
400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
)
statuses, failures = await self._account_handler.get_account_statuses(
body["user_ids"],
allow_remote=True,
)
return 200, {"account_statuses": statuses, "failures": failures}
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
@ -910,3 +940,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
if hs.config.experimental.msc3720_enabled:
AccountStatusRestServlet(hs).register(http_server)

View File

@ -75,6 +75,11 @@ class CapabilitiesRestServlet(RestServlet):
if self.config.experimental.msc3440_enabled:
response["capabilities"]["io.element.thread"] = {"enabled": True}
if self.config.experimental.msc3720_enabled:
response["capabilities"]["org.matrix.msc3720.account_status"] = {
"enabled": True,
}
return HTTPStatus.OK, response

View File

@ -62,6 +62,7 @@ from synapse.federation.sender import AbstractFederationSender, FederationSender
from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
from synapse.handlers.account import AccountHandler
from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.admin import AdminHandler
@ -807,6 +808,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_external_cache(self) -> ExternalCache:
return ExternalCache(self)
@cache_in_self
def get_account_handler(self) -> AccountHandler:
return AccountHandler(self)
@cache_in_self
def get_outbound_redis_connection(self) -> "RedisProtocol":
"""

View File

@ -1,6 +1,4 @@
# Copyright 2015-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,16 +15,22 @@ import json
import os
import re
from email.parser import Parser
from typing import Optional
from typing import Dict, List, Optional
from unittest.mock import Mock
import pkg_resources
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import LoginType, Membership
from synapse.api.errors import Codes, HttpResponseException
from synapse.appservice import ApplicationService
from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.server import FakeSite, make_request
@ -1040,3 +1044,195 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
self.assertIn(expected_email, threepids)
class AccountStatusTestCase(unittest.HomeserverTestCase):
servlets = [
account.register_servlets,
admin.register_servlets,
login.register_servlets,
]
url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
def make_homeserver(self, reactor, clock):
config = self.default_config()
config["experimental_features"] = {"msc3720_enabled": True}
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
self.requester = self.register_user("requester", "password")
self.requester_tok = self.login("requester", "password")
self.server_name = homeserver.config.server.server_name
def test_missing_mxid(self):
"""Tests that not providing any MXID raises an error."""
self._test_status(
users=None,
expected_status_code=400,
expected_errcode=Codes.MISSING_PARAM,
)
def test_invalid_mxid(self):
"""Tests that providing an invalid MXID raises an error."""
self._test_status(
users=["bad:test"],
expected_status_code=400,
expected_errcode=Codes.INVALID_PARAM,
)
def test_local_user_not_exists(self):
"""Tests that the account status endpoints correctly reports that a user doesn't
exist.
"""
user = "@unknown:" + self.hs.config.server.server_name
self._test_status(
users=[user],
expected_statuses={
user: {
"exists": False,
},
},
expected_failures=[],
)
def test_local_user_exists(self):
"""Tests that the account status endpoint correctly reports that a user doesn't
exist.
"""
user = self.register_user("someuser", "password")
self._test_status(
users=[user],
expected_statuses={
user: {
"exists": True,
"deactivated": False,
},
},
expected_failures=[],
)
def test_local_user_deactivated(self):
"""Tests that the account status endpoint correctly reports a deactivated user."""
user = self.register_user("someuser", "password")
self.get_success(
self.hs.get_datastore().set_user_deactivated_status(user, deactivated=True)
)
self._test_status(
users=[user],
expected_statuses={
user: {
"exists": True,
"deactivated": True,
},
},
expected_failures=[],
)
def test_mixed_local_and_remote_users(self):
"""Tests that if some users are remote the account status endpoint correctly
merges the remote responses with the local result.
"""
# We use 3 users: one doesn't exist but belongs on the local homeserver, one is
# deactivated and belongs on one remote homeserver, and one belongs to another
# remote homeserver that didn't return any result (the federation code should
# mark that user as a failure).
users = [
"@unknown:" + self.hs.config.server.server_name,
"@deactivated:remote",
"@failed:otherremote",
"@bad:badremote",
]
async def post_json(destination, path, data, *a, **kwa):
if destination == "remote":
return {
"account_statuses": {
users[1]: {
"exists": True,
"deactivated": True,
},
}
}
if destination == "otherremote":
return {}
if destination == "badremote":
# badremote tries to overwrite the status of a user that doesn't belong
# to it (i.e. users[1]) with false data, which Synapse is expected to
# ignore.
return {
"account_statuses": {
users[3]: {
"exists": False,
},
users[1]: {
"exists": False,
},
}
}
# Register a mock that will return the expected result depending on the remote.
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
# Check that we've got the correct response from the client-side endpoint.
self._test_status(
users=users,
expected_statuses={
users[0]: {
"exists": False,
},
users[1]: {
"exists": True,
"deactivated": True,
},
users[3]: {
"exists": False,
},
},
expected_failures=[users[2]],
)
def _test_status(
self,
users: Optional[List[str]],
expected_status_code: int = 200,
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None,
):
"""Send a request to the account status endpoint and check that the response
matches with what's expected.
Args:
users: The account(s) to request the status of, if any. If set to None, no
`user_id` query parameter will be included in the request.
expected_status_code: The expected HTTP status code.
expected_statuses: The expected account statuses, if any.
expected_failures: The expected failures, if any.
expected_errcode: The expected Matrix error code, if any.
"""
content = {}
if users is not None:
content["user_ids"] = users
channel = self.make_request(
method="POST",
path=self.url,
content=content,
access_token=self.requester_tok,
)
self.assertEqual(channel.code, expected_status_code)
if expected_statuses is not None:
self.assertEqual(channel.json_body["account_statuses"], expected_statuses)
if expected_failures is not None:
self.assertEqual(channel.json_body["failures"], expected_failures)
if expected_errcode is not None:
self.assertEqual(channel.json_body["errcode"], expected_errcode)