diff --git a/changelog.d/8957.feature b/changelog.d/8957.feature new file mode 100644 index 000000000..fa8961f84 --- /dev/null +++ b/changelog.d/8957.feature @@ -0,0 +1 @@ +Add rate limiters to cross-user key sharing requests. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index af8d59cf8..691f8f9ad 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -98,11 +98,14 @@ class EventTypes: Retention = "m.room.retention" - Presence = "m.presence" - Dummy = "org.matrix.dummy_event" +class EduTypes: + Presence = "m.presence" + RoomKeyRequest = "m.room_key_request" + + class RejectedReason: AUTH_ERROR = "auth_error" diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 5d9d5a228..c3f07bc1a 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -14,7 +14,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Any, Optional, Tuple +from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError from synapse.types import Requester @@ -42,7 +42,9 @@ class Ratelimiter: # * How many times an action has occurred since a point in time # * The point in time # * The rate_hz of this particular entry. This can vary per request - self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] + self.actions = ( + OrderedDict() + ) # type: OrderedDict[Hashable, Tuple[float, int, float]] def can_requester_do_action( self, @@ -82,7 +84,7 @@ class Ratelimiter: def can_do_action( self, - key: Any, + key: Hashable, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -175,7 +177,7 @@ class Ratelimiter: def ratelimit( self, - key: Any, + key: Hashable, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index def33a60a..847d25122 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -102,6 +102,16 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 3}, ) + # Ratelimit cross-user key requests: + # * For local requests this is keyed by the sending device. + # * For requests received over federation this is keyed by the origin. + # + # Note that this isn't exposed in the configuration as it is obscure. + self.rc_key_requests = RateLimitConfig( + config.get("rc_key_requests", {}), + defaults={"per_second": 20, "burst_count": 100}, + ) + self.rc_3pid_validation = RateLimitConfig( config.get("rc_3pid_validation") or {}, defaults={"per_second": 0.003, "burst_count": 5}, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 8d4bb621e..2f832b47f 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -34,7 +34,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -44,6 +44,7 @@ from synapse.api.errors import ( SynapseError, UnsupportedRoomVersionError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json @@ -869,6 +870,13 @@ class FederationHandlerRegistry: # EDU received. self._edu_type_to_instance = {} # type: Dict[str, List[str]] + # A rate limiter for incoming room key requests per origin. + self._room_key_request_rate_limiter = Ratelimiter( + clock=self.clock, + rate_hz=self.config.rc_key_requests.per_second, + burst_count=self.config.rc_key_requests.burst_count, + ) + def register_edu_handler( self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] ): @@ -917,7 +925,15 @@ class FederationHandlerRegistry: self._edu_type_to_instance[edu_type] = instance_names async def on_edu(self, edu_type: str, origin: str, content: dict): - if not self.config.use_presence and edu_type == "m.presence": + if not self.config.use_presence and edu_type == EduTypes.Presence: + return + + # If the incoming room key requests from a particular origin are over + # the limit, drop them. + if ( + edu_type == EduTypes.RoomKeyRequest + and not self._room_key_request_rate_limiter.can_do_action(origin) + ): return # Check if we have a handler on this instance diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 1aa7d803b..7db4f4896 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -16,7 +16,9 @@ import logging from typing import TYPE_CHECKING, Any, Dict +from synapse.api.constants import EduTypes from synapse.api.errors import SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background from synapse.logging.opentracing import ( get_active_span_text_map, @@ -25,7 +27,7 @@ from synapse.logging.opentracing import ( start_active_span, ) from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.util import json_encoder from synapse.util.stringutils import random_string @@ -78,6 +80,12 @@ class DeviceMessageHandler: ReplicationUserDevicesResyncRestServlet.make_client(hs) ) + self._ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.rc_key_requests.per_second, + burst_count=hs.config.rc_key_requests.burst_count, + ) + async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: local_messages = {} sender_user_id = content["sender"] @@ -168,15 +176,27 @@ class DeviceMessageHandler: async def send_device_message( self, - sender_user_id: str, + requester: Requester, message_type: str, messages: Dict[str, Dict[str, JsonDict]], ) -> None: + sender_user_id = requester.user.to_string() + set_tag("number_of_messages", len(messages)) set_tag("sender", sender_user_id) local_messages = {} remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] for user_id, by_device in messages.items(): + # Ratelimit local cross-user key requests by the sending device. + if ( + message_type == EduTypes.RoomKeyRequest + and user_id != sender_user_id + and self._ratelimiter.can_do_action( + (sender_user_id, requester.device_id) + ) + ): + continue + # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): messages_by_device = { diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 3e23f82cf..f46cab732 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -17,7 +17,7 @@ import logging import random from typing import TYPE_CHECKING, Iterable, List, Optional -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state @@ -113,7 +113,7 @@ class EventStreamHandler(BaseHandler): states = await presence_handler.get_states(users) to_add.extend( { - "type": EventTypes.Presence, + "type": EduTypes.Presence, "content": format_user_presence_state(state, time_now), } for state in states diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 78c3e5a10..71a507667 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional, Tuple from twisted.internet import defer -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import SynapseError from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state @@ -412,7 +412,7 @@ class InitialSyncHandler(BaseHandler): return [ { - "type": EventTypes.Presence, + "type": EduTypes.Presence, "content": format_user_presence_state(s, time_now), } for s in states diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index a3dee14ed..79c1b526e 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -56,10 +56,8 @@ class SendToDeviceRestServlet(servlet.RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ("messages",)) - sender_user_id = requester.user.to_string() - await self.device_message_handler.send_device_message( - sender_user_id, message_type, content["messages"] + requester, message_type, content["messages"] ) response = (200, {}) # type: Tuple[int, dict]