diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 660dfb56e..06cc8d90b 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -23,7 +23,7 @@ class Ratelimiter(object): def __init__(self): self.message_counts = collections.OrderedDict() - def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count): + def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True): """Can the user send a message? Args: user_id: The user sending a message. @@ -32,12 +32,15 @@ class Ratelimiter(object): second. burst_count: How many messages the user can send before being limited. + update (bool): Whether to update the message rates or not. This is + useful to check if a message would be allowed to be sent before + its ready to be actually sent. Returns: A pair of a bool indicating if they can send a message now and a time in seconds of when they can next send a message. """ self.prune_message_counts(time_now_s) - message_count, time_start, _ignored = self.message_counts.pop( + message_count, time_start, _ignored = self.message_counts.get( user_id, (0., time_now_s, None), ) time_delta = time_now_s - time_start @@ -52,9 +55,10 @@ class Ratelimiter(object): allowed = True message_count += 1 - self.message_counts[user_id] = ( - message_count, time_start, msg_rate_hz - ) + if update: + self.message_counts[user_id] = ( + message_count, time_start, msg_rate_hz + ) if msg_rate_hz > 0: time_allowed = ( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 30ea9630f..59eb26bea 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -16,7 +16,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator @@ -239,6 +239,21 @@ class MessageHandler(BaseHandler): "Tried to send member event through non-member codepath" ) + # We check here if we are currently being rate limited, so that we + # don't do unnecessary work. We check again just before we actually + # send the event. + time_now = self.clock.time() + allowed, time_allowed = self.ratelimiter.send_message( + event.sender, time_now, + msg_rate_hz=self.hs.config.rc_messages_per_second, + burst_count=self.hs.config.rc_message_burst_count, + update=False, + ) + if not allowed: + raise LimitExceededError( + retry_after_ms=int(1000 * (time_allowed - time_now)), + ) + user = UserID.from_string(event.sender) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)