mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Allow rate limiters to passively record actions they cannot limit (#13253)
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
parent
0eb7e69768
commit
599c403d99
1
changelog.d/13253.misc
Normal file
1
changelog.d/13253.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Preparatory work for a per-room rate limiter on joins.
|
@ -27,6 +27,33 @@ class Ratelimiter:
|
|||||||
"""
|
"""
|
||||||
Ratelimit actions marked by arbitrary keys.
|
Ratelimit actions marked by arbitrary keys.
|
||||||
|
|
||||||
|
(Note that the source code speaks of "actions" and "burst_count" rather than
|
||||||
|
"tokens" and a "bucket_size".)
|
||||||
|
|
||||||
|
This is a "leaky bucket as a meter". For each key to be tracked there is a bucket
|
||||||
|
containing some number 0 <= T <= `burst_count` of tokens corresponding to previously
|
||||||
|
permitted requests for that key. Each bucket starts empty, and gradually leaks
|
||||||
|
tokens at a rate of `rate_hz`.
|
||||||
|
|
||||||
|
Upon an incoming request, we must determine:
|
||||||
|
- the key that this request falls under (which bucket to inspect), and
|
||||||
|
- the cost C of this request in tokens.
|
||||||
|
Then, if there is room in the bucket for C tokens (T + C <= `burst_count`),
|
||||||
|
the request is permitted and `cost` tokens are added to the bucket.
|
||||||
|
Otherwise the request is denied, and the bucket continues to hold T tokens.
|
||||||
|
|
||||||
|
This means that the limiter enforces an average request frequency of `rate_hz`,
|
||||||
|
while accumulating a buffer of up to `burst_count` requests which can be consumed
|
||||||
|
instantaneously.
|
||||||
|
|
||||||
|
The tricky bit is the leaking. We do not want to have a periodic process which
|
||||||
|
leaks every bucket! Instead, we track
|
||||||
|
- the time point when the bucket was last completely empty, and
|
||||||
|
- how many tokens have added to the bucket permitted since then.
|
||||||
|
Then for each incoming request, we can calculate how many tokens have leaked
|
||||||
|
since this time point, and use that to decide if we should accept or reject the
|
||||||
|
request.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
clock: A homeserver clock, for retrieving the current time
|
clock: A homeserver clock, for retrieving the current time
|
||||||
rate_hz: The long term number of actions that can be performed in a second.
|
rate_hz: The long term number of actions that can be performed in a second.
|
||||||
@ -41,14 +68,30 @@ class Ratelimiter:
|
|||||||
self.burst_count = burst_count
|
self.burst_count = burst_count
|
||||||
self.store = store
|
self.store = store
|
||||||
|
|
||||||
# A ordered dictionary keeping track of actions, when they were last
|
# An ordered dictionary representing the token buckets tracked by this rate
|
||||||
# performed and how often. Each entry is a mapping from a key of arbitrary type
|
# limiter. Each entry maps a key of arbitrary type to a tuple representing:
|
||||||
# to a tuple representing:
|
# * The number of tokens currently in the bucket,
|
||||||
# * How many times an action has occurred since a point in time
|
# * The time point when the bucket was last completely empty, and
|
||||||
# * The point in time
|
# * The rate_hz (leak rate) of this particular bucket.
|
||||||
# * The rate_hz of this particular entry. This can vary per request
|
|
||||||
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
|
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
|
||||||
|
|
||||||
|
def _get_key(
|
||||||
|
self, requester: Optional[Requester], key: Optional[Hashable]
|
||||||
|
) -> Hashable:
|
||||||
|
"""Use the requester's MXID as a fallback key if no key is provided."""
|
||||||
|
if key is None:
|
||||||
|
if not requester:
|
||||||
|
raise ValueError("Must supply at least one of `requester` or `key`")
|
||||||
|
|
||||||
|
key = requester.user.to_string()
|
||||||
|
return key
|
||||||
|
|
||||||
|
def _get_action_counts(
|
||||||
|
self, key: Hashable, time_now_s: float
|
||||||
|
) -> Tuple[float, float, float]:
|
||||||
|
"""Retrieve the action counts, with a fallback representing an empty bucket."""
|
||||||
|
return self.actions.get(key, (0.0, time_now_s, 0.0))
|
||||||
|
|
||||||
async def can_do_action(
|
async def can_do_action(
|
||||||
self,
|
self,
|
||||||
requester: Optional[Requester],
|
requester: Optional[Requester],
|
||||||
@ -88,11 +131,7 @@ class Ratelimiter:
|
|||||||
* The reactor timestamp for when the action can be performed next.
|
* The reactor timestamp for when the action can be performed next.
|
||||||
-1 if rate_hz is less than or equal to zero
|
-1 if rate_hz is less than or equal to zero
|
||||||
"""
|
"""
|
||||||
if key is None:
|
key = self._get_key(requester, key)
|
||||||
if not requester:
|
|
||||||
raise ValueError("Must supply at least one of `requester` or `key`")
|
|
||||||
|
|
||||||
key = requester.user.to_string()
|
|
||||||
|
|
||||||
if requester:
|
if requester:
|
||||||
# Disable rate limiting of users belonging to any AS that is configured
|
# Disable rate limiting of users belonging to any AS that is configured
|
||||||
@ -121,7 +160,7 @@ class Ratelimiter:
|
|||||||
self._prune_message_counts(time_now_s)
|
self._prune_message_counts(time_now_s)
|
||||||
|
|
||||||
# Check if there is an existing count entry for this key
|
# Check if there is an existing count entry for this key
|
||||||
action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
|
action_count, time_start, _ = self._get_action_counts(key, time_now_s)
|
||||||
|
|
||||||
# Check whether performing another action is allowed
|
# Check whether performing another action is allowed
|
||||||
time_delta = time_now_s - time_start
|
time_delta = time_now_s - time_start
|
||||||
@ -164,6 +203,37 @@ class Ratelimiter:
|
|||||||
|
|
||||||
return allowed, time_allowed
|
return allowed, time_allowed
|
||||||
|
|
||||||
|
def record_action(
|
||||||
|
self,
|
||||||
|
requester: Optional[Requester],
|
||||||
|
key: Optional[Hashable] = None,
|
||||||
|
n_actions: int = 1,
|
||||||
|
_time_now_s: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Record that an action(s) took place, even if they violate the rate limit.
|
||||||
|
|
||||||
|
This is useful for tracking the frequency of events that happen across
|
||||||
|
federation which we still want to impose local rate limits on. For instance, if
|
||||||
|
we are alice.com monitoring a particular room, we cannot prevent bob.com
|
||||||
|
from joining users to that room. However, we can track the number of recent
|
||||||
|
joins in the room and refuse to serve new joins ourselves if there have been too
|
||||||
|
many in the room across both homeservers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester: The requester that is doing the action, if any.
|
||||||
|
key: An arbitrary key used to classify an action. Defaults to the
|
||||||
|
requester's user ID.
|
||||||
|
n_actions: The number of times the user wants to do this action. If the user
|
||||||
|
cannot do all of the actions, the user's action count is not incremented
|
||||||
|
at all.
|
||||||
|
_time_now_s: The current time. Optional, defaults to the current time according
|
||||||
|
to self.clock. Only used by tests.
|
||||||
|
"""
|
||||||
|
key = self._get_key(requester, key)
|
||||||
|
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
||||||
|
action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
|
||||||
|
self.actions[key] = (action_count + n_actions, time_start, rate_hz)
|
||||||
|
|
||||||
def _prune_message_counts(self, time_now_s: float) -> None:
|
def _prune_message_counts(self, time_now_s: float) -> None:
|
||||||
"""Remove message count entries that have not exceeded their defined
|
"""Remove message count entries that have not exceeded their defined
|
||||||
rate_hz limit
|
rate_hz limit
|
||||||
|
@ -314,3 +314,77 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# Check that we get rate limited after using that token.
|
# Check that we get rate limited after using that token.
|
||||||
self.assertFalse(consume_at(11.1))
|
self.assertFalse(consume_at(11.1))
|
||||||
|
|
||||||
|
def test_record_action_which_doesnt_fill_bucket(self) -> None:
|
||||||
|
limiter = Ratelimiter(
|
||||||
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
|
||||||
|
)
|
||||||
|
|
||||||
|
# Observe two actions, leaving room in the bucket for one more.
|
||||||
|
limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)
|
||||||
|
|
||||||
|
# We should be able to take a new action now.
|
||||||
|
success, _ = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
|
||||||
|
)
|
||||||
|
self.assertTrue(success)
|
||||||
|
|
||||||
|
# ... but not two.
|
||||||
|
success, _ = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
|
||||||
|
)
|
||||||
|
self.assertFalse(success)
|
||||||
|
|
||||||
|
def test_record_action_which_fills_bucket(self) -> None:
|
||||||
|
limiter = Ratelimiter(
|
||||||
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
|
||||||
|
)
|
||||||
|
|
||||||
|
# Observe three actions, filling up the bucket.
|
||||||
|
limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)
|
||||||
|
|
||||||
|
# We should be unable to take a new action now.
|
||||||
|
success, _ = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
|
||||||
|
)
|
||||||
|
self.assertFalse(success)
|
||||||
|
|
||||||
|
# If we wait 10 seconds to leak a token, we should be able to take one action...
|
||||||
|
success, _ = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
|
||||||
|
)
|
||||||
|
self.assertTrue(success)
|
||||||
|
|
||||||
|
# ... but not two.
|
||||||
|
success, _ = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
|
||||||
|
)
|
||||||
|
self.assertFalse(success)
|
||||||
|
|
||||||
|
def test_record_action_which_overfills_bucket(self) -> None:
|
||||||
|
limiter = Ratelimiter(
|
||||||
|
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
|
||||||
|
)
|
||||||
|
|
||||||
|
# Observe four actions, exceeding the bucket.
|
||||||
|
limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)
|
||||||
|
|
||||||
|
# We should be prevented from taking a new action now.
|
||||||
|
success, _ = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
|
||||||
|
)
|
||||||
|
self.assertFalse(success)
|
||||||
|
|
||||||
|
# If we wait 10 seconds to leak a token, we should be unable to take an action
|
||||||
|
# because the bucket is still full.
|
||||||
|
success, _ = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
|
||||||
|
)
|
||||||
|
self.assertFalse(success)
|
||||||
|
|
||||||
|
# But after another 10 seconds we leak a second token, giving us room for
|
||||||
|
# action.
|
||||||
|
success, _ = self.get_success_or_raise(
|
||||||
|
limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
|
||||||
|
)
|
||||||
|
self.assertTrue(success)
|
||||||
|
Loading…
Reference in New Issue
Block a user