mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-09-20 00:24:36 -04:00
Ratelimit 3PID /requestToken API (#9238)
This commit is contained in:
parent
54a6afeee3
commit
4b73488e81
11 changed files with 159 additions and 14 deletions
|
@ -24,7 +24,7 @@ import pkg_resources
|
|||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import LoginType, Membership
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.api.errors import Codes, HttpResponseException
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.rest.client.v2_alpha import account, register
|
||||
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
|
||||
|
@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||
# Assert we can't log in with the old password
|
||||
self.attempt_wrong_password_login("kermit", old_password)
|
||||
|
||||
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
||||
def test_ratelimit_by_email(self):
|
||||
"""Test that we ratelimit /requestToken for the same email.
|
||||
"""
|
||||
old_password = "monkey"
|
||||
new_password = "kangeroo"
|
||||
|
||||
user_id = self.register_user("kermit", old_password)
|
||||
self.login("kermit", old_password)
|
||||
|
||||
email = "test1@example.com"
|
||||
|
||||
# Add a threepid
|
||||
self.get_success(
|
||||
self.store.user_add_threepid(
|
||||
user_id=user_id,
|
||||
medium="email",
|
||||
address=email,
|
||||
validated_at=0,
|
||||
added_at=0,
|
||||
)
|
||||
)
|
||||
|
||||
def reset(ip):
|
||||
client_secret = "foobar"
|
||||
session_id = self._request_token(email, client_secret, ip)
|
||||
|
||||
self.assertEquals(len(self.email_attempts), 1)
|
||||
link = self._get_link_from_email()
|
||||
|
||||
self._validate_token(link)
|
||||
|
||||
self._reset_password(new_password, session_id, client_secret)
|
||||
|
||||
self.email_attempts.clear()
|
||||
|
||||
# We expect to be able to make three requests before getting rate
|
||||
# limited.
|
||||
#
|
||||
# We change IPs to ensure that we're not being ratelimited due to the
|
||||
# same IP
|
||||
reset("127.0.0.1")
|
||||
reset("127.0.0.2")
|
||||
reset("127.0.0.3")
|
||||
|
||||
with self.assertRaises(HttpResponseException) as cm:
|
||||
reset("127.0.0.4")
|
||||
|
||||
self.assertEqual(cm.exception.code, 429)
|
||||
|
||||
def test_basic_password_reset_canonicalise_email(self):
|
||||
"""Test basic password reset flow
|
||||
Request password reset with different spelling
|
||||
|
@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertIsNotNone(session_id)
|
||||
|
||||
def _request_token(self, email, client_secret):
|
||||
def _request_token(self, email, client_secret, ip="127.0.0.1"):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
b"account/password/email/requestToken",
|
||||
{"client_secret": client_secret, "email": email, "send_attempt": 1},
|
||||
client_ip=ip,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
|
||||
if channel.code != 200:
|
||||
raise HttpResponseException(
|
||||
channel.code, channel.result["reason"], channel.result["body"],
|
||||
)
|
||||
|
||||
return channel.json_body["sid"]
|
||||
|
||||
|
@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||
def test_address_trim(self):
|
||||
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
|
||||
|
||||
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
||||
def test_ratelimit_by_ip(self):
|
||||
"""Tests that adding emails is ratelimited by IP
|
||||
"""
|
||||
|
||||
# We expect to be able to set three emails before getting ratelimited.
|
||||
self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
|
||||
self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
|
||||
self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
|
||||
|
||||
with self.assertRaises(HttpResponseException) as cm:
|
||||
self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
|
||||
|
||||
self.assertEqual(cm.exception.code, 429)
|
||||
|
||||
def test_add_email_if_disabled(self):
|
||||
"""Test adding email to profile when doing so is disallowed
|
||||
"""
|
||||
|
@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||
body["next_link"] = next_link
|
||||
|
||||
channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
|
||||
self.assertEquals(expect_code, channel.code, channel.result)
|
||||
|
||||
if channel.code != expect_code:
|
||||
raise HttpResponseException(
|
||||
channel.code, channel.result["reason"], channel.result["body"],
|
||||
)
|
||||
|
||||
return channel.json_body.get("sid")
|
||||
|
||||
|
@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||
def _add_email(self, request_email, expected_email):
|
||||
"""Test adding an email to profile
|
||||
"""
|
||||
previous_email_attempts = len(self.email_attempts)
|
||||
|
||||
client_secret = "foobar"
|
||||
session_id = self._request_token(request_email, client_secret)
|
||||
|
||||
self.assertEquals(len(self.email_attempts), 1)
|
||||
self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
|
||||
link = self._get_link_from_email()
|
||||
|
||||
self._validate_token(link)
|
||||
|
@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
|
||||
|
||||
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
|
||||
self.assertIn(expected_email, threepids)
|
||||
|
|
|
@ -47,6 +47,7 @@ class FakeChannel:
|
|||
site = attr.ib(type=Site)
|
||||
_reactor = attr.ib()
|
||||
result = attr.ib(type=dict, default=attr.Factory(dict))
|
||||
_ip = attr.ib(type=str, default="127.0.0.1")
|
||||
_producer = None
|
||||
|
||||
@property
|
||||
|
@ -120,7 +121,7 @@ class FakeChannel:
|
|||
def getPeer(self):
|
||||
# We give an address so that getClientIP returns a non null entry,
|
||||
# causing us to record the MAU
|
||||
return address.IPv4Address("TCP", "127.0.0.1", 3423)
|
||||
return address.IPv4Address("TCP", self._ip, 3423)
|
||||
|
||||
def getHost(self):
|
||||
return None
|
||||
|
@ -196,6 +197,7 @@ def make_request(
|
|||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
client_ip: str = "127.0.0.1",
|
||||
) -> FakeChannel:
|
||||
"""
|
||||
Make a web request using the given method, path and content, and render it
|
||||
|
@ -223,6 +225,9 @@ def make_request(
|
|||
will pump the reactor until the the renderer tells the channel the request
|
||||
is finished.
|
||||
|
||||
client_ip: The IP to use as the requesting IP. Useful for testing
|
||||
ratelimiting.
|
||||
|
||||
Returns:
|
||||
channel
|
||||
"""
|
||||
|
@ -250,7 +255,7 @@ def make_request(
|
|||
if isinstance(content, str):
|
||||
content = content.encode("utf8")
|
||||
|
||||
channel = FakeChannel(site, reactor)
|
||||
channel = FakeChannel(site, reactor, ip=client_ip)
|
||||
|
||||
req = request(channel)
|
||||
req.content = BytesIO(content)
|
||||
|
|
|
@ -386,6 +386,7 @@ class HomeserverTestCase(TestCase):
|
|||
custom_headers: Optional[
|
||||
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
|
||||
] = None,
|
||||
client_ip: str = "127.0.0.1",
|
||||
) -> FakeChannel:
|
||||
"""
|
||||
Create a SynapseRequest at the path using the method and containing the
|
||||
|
@ -410,6 +411,9 @@ class HomeserverTestCase(TestCase):
|
|||
|
||||
custom_headers: (name, value) pairs to add as request headers
|
||||
|
||||
client_ip: The IP to use as the requesting IP. Useful for testing
|
||||
ratelimiting.
|
||||
|
||||
Returns:
|
||||
The FakeChannel object which stores the result of the request.
|
||||
"""
|
||||
|
@ -426,6 +430,7 @@ class HomeserverTestCase(TestCase):
|
|||
content_is_form,
|
||||
await_result,
|
||||
custom_headers,
|
||||
client_ip,
|
||||
)
|
||||
|
||||
def setup_test_homeserver(self, *args, **kwargs):
|
||||
|
|
|
@ -157,6 +157,7 @@ def default_config(name, parse=False):
|
|||
"local": {"per_second": 10000, "burst_count": 10000},
|
||||
"remote": {"per_second": 10000, "burst_count": 10000},
|
||||
},
|
||||
"rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
|
||||
"saml2_enabled": False,
|
||||
"default_identity_server": None,
|
||||
"key_refresh_interval": 24 * 60 * 60 * 1000,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue