Fix destination_is errors seen in sentry. (#13041)

* Rename test_fedclient to match its source file
* Require at least one destination to be truthy
* Explicitly validate user ID in profile endpoint GETs
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
David Robertson 2022-06-14 18:28:26 +01:00 committed by GitHub
parent aef398457f
commit c99b511db9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 59 additions and 8 deletions

2
changelog.d/13041.bugfix Normal file
View File

@ -0,0 +1,2 @@
Fix a bug introduced in Synapse 1.58 where profile requests for a malformed user ID would ccause an internal error. Synapse now returns 400 Bad Request in this situation.

View File

@ -731,8 +731,11 @@ class MatrixFederationHttpClient:
Returns: Returns:
A list of headers to be added as "Authorization:" headers A list of headers to be added as "Authorization:" headers
""" """
if destination is None and destination_is is None: if not destination and not destination_is:
raise ValueError("destination and destination_is cannot both be None!") raise ValueError(
"At least one of the arguments destination and destination_is "
"must be a nonempty bytestring."
)
request: JsonDict = { request: JsonDict = {
"method": method.decode("ascii"), "method": method.decode("ascii"),

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
@ -45,8 +45,12 @@ class ProfileDisplaynameRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user = requester.user requester_user = requester.user
user = UserID.from_string(user_id) if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)
user = UserID.from_string(user_id)
await self.profile_handler.check_profile_query_allowed(user, requester_user) await self.profile_handler.check_profile_query_allowed(user, requester_user)
displayname = await self.profile_handler.get_displayname(user) displayname = await self.profile_handler.get_displayname(user)
@ -98,8 +102,12 @@ class ProfileAvatarURLRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user = requester.user requester_user = requester.user
user = UserID.from_string(user_id) if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)
user = UserID.from_string(user_id)
await self.profile_handler.check_profile_query_allowed(user, requester_user) await self.profile_handler.check_profile_query_allowed(user, requester_user)
avatar_url = await self.profile_handler.get_avatar_url(user) avatar_url = await self.profile_handler.get_avatar_url(user)
@ -150,8 +158,12 @@ class ProfileRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user = requester.user requester_user = requester.user
user = UserID.from_string(user_id) if not UserID.is_valid(user_id):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid user id", Codes.INVALID_PARAM
)
user = UserID.from_string(user_id)
await self.profile_handler.check_profile_query_allowed(user, requester_user) await self.profile_handler.check_profile_query_allowed(user, requester_user)
displayname = await self.profile_handler.get_displayname(user) displayname = await self.profile_handler.get_displayname(user)

View File

@ -267,7 +267,6 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
) )
domain = parts[1] domain = parts[1]
# This code will need changing if we want to support multiple domain # This code will need changing if we want to support multiple domain
# names on one HS # names on one HS
return cls(localpart=parts[0], domain=domain) return cls(localpart=parts[0], domain=domain)
@ -279,6 +278,8 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
@classmethod @classmethod
def is_valid(cls: Type[DS], s: str) -> bool: def is_valid(cls: Type[DS], s: str) -> bool:
"""Parses the input string and attempts to ensure it is valid.""" """Parses the input string and attempts to ensure it is valid."""
# TODO: this does not reject an empty localpart or an overly-long string.
# See https://spec.matrix.org/v1.2/appendices/#identifier-grammar
try: try:
obj = cls.from_string(s) obj = cls.from_string(s)
# Apply additional validation to the domain. This is only done # Apply additional validation to the domain. This is only done

View File

@ -617,3 +617,17 @@ class FederationClientTests(HomeserverTestCase):
self.assertIsInstance(f.value, RequestSendFailed) self.assertIsInstance(f.value, RequestSendFailed)
self.assertTrue(transport.disconnecting) self.assertTrue(transport.disconnecting)
def test_build_auth_headers_rejects_falsey_destinations(self) -> None:
with self.assertRaises(ValueError):
self.cl.build_auth_headers(None, b"GET", b"https://example.com")
with self.assertRaises(ValueError):
self.cl.build_auth_headers(b"", b"GET", b"https://example.com")
with self.assertRaises(ValueError):
self.cl.build_auth_headers(
None, b"GET", b"https://example.com", destination_is=b""
)
with self.assertRaises(ValueError):
self.cl.build_auth_headers(
b"", b"GET", b"https://example.com", destination_is=b""
)

View File

@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
"""Tests REST events for /profile paths.""" """Tests REST events for /profile paths."""
import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -49,6 +51,12 @@ class ProfileTestCase(unittest.HomeserverTestCase):
res = self._get_displayname() res = self._get_displayname()
self.assertEqual(res, "owner") self.assertEqual(res, "owner")
def test_get_displayname_rejects_bad_username(self) -> None:
channel = self.make_request(
"GET", f"/profile/{urllib.parse.quote('@alice:')}/displayname"
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_set_displayname(self) -> None: def test_set_displayname(self) -> None:
channel = self.make_request( channel = self.make_request(
"PUT", "PUT",

View File

@ -26,10 +26,21 @@ class UserIDTestCase(unittest.HomeserverTestCase):
self.assertEqual("test", user.domain) self.assertEqual("test", user.domain)
self.assertEqual(True, self.hs.is_mine(user)) self.assertEqual(True, self.hs.is_mine(user))
def test_pase_empty(self): def test_parse_rejects_empty_id(self):
with self.assertRaises(SynapseError): with self.assertRaises(SynapseError):
UserID.from_string("") UserID.from_string("")
def test_parse_rejects_missing_sigil(self):
with self.assertRaises(SynapseError):
UserID.from_string("alice:example.com")
def test_parse_rejects_missing_separator(self):
with self.assertRaises(SynapseError):
UserID.from_string("@alice.example.com")
def test_validation_rejects_missing_domain(self):
self.assertFalse(UserID.is_valid("@alice:"))
def test_build(self): def test_build(self):
user = UserID("5678efgh", "my.domain") user = UserID("5678efgh", "my.domain")