From c7a7cdf7346c9268b1d4f483b31e1fdc39b6d7e0 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 2 Sep 2014 17:57:04 +0100 Subject: [PATCH] Add ratelimiting function to basehandler --- synapse/api/errors.py | 1 + synapse/app/homeserver.py | 1 + synapse/config/homeserver.py | 4 +++- synapse/handlers/_base.py | 17 +++++++++++++++++ synapse/server.py | 5 +++++ 5 files changed, 27 insertions(+), 1 deletion(-) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 21ededc5a..3f33ca5b9 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -28,6 +28,7 @@ class Codes(object): UNKNOWN = "M_UNKNOWN" NOT_FOUND = "M_NOT_FOUND" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" + LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" class CodeMessageException(Exception): diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 606c9c650..8a7cd07fe 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -247,6 +247,7 @@ def setup(): upload_dir=os.path.abspath("uploads"), db_name=config.database_path, tls_context_factory=tls_context_factory, + config=config, ) hs.register_servlets() diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 18072e319..a9aa4c735 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -17,8 +17,10 @@ from .tls import TlsConfig from .server import ServerConfig from .logger import LoggingConfig from .database import DatabaseConfig +from .ratelimiting import RatelimitConfig -class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig): +class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, + RatelimitConfig): pass if __name__=='__main__': diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index b37c8be96..dc1298366 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -14,6 +14,7 @@ # limitations under the License. from twisted.internet import defer +from synapse.api.errors import cs_error, Codes class BaseHandler(object): @@ -25,8 +26,24 @@ class BaseHandler(object): self.room_lock = hs.get_room_lock_manager() self.state_handler = hs.get_state_handler() self.distributor = hs.get_distributor() + self.ratelimiter = hs.get_ratelimiter() + self.clock = hs.get_clock() self.hs = hs + def ratelimit(self, user_id): + time_now = self.clock.time() + allowed, time_allowed = self.ratelimiter.send_message( + user_id, time_now, + msg_rate_hz=self.hs.config.rc_messages_per_second, + burst_count=self.hs.config.rc_messsage_burst_count, + ) + if not allowed: + raise cs_error( + "Limit exceeded", + Codes.M_LIMIT_EXCEEDED, + retry_after_ms=1000*(time_allowed - time_now), + ) + class BaseRoomHandler(BaseHandler): diff --git a/synapse/server.py b/synapse/server.py index 3e72b2bcd..35e311a47 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -32,6 +32,7 @@ from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.util.lockutils import LockManager from synapse.streams.events import EventSources +from synapse.api.ratelimiting import Ratelimiter class BaseHomeServer(object): @@ -73,6 +74,7 @@ class BaseHomeServer(object): 'resource_for_web_client', 'resource_for_content_repo', 'event_sources', + 'ratelimiter', ] def __init__(self, hostname, **kwargs): @@ -190,6 +192,9 @@ class HomeServer(BaseHomeServer): def build_event_sources(self): return EventSources(self) + def build_ratelimiter(self): + return Ratelimiter() + def register_servlets(self): """ Register all servlets associated with this HomeServer. """