Merge branch 'ratelimiting' into develop

This commit is contained in:
Mark Haines 2014-09-03 09:15:52 +01:00
commit 30ad0c5674
14 changed files with 244 additions and 10 deletions

View file

@ -28,6 +28,7 @@ class Codes(object):
UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
class CodeMessageException(Exception):
@ -39,10 +40,13 @@ class CodeMessageException(Exception):
self.code = code
self.msg = msg
def error_dict(self):
return cs_error(self.msg)
class SynapseError(CodeMessageException):
"""A base error which can be caught for all synapse events."""
def __init__(self, code, msg, errcode=""):
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
"""Constructs a synapse error.
Args:
@ -53,6 +57,11 @@ class SynapseError(CodeMessageException):
super(SynapseError, self).__init__(code, msg)
self.errcode = errcode
def error_dict(self):
return cs_error(
self.msg,
self.errcode,
)
class RoomError(SynapseError):
"""An error raised when a room event fails."""
@ -91,13 +100,25 @@ class StoreError(SynapseError):
pass
def cs_exception(exception):
if isinstance(exception, SynapseError):
class LimitExceededError(SynapseError):
"""A client has sent too many requests and is being throttled.
"""
def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None,
errcode=Codes.LIMIT_EXCEEDED):
super(LimitExceededError, self).__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
def error_dict(self):
return cs_error(
exception.msg,
Codes.UNKNOWN if not exception.errcode else exception.errcode)
elif isinstance(exception, CodeMessageException):
return cs_error(exception.msg)
self.msg,
self.errcode,
retry_after_ms=self.retry_after_ms,
)
def cs_exception(exception):
if isinstance(exception, CodeMessageException):
return exception.error_dict()
else:
logging.error("Unknown exception type: %s", type(exception))

View file

@ -0,0 +1,65 @@
import collections
class Ratelimiter(object):
"""
Ratelimit message sending by user.
"""
def __init__(self):
self.message_counts = collections.OrderedDict()
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
"""Can the user send a message?
Args:
user_id: The user sending a message.
time_now_s: The time now.
msg_rate_hz: The long term number of messages a user can send in a
second.
burst_count: How many messages the user can send before being
limited.
Returns:
A pair of a bool indicating if they can send a message now and a
time in seconds of when they can next send a message.
"""
self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.pop(
user_id, (0., time_now_s, None),
)
time_delta = time_now_s - time_start
sent_count = message_count - time_delta * msg_rate_hz
if sent_count < 0:
allowed = True
time_start = time_now_s
messagecount = 1.
elif sent_count > burst_count - 1.:
allowed = False
else:
allowed = True
message_count += 1
self.message_counts[user_id] = (
message_count, time_start, msg_rate_hz
)
if msg_rate_hz > 0:
time_allowed = (
time_start + (message_count - burst_count + 1) / msg_rate_hz
)
if time_allowed < time_now_s:
time_allowed = time_now_s
else:
time_allowed = -1
return allowed, time_allowed
def prune_message_counts(self, time_now_s):
for user_id in self.message_counts.keys():
message_count, time_start, msg_rate_hz = (
self.message_counts[user_id]
)
time_delta = time_now_s - time_start
if message_count - time_delta * msg_rate_hz > 0:
break
else:
del self.message_counts[user_id]