Merge remote-tracking branch 'origin/release-v1.20.0' into develop

This commit is contained in:
Erik Johnston 2020-09-18 10:50:04 +01:00
commit 5e42e61609
4 changed files with 54 additions and 17 deletions

1
changelog.d/8342.bugfix Normal file
View File

@ -0,0 +1 @@
Fix ratelimitng of federation `/send` requests.

View File

@ -97,10 +97,16 @@ class FederationServer(FederationBase):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self._federation_ratelimiter = hs.get_federation_ratelimiter()
self._server_linearizer = Linearizer("fed_server") self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = Linearizer("fed_txn_handler") self._transaction_linearizer = Linearizer("fed_txn_handler")
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
hs, "fed_txn_handler", timeout_ms=30000
)
self.transaction_actions = TransactionActions(self.store) self.transaction_actions = TransactionActions(self.store)
self.registry = hs.get_federation_registry() self.registry = hs.get_federation_registry()
@ -135,19 +141,41 @@ class FederationServer(FederationBase):
request_time = self._clock.time_msec() request_time = self._clock.time_msec()
transaction = Transaction(**transaction_data) transaction = Transaction(**transaction_data)
transaction_id = transaction.transaction_id # type: ignore
if not transaction.transaction_id: # type: ignore if not transaction_id:
raise Exception("Transaction missing transaction_id") raise Exception("Transaction missing transaction_id")
logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore logger.debug("[%s] Got transaction", transaction_id)
# use a linearizer to ensure that we don't process the same transaction # We wrap in a ResponseCache so that we de-duplicate retried
# multiple times in parallel. # transactions.
with ( return await self._transaction_resp_cache.wrap(
await self._transaction_linearizer.queue( (origin, transaction_id),
(origin, transaction.transaction_id) # type: ignore self._on_incoming_transaction_inner,
origin,
transaction,
request_time,
) )
):
async def _on_incoming_transaction_inner(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
# Use a linearizer to ensure that transactions from a remote are
# processed in order.
with await self._transaction_linearizer.queue(origin):
# We rate limit here *after* we've queued up the incoming requests,
# so that we don't fill up the ratelimiter with blocked requests.
#
# This is important as the ratelimiter allows N concurrent requests
# at a time, and only starts ratelimiting if there are more requests
# than that being processed at a time. If we queued up requests in
# the linearizer/response cache *after* the ratelimiting then those
# queued up requests would count as part of the allowed limit of N
# concurrent requests.
with self._federation_ratelimiter.ratelimit(origin) as d:
await d
result = await self._handle_incoming_transaction( result = await self._handle_incoming_transaction(
origin, transaction, request_time origin, transaction, request_time
) )

View File

@ -45,7 +45,6 @@ from synapse.logging.opentracing import (
) )
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import ThirdPartyInstanceID, get_domain_from_id from synapse.types import ThirdPartyInstanceID, get_domain_from_id
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -72,9 +71,7 @@ class TransportLayerServer(JsonResource):
super(TransportLayerServer, self).__init__(hs, canonical_json=False) super(TransportLayerServer, self).__init__(hs, canonical_json=False)
self.authenticator = Authenticator(hs) self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter( self.ratelimiter = hs.get_federation_ratelimiter()
self.clock, config=hs.config.rc_federation
)
self.register_servlets() self.register_servlets()
@ -272,6 +269,8 @@ class BaseFederationServlet:
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
RATELIMIT = True # Whether to rate limit requests or not
def __init__(self, handler, authenticator, ratelimiter, server_name): def __init__(self, handler, authenticator, ratelimiter, server_name):
self.handler = handler self.handler = handler
self.authenticator = authenticator self.authenticator = authenticator
@ -335,7 +334,7 @@ class BaseFederationServlet:
) )
with scope: with scope:
if origin: if origin and self.RATELIMIT:
with ratelimiter.ratelimit(origin) as d: with ratelimiter.ratelimit(origin) as d:
await d await d
if request._disconnected: if request._disconnected:
@ -372,6 +371,10 @@ class BaseFederationServlet:
class FederationSendServlet(BaseFederationServlet): class FederationSendServlet(BaseFederationServlet):
PATH = "/send/(?P<transaction_id>[^/]*)/?" PATH = "/send/(?P<transaction_id>[^/]*)/?"
# We ratelimit manually in the handler as we queue up the requests and we
# don't want to fill up the ratelimiter with blocked requests.
RATELIMIT = False
def __init__(self, handler, server_name, **kwargs): def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__( super(FederationSendServlet, self).__init__(
handler, server_name=server_name, **kwargs handler, server_name=server_name, **kwargs

View File

@ -114,6 +114,7 @@ from synapse.streams.events import EventSources
from synapse.types import DomainSpecificString from synapse.types import DomainSpecificString
from synapse.util import Clock from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -642,6 +643,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_replication_streams(self) -> Dict[str, Stream]: def get_replication_streams(self) -> Dict[str, Stream]:
return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()} return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()}
@cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter:
return FederationRateLimiter(self.clock, config=self.config.rc_federation)
async def remove_pusher(self, app_id: str, push_key: str, user_id: str): async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)