Ratelimit cross-user key sharing requests. (#8957)

This commit is contained in:
Patrick Cloke 2021-02-19 13:20:34 -05:00 committed by GitHub
parent 179c0953ff
commit fc8b3d8809
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 67 additions and 17 deletions

1
changelog.d/8957.feature Normal file
View File

@ -0,0 +1 @@
Add rate limiters to cross-user key sharing requests.

View File

@ -98,11 +98,14 @@ class EventTypes:
Retention = "m.room.retention" Retention = "m.room.retention"
Presence = "m.presence"
Dummy = "org.matrix.dummy_event" Dummy = "org.matrix.dummy_event"
class EduTypes:
Presence = "m.presence"
RoomKeyRequest = "m.room_key_request"
class RejectedReason: class RejectedReason:
AUTH_ERROR = "auth_error" AUTH_ERROR = "auth_error"

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional, Tuple from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.types import Requester from synapse.types import Requester
@ -42,7 +42,9 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time # * How many times an action has occurred since a point in time
# * The point in time # * The point in time
# * The rate_hz of this particular entry. This can vary per request # * The rate_hz of this particular entry. This can vary per request
self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]] self.actions = (
OrderedDict()
) # type: OrderedDict[Hashable, Tuple[float, int, float]]
def can_requester_do_action( def can_requester_do_action(
self, self,
@ -82,7 +84,7 @@ class Ratelimiter:
def can_do_action( def can_do_action(
self, self,
key: Any, key: Hashable,
rate_hz: Optional[float] = None, rate_hz: Optional[float] = None,
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,
@ -175,7 +177,7 @@ class Ratelimiter:
def ratelimit( def ratelimit(
self, self,
key: Any, key: Hashable,
rate_hz: Optional[float] = None, rate_hz: Optional[float] = None,
burst_count: Optional[int] = None, burst_count: Optional[int] = None,
update: bool = True, update: bool = True,

View File

@ -102,6 +102,16 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.01, "burst_count": 3}, defaults={"per_second": 0.01, "burst_count": 3},
) )
# Ratelimit cross-user key requests:
# * For local requests this is keyed by the sending device.
# * For requests received over federation this is keyed by the origin.
#
# Note that this isn't exposed in the configuration as it is obscure.
self.rc_key_requests = RateLimitConfig(
config.get("rc_key_requests", {}),
defaults={"per_second": 20, "burst_count": 100},
)
self.rc_3pid_validation = RateLimitConfig( self.rc_3pid_validation = RateLimitConfig(
config.get("rc_3pid_validation") or {}, config.get("rc_3pid_validation") or {},
defaults={"per_second": 0.003, "burst_count": 5}, defaults={"per_second": 0.003, "burst_count": 5},

View File

@ -34,7 +34,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress from twisted.internet.abstract import isIPAddress
from twisted.python import failure from twisted.python import failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -44,6 +44,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnsupportedRoomVersionError, UnsupportedRoomVersionError,
) )
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
@ -869,6 +870,13 @@ class FederationHandlerRegistry:
# EDU received. # EDU received.
self._edu_type_to_instance = {} # type: Dict[str, List[str]] self._edu_type_to_instance = {} # type: Dict[str, List[str]]
# A rate limiter for incoming room key requests per origin.
self._room_key_request_rate_limiter = Ratelimiter(
clock=self.clock,
rate_hz=self.config.rc_key_requests.per_second,
burst_count=self.config.rc_key_requests.burst_count,
)
def register_edu_handler( def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]] self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
): ):
@ -917,7 +925,15 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict): async def on_edu(self, edu_type: str, origin: str, content: dict):
if not self.config.use_presence and edu_type == "m.presence": if not self.config.use_presence and edu_type == EduTypes.Presence:
return
# If the incoming room key requests from a particular origin are over
# the limit, drop them.
if (
edu_type == EduTypes.RoomKeyRequest
and not self._room_key_request_rate_limiter.can_do_action(origin)
):
return return
# Check if we have a handler on this instance # Check if we have a handler on this instance

View File

@ -16,7 +16,9 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any, Dict
from synapse.api.constants import EduTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
get_active_span_text_map, get_active_span_text_map,
@ -25,7 +27,7 @@ from synapse.logging.opentracing import (
start_active_span, start_active_span,
) )
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -78,6 +80,12 @@ class DeviceMessageHandler:
ReplicationUserDevicesResyncRestServlet.make_client(hs) ReplicationUserDevicesResyncRestServlet.make_client(hs)
) )
self._ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=hs.config.rc_key_requests.per_second,
burst_count=hs.config.rc_key_requests.burst_count,
)
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {} local_messages = {}
sender_user_id = content["sender"] sender_user_id = content["sender"]
@ -168,15 +176,27 @@ class DeviceMessageHandler:
async def send_device_message( async def send_device_message(
self, self,
sender_user_id: str, requester: Requester,
message_type: str, message_type: str,
messages: Dict[str, Dict[str, JsonDict]], messages: Dict[str, Dict[str, JsonDict]],
) -> None: ) -> None:
sender_user_id = requester.user.to_string()
set_tag("number_of_messages", len(messages)) set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id) set_tag("sender", sender_user_id)
local_messages = {} local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
for user_id, by_device in messages.items(): for user_id, by_device in messages.items():
# Ratelimit local cross-user key requests by the sending device.
if (
message_type == EduTypes.RoomKeyRequest
and user_id != sender_user_id
and self._ratelimiter.can_do_action(
(sender_user_id, requester.device_id)
)
):
continue
# we use UserID.from_string to catch invalid user ids # we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)): if self.is_mine(UserID.from_string(user_id)):
messages_by_device = { messages_by_device = {

View File

@ -17,7 +17,7 @@ import logging
import random import random
from typing import TYPE_CHECKING, Iterable, List, Optional from typing import TYPE_CHECKING, Iterable, List, Optional
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -113,7 +113,7 @@ class EventStreamHandler(BaseHandler):
states = await presence_handler.get_states(users) states = await presence_handler.get_states(users)
to_add.extend( to_add.extend(
{ {
"type": EventTypes.Presence, "type": EduTypes.Presence,
"content": format_user_presence_state(state, time_now), "content": format_user_presence_state(state, time_now),
} }
for state in states for state in states

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
@ -412,7 +412,7 @@ class InitialSyncHandler(BaseHandler):
return [ return [
{ {
"type": EventTypes.Presence, "type": EduTypes.Presence,
"content": format_user_presence_state(s, time_now), "content": format_user_presence_state(s, time_now),
} }
for s in states for s in states

View File

@ -56,10 +56,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("messages",)) assert_params_in_dict(content, ("messages",))
sender_user_id = requester.user.to_string()
await self.device_message_handler.send_device_message( await self.device_message_handler.send_device_message(
sender_user_id, message_type, content["messages"] requester, message_type, content["messages"]
) )
response = (200, {}) # type: Tuple[int, dict] response = (200, {}) # type: Tuple[int, dict]