Remove the deprecated BaseHandler. (#11005)

The shared ratelimit function was replaced with a dedicated
RequestRatelimiter class (accessible from the HomeServer
object).

Other properties were copied to each sub-class that inherited
from BaseHandler.
This commit is contained in:
Patrick Cloke 2021-10-08 07:44:43 -04:00 committed by GitHub
parent 49a683d871
commit eb9ddc8c2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 166 additions and 215 deletions

1
changelog.d/11005.misc Normal file
View File

@ -0,0 +1 @@
Remove the deprecated `BaseHandler` object.

View File

@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import Hashable, Optional, Tuple from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import RateLimitConfig
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import Requester from synapse.types import Requester
from synapse.util import Clock from synapse.util import Clock
@ -233,3 +234,88 @@ class Ratelimiter:
raise LimitExceededError( raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s)) retry_after_ms=int(1000 * (time_allowed - time_now_s))
) )
class RequestRatelimiter:
def __init__(
self,
store: DataStore,
clock: Clock,
rc_message: RateLimitConfig,
rc_admin_redaction: Optional[RateLimitConfig],
):
self.store = store
self.clock = clock
# The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(
store=self.store, clock=self.clock, rate_hz=0, burst_count=0
)
self._rc_message = rc_message
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if rc_admin_redaction:
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=rc_admin_redaction.per_second,
burst_count=rc_admin_redaction.burst_count,
)
else:
self.admin_redaction_ratelimiter = None
async def ratelimit(
self,
requester: Requester,
update: bool = True,
is_admin_redaction: bool = False,
) -> None:
"""Ratelimits requests.
Args:
requester
update: Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
is_admin_redaction: Whether this is a room admin/moderator
redacting an event. If so then we may apply different
ratelimits depending on config.
Raises:
LimitExceededError if the request should be ratelimited
"""
user_id = requester.user.to_string()
# The AS user itself is never rate limited.
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service is not None:
return # do not ratelimit app service senders
messages_per_second = self._rc_message.per_second
burst_count = self._rc_message.burst_count
# Check if there is a per user override in the DB.
override = await self.store.get_ratelimit_for_user(user_id)
if override:
# If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user
if not override.messages_per_second:
return
messages_per_second = override.messages_per_second
burst_count = override.burst_count
if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions, use a separate
# ratelimiter as to not have user_ids clash
await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
else:
# Override rate and burst count per-user
await self.request_ratelimiter.ratelimit(
requester,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)

View File

@ -1,120 +0,0 @@
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from synapse.api.ratelimiting import Ratelimiter
from synapse.types import Requester
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class BaseHandler:
"""
Common base class for the event handlers.
Deprecated: new code should not use this. Instead, Handler classes should define the
fields they actually need. The utility methods should either be factored out to
standalone helper functions, or to different Handler classes.
"""
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.clock = hs.get_clock()
self.hs = hs
# The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(
store=self.store, clock=self.clock, rate_hz=0, burst_count=0
)
self._rc_message = self.hs.config.ratelimiting.rc_message
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.ratelimiting.rc_admin_redaction:
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=self.hs.config.ratelimiting.rc_admin_redaction.per_second,
burst_count=self.hs.config.ratelimiting.rc_admin_redaction.burst_count,
)
else:
self.admin_redaction_ratelimiter = None
self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory()
async def ratelimit(
self,
requester: Requester,
update: bool = True,
is_admin_redaction: bool = False,
) -> None:
"""Ratelimits requests.
Args:
requester
update: Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
is_admin_redaction: Whether this is a room admin/moderator
redacting an event. If so then we may apply different
ratelimits depending on config.
Raises:
LimitExceededError if the request should be ratelimited
"""
user_id = requester.user.to_string()
# The AS user itself is never rate limited.
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service is not None:
return # do not ratelimit app service senders
messages_per_second = self._rc_message.per_second
burst_count = self._rc_message.burst_count
# Check if there is a per user override in the DB.
override = await self.store.get_ratelimit_for_user(user_id)
if override:
# If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user
if not override.messages_per_second:
return
messages_per_second = override.messages_per_second
burst_count = override.burst_count
if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions, use a separate
# ratelimiter as to not have user_ids clash
await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
else:
# Override rate and burst count per-user
await self.request_ratelimiter.ratelimit(
requester,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)

View File

@ -21,18 +21,15 @@ from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler): class AdminHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state

View File

@ -52,7 +52,6 @@ from synapse.api.errors import (
UserDeactivatedError, UserDeactivatedError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers._base import BaseHandler
from synapse.handlers.ui_auth import ( from synapse.handlers.ui_auth import (
INTERACTIVE_AUTH_CHECKERS, INTERACTIVE_AUTH_CHECKERS,
UIAuthSessionDataConstants, UIAuthSessionDataConstants,
@ -186,12 +185,13 @@ class LoginTokenAttributes:
auth_provider_id = attr.ib(type=str) auth_provider_id = attr.ib(type=str)
class AuthHandler(BaseHandler): class AuthHandler:
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.checkers: Dict[str, UserInteractiveAuthChecker] = {} self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs) inst = auth_checker_class(hs)

View File

@ -19,19 +19,17 @@ from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import Requester, UserID, create_requester from synapse.types import Requester, UserID, create_requester
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DeactivateAccountHandler(BaseHandler): class DeactivateAccountHandler:
"""Handler which deals with deactivating user accounts.""" """Handler which deals with deactivating user accounts."""
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()

View File

@ -40,8 +40,6 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -50,14 +48,16 @@ logger = logging.getLogger(__name__)
MAX_DEVICE_DISPLAY_NAME_LEN = 100 MAX_DEVICE_DISPLAY_NAME_LEN = 100
class DeviceWorkerHandler(BaseHandler): class DeviceWorkerHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.clock = hs.get_clock()
self.hs = hs self.hs = hs
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.state_store = hs.get_storage().state self.state_store = hs.get_storage().state
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.server_name = hs.hostname
@trace @trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:

View File

@ -31,18 +31,16 @@ from synapse.appservice import ApplicationService
from synapse.storage.databases.main.directory import RoomAliasMapping from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DirectoryHandler(BaseHandler): class DirectoryHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.auth = hs.get_auth()
self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -51,6 +49,7 @@ class DirectoryHandler(BaseHandler):
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.require_membership = hs.config.server.require_membership_for_aliases self.require_membership = hs.config.server.require_membership_for_aliases
self.third_party_event_rules = hs.get_third_party_event_rules() self.third_party_event_rules = hs.get_third_party_event_rules()
self.server_name = hs.hostname
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler( hs.get_federation_registry().register_query_handler(

View File

@ -25,8 +25,6 @@ from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -34,11 +32,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EventStreamHandler(BaseHandler): class EventStreamHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@ -138,9 +136,9 @@ class EventStreamHandler(BaseHandler):
return chunk return chunk
class EventHandler(BaseHandler): class EventHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.storage = hs.get_storage() self.storage = hs.get_storage()
async def get_event( async def get_event(

View File

@ -53,7 +53,6 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers._base import BaseHandler
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import ( from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
@ -78,15 +77,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FederationHandler(BaseHandler): class FederationHandler:
"""Handles general incoming federation requests """Handles general incoming federation requests
Incoming events are *not* handled here, for which see FederationEventHandler. Incoming events are *not* handled here, for which see FederationEventHandler.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -99,6 +96,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self._event_auth_handler = hs.get_event_auth_handler() self._event_auth_handler = hs.get_event_auth_handler()
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self.config = hs.config self.config = hs.config

View File

@ -39,8 +39,6 @@ from synapse.util.stringutils import (
valid_id_server_location, valid_id_server_location,
) )
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -49,10 +47,9 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://" id_server_scheme = "https://"
class IdentityHandler(BaseHandler): class IdentityHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
# An HTTP client for contacting trusted URLs. # An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs) self.http_client = SimpleHttpClient(hs)
# An HTTP client for contacting identity servers specified by clients. # An HTTP client for contacting identity servers specified by clients.

View File

@ -31,8 +31,6 @@ from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -40,9 +38,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class InitialSyncHandler(BaseHandler): class InitialSyncHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.hs = hs self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()

View File

@ -62,8 +62,6 @@ from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.events.third_party_rules import ThirdPartyEventRules from synapse.events.third_party_rules import ThirdPartyEventRules
from synapse.server import HomeServer from synapse.server import HomeServer
@ -433,8 +431,7 @@ class EventCreationHandler:
self.send_event = ReplicationSendEventRestServlet.make_client(hs) self.send_event = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function self.request_ratelimiter = hs.get_request_ratelimiter()
self.base_handler = BaseHandler(hs)
# We arbitrarily limit concurrent event creation for a room to 5. # We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much. # This is to stop us from diverging history *too* much.
@ -1322,7 +1319,7 @@ class EventCreationHandler:
original_event and event.sender != original_event.sender original_event and event.sender != original_event.sender
) )
await self.base_handler.ratelimit( await self.request_ratelimiter.ratelimit(
requester, is_admin_redaction=is_admin_redaction requester, is_admin_redaction=is_admin_redaction
) )

View File

@ -32,8 +32,6 @@ from synapse.types import (
get_domain_from_id, get_domain_from_id,
) )
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -43,7 +41,7 @@ MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000 MAX_AVATAR_URL_LEN = 1000
class ProfileHandler(BaseHandler): class ProfileHandler:
"""Handles fetching and updating user profile information. """Handles fetching and updating user profile information.
ProfileHandler can be instantiated directly on workers and will ProfileHandler can be instantiated directly on workers and will
@ -54,7 +52,9 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.hs = hs
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler( hs.get_federation_registry().register_query_handler(
@ -62,6 +62,7 @@ class ProfileHandler(BaseHandler):
) )
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
self.request_ratelimiter = hs.get_request_ratelimiter()
if hs.config.worker.run_background_tasks: if hs.config.worker.run_background_tasks:
self.clock.looping_call( self.clock.looping_call(
@ -346,7 +347,7 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
return return
await self.ratelimit(requester) await self.request_ratelimiter.ratelimit(requester)
# Do not actually update the room state for shadow-banned users. # Do not actually update the room state for shadow-banned users.
if requester.shadow_banned: if requester.shadow_banned:

View File

@ -17,17 +17,14 @@ from typing import TYPE_CHECKING
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReadMarkerHandler(BaseHandler): class ReadMarkerHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.server_name = hs.config.server.server_name self.server_name = hs.config.server.server_name
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.account_data_handler = hs.get_account_data_handler() self.account_data_handler = hs.get_account_data_handler()

View File

@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@ -26,10 +25,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler): class ReceiptsHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.notifier = hs.get_notifier()
self.server_name = hs.config.server.server_name self.server_name = hs.config.server.server_name
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_auth_handler = hs.get_event_auth_handler() self.event_auth_handler = hs.get_event_auth_handler()

View File

@ -41,8 +41,6 @@ from synapse.spam_checker_api import RegistrationBehaviour
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester from synapse.types import RoomAlias, UserID, create_requester
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -85,9 +83,10 @@ class LoginDict(TypedDict):
refresh_token: Optional[str] refresh_token: Optional[str]
class RegistrationHandler(BaseHandler): class RegistrationHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@ -515,7 +514,7 @@ class RegistrationHandler(BaseHandler):
# we don't have a local user in the room to craft up an invite with. # we don't have a local user in the room to craft up an invite with.
requires_invite = await self.store.is_host_joined( requires_invite = await self.store.is_host_joined(
room_id, room_id,
self.server_name, self._server_name,
) )
if requires_invite: if requires_invite:

View File

@ -76,8 +76,6 @@ from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.stringutils import parse_and_validate_server_name
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -88,15 +86,18 @@ id_server_scheme = "https://"
FIVE_MINUTES_IN_MS = 5 * 60 * 1000 FIVE_MINUTES_IN_MS = 5 * 60 * 1000
class RoomCreationHandler(BaseHandler): class RoomCreationHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.hs = hs
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self._event_auth_handler = hs.get_event_auth_handler() self._event_auth_handler = hs.get_event_auth_handler()
self.config = hs.config self.config = hs.config
self.request_ratelimiter = hs.get_request_ratelimiter()
# Room state based off defined presets # Room state based off defined presets
self._presets_dict: Dict[str, Dict[str, Any]] = { self._presets_dict: Dict[str, Dict[str, Any]] = {
@ -162,7 +163,7 @@ class RoomCreationHandler(BaseHandler):
Raises: Raises:
ShadowBanError if the requester is shadow-banned. ShadowBanError if the requester is shadow-banned.
""" """
await self.ratelimit(requester) await self.request_ratelimiter.ratelimit(requester)
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -665,7 +666,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(403, "You are not permitted to create rooms") raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit: if ratelimit:
await self.ratelimit(requester) await self.request_ratelimiter.ratelimit(requester)
room_version_id = config.get( room_version_id = config.get(
"room_version", self.config.server.default_room_version.identifier "room_version", self.config.server.default_room_version.identifier

View File

@ -36,8 +36,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -49,9 +47,10 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler): class RoomListHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.hs = hs
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.response_cache: ResponseCache[ self.response_cache: ResponseCache[
Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]]

View File

@ -51,8 +51,6 @@ from synapse.types import (
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room from synapse.util.distributor import user_left_room
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -118,9 +116,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
) )
# This is only used to get at the ratelimit function. It's fine there are self.request_ratelimiter = hs.get_request_ratelimiter()
# multiple of these as it doesn't store state.
self.base_handler = BaseHandler(hs)
@abc.abstractmethod @abc.abstractmethod
async def _remote_join( async def _remote_join(
@ -1275,7 +1271,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# We need to rate limit *before* we send out any 3PID invites, so we # We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events. # can't just rely on the standard ratelimiting of events.
await self.base_handler.ratelimit(requester) await self.request_ratelimiter.ratelimit(requester)
can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id medium, address, room_id

View File

@ -22,7 +22,6 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -51,9 +50,11 @@ class Saml2SessionData:
ui_auth_session_id: Optional[str] = None ui_auth_session_id: Optional[str] = None
class SamlHandler(BaseHandler): class SamlHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config)
self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid

View File

@ -26,17 +26,18 @@ from synapse.storage.state import StateFilter
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler): class SearchHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self.state_handler = hs.get_state_handler()
self.clock = hs.get_clock()
self.hs = hs
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state

View File

@ -17,19 +17,17 @@ from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester from synapse.types import Requester
from ._base import BaseHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler): class SetPasswordHandler:
"""Handler which deals with changing user account passwords""" """Handler which deals with changing user account passwords"""
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs) self.store = hs.get_datastore()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()

View File

@ -39,7 +39,7 @@ from twisted.web.resource import IResource
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.filtering import Filtering from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -816,3 +816,12 @@ class HomeServer(metaclass=abc.ABCMeta):
def should_send_federation(self) -> bool: def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?" "Should this server be sending federation traffic directly?"
return self.config.worker.send_federation return self.config.worker.send_federation
@cache_in_self
def get_request_ratelimiter(self) -> RequestRatelimiter:
return RequestRatelimiter(
self.get_datastore(),
self.get_clock(),
self.config.ratelimiting.rc_message,
self.config.ratelimiting.rc_admin_redaction,
)