Merge branch 'release-v1.13.0' into develop

This commit is contained in:
Richard van der Hoff 2020-05-06 15:56:03 +01:00
commit 62ee862119
14 changed files with 264 additions and 150 deletions

1
changelog.d/7420.misc Normal file
View File

@ -0,0 +1 @@
Prevent methods in `synapse.handlers.auth` from polling the homeserver config every request.

1
changelog.d/7423.misc Normal file
View File

@ -0,0 +1 @@
Speed up fetching device lists changes when handling `/sync` requests.

1
changelog.d/7427.feature Normal file
View File

@ -0,0 +1 @@
Add support for running replication over Redis when using workers.

View File

@ -26,16 +26,15 @@ from twisted.internet import defer
import synapse.logging.opentracing as opentracing import synapse.logging.opentracing as opentracing
import synapse.types import synapse.types
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
InvalidClientTokenError, InvalidClientTokenError,
MissingClientTokenError, MissingClientTokenError,
ResourceLimitError,
) )
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import is_threepid_reserved
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import StateMap, UserID from synapse.types import StateMap, UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
@ -77,7 +76,11 @@ class Auth(object):
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
register_cache("cache", "token_cache", self.token_cache) register_cache("cache", "token_cache", self.token_cache)
self._auth_blocking = AuthBlocking(self.hs)
self._account_validity = hs.config.account_validity self._account_validity = hs.config.account_validity
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True): def check_from_context(self, room_version: str, event, context, do_sig_check=True):
@ -191,7 +194,7 @@ class Auth(object):
opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id) opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self.hs.config.track_appservice_user_ips: if ip_addr and self._track_appservice_user_ips:
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
user_id=user_id, user_id=user_id,
access_token=access_token, access_token=access_token,
@ -454,7 +457,7 @@ class Auth(object):
# access_tokens include a nonce for uniqueness: any value is acceptable # access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = ")) v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self._macaroon_secret_key)
def _verify_expiry(self, caveat): def _verify_expiry(self, caveat):
prefix = "time < " prefix = "time < "
@ -663,71 +666,5 @@ class Auth(object):
% (user_id, room_id), % (user_id, room_id),
) )
@defer.inlineCallbacks def check_auth_blocking(self, *args, **kwargs):
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): return self._auth_blocking.check_auth_blocking(*args, **kwargs)
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
threepid(dict|None): If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
user_type(str|None): If present, is used to decide whether to check against
certain blocking reasons like MAU.
"""
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
if user_id is not None:
if user_id == self.hs.config.server_notices_mxid:
return
if (yield self.store.is_support_user(user_id)):
return
if self.hs.config.hs_disabled:
raise ResourceLimitError(
403,
self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=self.hs.config.admin_contact,
limit_type=LimitBlockingTypes.HS_DISABLED,
)
if self.hs.config.limit_usage_by_mau is True:
assert not (user_id and threepid)
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
# If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
return
elif user_type == UserTypes.SUPPORT:
# If the user does not exist yet and is of type "support",
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError(
403,
"Monthly Active User Limit Exceeded",
admin_contact=self.hs.config.admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
)

View File

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
logger = logging.getLogger(__name__)
class AuthBlocking(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self._server_notices_mxid = hs.config.server_notices_mxid
self._hs_disabled = hs.config.hs_disabled
self._hs_disabled_message = hs.config.hs_disabled_message
self._admin_contact = hs.config.admin_contact
self._max_mau_value = hs.config.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
@defer.inlineCallbacks
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
threepid(dict|None): If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
user_type(str|None): If present, is used to decide whether to check against
certain blocking reasons like MAU.
"""
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
if user_id is not None:
if user_id == self._server_notices_mxid:
return
if (yield self.store.is_support_user(user_id)):
return
if self._hs_disabled:
raise ResourceLimitError(
403,
self._hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=self._admin_contact,
limit_type=LimitBlockingTypes.HS_DISABLED,
)
if self._limit_usage_by_mau is True:
assert not (user_id and threepid)
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
# If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check
if is_threepid_reserved(self._mau_limits_reserved_threepids, threepid):
return
elif user_type == UserTypes.SUPPORT:
# If the user does not exist yet and is of type "support",
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self._max_mau_value:
raise ResourceLimitError(
403,
"Monthly Active User Limit Exceeded",
admin_contact=self._admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
)

View File

@ -1143,10 +1143,14 @@ class SyncHandler(object):
user_id user_id
) )
tracked_users = set(users_who_share_room) # Always tell the user about their own devices. We check as the user
# ID is almost certainly already included (unless they're not in any
# rooms) and taking a copy of the set is relatively expensive.
if user_id not in users_who_share_room:
users_who_share_room = set(users_who_share_room)
users_who_share_room.add(user_id)
# Always tell the user about their own devices tracked_users = users_who_share_room
tracked_users.add(user_id)
# Step 1a, check for changes in devices of users we share a room with # Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = await self.store.get_users_whose_devices_changed( users_that_have_changed = await self.store.get_users_whose_devices_changed(

View File

@ -81,9 +81,6 @@ class ReplicationCommandHandler:
self._instance_id = hs.get_instance_id() self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
# Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str]
self._streams = { self._streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values() stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream] } # type: Dict[str, Stream]
@ -99,9 +96,13 @@ class ReplicationCommandHandler:
# The factory used to create connections. # The factory used to create connections.
self._factory = None # type: Optional[ReconnectingClientFactory] self._factory = None # type: Optional[ReconnectingClientFactory]
# The currently connected connections. # The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection] self._connections = [] # type: List[AbstractConnection]
# For each connection, the incoming streams that are coming from that connection
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge( LaterGauge(
"synapse_replication_tcp_resource_total_connections", "synapse_replication_tcp_resource_total_connections",
"", "",
@ -257,9 +258,11 @@ class ReplicationCommandHandler:
# 2. so we don't race with getting a POSITION command and fetching # 2. so we don't race with getting a POSITION command and fetching
# missing RDATA. # missing RDATA.
with await self._position_linearizer.queue(cmd.stream_name): with await self._position_linearizer.queue(cmd.stream_name):
if stream_name not in self._streams_connected: # make sure that we've processed a POSITION for this stream *on this
# If the stream isn't marked as connected then we haven't seen a # connection*. (A POSITION on another connection is no good, as there
# `POSITION` command yet, and so we may have missed some rows. # is no guarantee that we have seen all the intermediate updates.)
sbc = self._streams_by_connection.get(conn)
if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a # Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then. # `POSITION` soon and we'll catch up correctly then.
logger.debug( logger.debug(
@ -302,21 +305,25 @@ class ReplicationCommandHandler:
# Ignore POSITION that are just our own echoes # Ignore POSITION that are just our own echoes
return return
stream = self._streams.get(cmd.stream_name) logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
stream_name = cmd.stream_name
stream = self._streams.get(stream_name)
if not stream: if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name) logger.error("Got POSITION for unknown stream: %s", stream_name)
return return
# We protect catching up with a linearizer in case the replication # We protect catching up with a linearizer in case the replication
# connection reconnects under us. # connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name): with await self._position_linearizer.queue(stream_name):
# We're about to go and catch up with the stream, so remove from set # We're about to go and catch up with the stream, so remove from set
# of connected streams. # of connected streams.
self._streams_connected.discard(cmd.stream_name) for streams in self._streams_by_connection.values():
streams.discard(stream_name)
# We clear the pending batches for the stream as the fetching of the # We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch. # missing updates below will fetch all rows in the batch.
self._pending_batches.pop(cmd.stream_name, []) self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to. # Find where we previously streamed up to.
current_token = stream.current_token() current_token = stream.current_token()
@ -326,6 +333,12 @@ class ReplicationCommandHandler:
# between then and now. # between then and now.
missing_updates = cmd.token != current_token missing_updates = cmd.token != current_token
while missing_updates: while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
cmd.token,
)
( (
updates, updates,
current_token, current_token,
@ -341,16 +354,18 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates): for token, rows in _batch_updates(updates):
await self.on_rdata( await self.on_rdata(
cmd.stream_name, stream_name,
cmd.instance_name, cmd.instance_name,
token, token,
[stream.parse_row(row) for row in rows], [stream.parse_row(row) for row in rows],
) )
# We've now caught up to position sent to us, notify handler. logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
self._streams_connected.add(cmd.stream_name) # We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(stream_name, cmd.token)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP( async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand self, conn: AbstractConnection, cmd: RemoteServerUpCommand
@ -408,6 +423,12 @@ class ReplicationCommandHandler:
def lost_connection(self, connection: AbstractConnection): def lost_connection(self, connection: AbstractConnection):
"""Called when a connection is closed/lost. """Called when a connection is closed/lost.
""" """
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
if streams:
logger.info(
"Lost replication connection; streams now disconnected: %s", streams
)
try: try:
self._connections.remove(connection) self._connections.remove(connection)
except ValueError: except ValueError:

View File

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
import txredisapi import txredisapi
from synapse.logging.context import PreserveLoggingContext from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import ( from synapse.replication.tcp.commands import (
Command, Command,
@ -41,8 +41,14 @@ logger = logging.getLogger(__name__)
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection): class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
"""Connection to redis subscribed to replication stream. """Connection to redis subscribed to replication stream.
Parses incoming messages from redis into replication commands, and passes This class fulfils two functions:
them to `ReplicationCommandHandler`
(a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`
(b) it implements the AbstractConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom Due to the vagaries of `txredisapi` we don't want to have a custom
constructor, so instead we expect the defined attributes below to be set constructor, so instead we expect the defined attributes below to be set
@ -50,8 +56,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
Attributes: Attributes:
handler: The command handler to handle incoming commands. handler: The command handler to handle incoming commands.
stream_name: The *redis* stream name to subscribe to (not anything to stream_name: The *redis* stream name to subscribe to and publish from
do with Synapse replication streams). (not anything to do with Synapse replication streams).
outbound_redis_connection: The connection to redis to use to send outbound_redis_connection: The connection to redis to use to send
commands. commands.
""" """
@ -61,13 +67,23 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
outbound_redis_connection = None # type: txredisapi.RedisProtocol outbound_redis_connection = None # type: txredisapi.RedisProtocol
def connectionMade(self): def connectionMade(self):
logger.info("Connected to redis")
super().connectionMade() super().connectionMade()
logger.info("Connected to redis instance") run_as_background_process("subscribe-replication", self._send_subscribe)
self.subscribe(self.stream_name)
self.send_command(ReplicateCommand())
self.handler.new_connection(self) self.handler.new_connection(self)
async def _send_subscribe(self):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
await make_deferred_yieldable(self.subscribe(self.stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
def messageReceived(self, pattern: str, channel: str, message: str): def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis. """Received a message from redis.
""" """
@ -120,8 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.warning("Unhandled command: %r", cmd) logger.warning("Unhandled command: %r", cmd)
def connectionLost(self, reason): def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason) super().connectionLost(reason)
logger.info("Lost connection to redis instance")
self.handler.lost_connection(self) self.handler.lost_connection(self)
def send_command(self, cmd: Command): def send_command(self, cmd: Command):
@ -130,6 +146,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
Args: Args:
cmd (Command) cmd (Command)
""" """
run_as_background_process("send-cmd", self._async_send_command, cmd)
async def _async_send_command(self, cmd: Command):
"""Encode a replication command and send it over our outbound connection"""
string = "%s %s" % (cmd.NAME, cmd.to_line()) string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string: if "\n" in string:
raise Exception("Unexpected newline in command: %r", string) raise Exception("Unexpected newline in command: %r", string)
@ -140,16 +160,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances. # remote instances.
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
async def _send(): await make_deferred_yieldable(
with PreserveLoggingContext(): self.outbound_redis_connection.publish(self.stream_name, encoded_string)
# Note that we use the other connection as we can't send
# commands using the subscription connection.
await self.outbound_redis_connection.publish(
self.stream_name, encoded_string
) )
run_as_background_process("send-cmd", _send)
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory): class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
"""This is a reconnecting factory that connects to redis and immediately """This is a reconnecting factory that connects to redis and immediately

View File

@ -536,8 +536,8 @@ class DeviceWorkerStore(SQLBaseStore):
# Get set of users who *may* have changed. Users not in the returned # Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed. # list have definitely not changed.
to_check = list( to_check = self._device_list_stream_cache.get_entities_changed(
self._device_list_stream_cache.get_entities_changed(user_ids, from_key) user_ids, from_key
) )
if not to_check: if not to_check:

View File

@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Iterable, List, Mapping, Optional, Set from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
from six import integer_types from six import integer_types
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from synapse.types import Collection
from synapse.util import caches from synapse.util import caches
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,8 +86,8 @@ class StreamChangeCache:
return False return False
def get_entities_changed( def get_entities_changed(
self, entities: Iterable[EntityType], stream_pos: int self, entities: Collection[EntityType], stream_pos: int
) -> Set[EntityType]: ) -> Union[Set[EntityType], FrozenSet[EntityType]]:
""" """
Returns subset of entities that have had new things since the given Returns subset of entities that have had new things since the given
position. Entities unknown to the cache will be returned. If the position. Entities unknown to the cache will be returned. If the
@ -94,7 +95,17 @@ class StreamChangeCache:
""" """
changed_entities = self.get_all_entities_changed(stream_pos) changed_entities = self.get_all_entities_changed(stream_pos)
if changed_entities is not None: if changed_entities is not None:
# We now do an intersection, trying to do so in the most efficient
# way possible (some of these sets are *large*). First check in the
# given iterable is already set that we can reuse, otherwise we
# create a set of the *smallest* of the two iterables and call
# `intersection(..)` on it (this can be twice as fast as the reverse).
if isinstance(entities, (set, frozenset)):
result = entities.intersection(changed_entities)
elif len(changed_entities) < len(entities):
result = set(changed_entities).intersection(entities) result = set(changed_entities).intersection(entities)
else:
result = set(entities).intersection(changed_entities)
self.metrics.inc_hits() self.metrics.inc_hits()
else: else:
result = set(entities) result = set(entities)

View File

@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = TestHandlers(self.hs) self.hs.handlers = TestHandlers(self.hs)
self.auth = Auth(self.hs) self.auth = Auth(self.hs)
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.auth._auth_blocking
self.test_user = "@foo:bar" self.test_user = "@foo:bar"
self.test_token = b"_test_token_" self.test_token = b"_test_token_"
@ -321,15 +325,15 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_blocking_mau(self): def test_blocking_mau(self):
self.hs.config.limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
self.hs.config.max_mau_value = 50 self.auth_blocking._max_mau_value = 50
lots_of_users = 100 lots_of_users = 100
small_number_of_users = 1 small_number_of_users = 1
# Ensure no error thrown # Ensure no error thrown
yield defer.ensureDeferred(self.auth.check_auth_blocking()) yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock( self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users) return_value=defer.succeed(lots_of_users)
@ -349,8 +353,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self): def test_blocking_mau__depending_on_user_type(self):
self.hs.config.max_mau_value = 50 self.auth_blocking._max_mau_value = 50
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed # Support users allowed
@ -370,12 +374,12 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_reserved_threepid(self): def test_reserved_threepid(self):
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.config.max_mau_value = 1 self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2) self.store.get_monthly_active_count = lambda: defer.succeed(2)
threepid = {"medium": "email", "address": "reserved@server.com"} threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.hs.config.mau_limits_reserved_threepids = [threepid] self.auth_blocking._mau_limits_reserved_threepids = [threepid]
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(self.auth.check_auth_blocking()) yield defer.ensureDeferred(self.auth.check_auth_blocking())
@ -389,8 +393,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_hs_disabled(self): def test_hs_disabled(self):
self.hs.config.hs_disabled = True self.auth_blocking._hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled" self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e: with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking()) yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
@ -404,10 +408,10 @@ class AuthTestCase(unittest.TestCase):
""" """
# this should be the default, but we had a bug where the test was doing the wrong # this should be the default, but we had a bug where the test was doing the wrong
# thing, so let's make it explicit # thing, so let's make it explicit
self.hs.config.server_notices_mxid = None self.auth_blocking._server_notices_mxid = None
self.hs.config.hs_disabled = True self.auth_blocking._hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled" self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e: with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking()) yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
@ -416,8 +420,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_server_notices_mxid_special_cased(self): def test_server_notices_mxid_special_cased(self):
self.hs.config.hs_disabled = True self.auth_blocking._hs_disabled = True
user = "@user:server" user = "@user:server"
self.hs.config.server_notices_mxid = user self.auth_blocking._server_notices_mxid = user
self.hs.config.hs_disabled_message = "Reason for being disabled" self.auth_blocking._hs_disabled_message = "Reason for being disabled"
yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) yield defer.ensureDeferred(self.auth.check_auth_blocking(user))

View File

@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = AuthHandlers(self.hs) self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator() self.macaroon_generator = self.hs.get_macaroon_generator()
# MAU tests # MAU tests
self.hs.config.max_mau_value = 50 # AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth()._auth_blocking
self.auth_blocking._max_mau_value = 50
self.small_number_of_users = 1 self.small_number_of_users = 1
self.large_number_of_users = 100 self.large_number_of_users = 100
@ -119,7 +124,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_disabled(self): def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception # Ensure does not throw exception
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
@ -135,7 +140,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_exceeded_large(self): def test_mau_limits_exceeded_large(self):
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users) return_value=defer.succeed(self.large_number_of_users)
) )
@ -159,11 +164,11 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_parity(self): def test_mau_limits_parity(self):
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
# If not in monthly active cohort # If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.auth_blocking._max_mau_value)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -173,7 +178,7 @@ class AuthTestCase(unittest.TestCase):
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.auth_blocking._max_mau_value)
) )
with self.assertRaises(ResourceLimitError): with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -186,7 +191,7 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.get_clock().time_msec()) return_value=defer.succeed(self.hs.get_clock().time_msec())
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.auth_blocking._max_mau_value)
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
@ -197,7 +202,7 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.get_clock().time_msec()) return_value=defer.succeed(self.hs.get_clock().time_msec())
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.auth_blocking._max_mau_value)
) )
yield defer.ensureDeferred( yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(
@ -207,7 +212,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_mau_limits_not_exceeded(self): def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users) return_value=defer.succeed(self.small_number_of_users)

View File

@ -30,28 +30,31 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler = self.hs.get_sync_handler() self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
def test_wait_for_sync_for_user_auth_blocking(self): # AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth()._auth_blocking
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test" user_id1 = "@user1:test"
user_id2 = "@user2:test" user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1) sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time self.reactor.advance(100) # So we get not 0 time
self.hs.config.limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.config.max_mau_value = 1 self.auth_blocking._max_mau_value = 1
# Check that the happy case does not throw errors # Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1)) self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works # Test that global lock works
self.hs.config.hs_disabled = True self.auth_blocking._hs_disabled = True
e = self.get_failure( e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
) )
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2) sync_config = self._generate_sync_config(user_id2)

View File

@ -19,6 +19,7 @@ import json
from mock import Mock from mock import Mock
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.rest.client.v2_alpha import register, sync from synapse.rest.client.v2_alpha import register, sync
@ -45,11 +46,17 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
self.hs.config.hs_disabled = False self.hs.config.hs_disabled = False
self.hs.config.max_mau_value = 2 self.hs.config.max_mau_value = 2
self.hs.config.mau_trial_days = 0
self.hs.config.server_notices_mxid = "@server:red" self.hs.config.server_notices_mxid = "@server:red"
self.hs.config.server_notices_mxid_display_name = None self.hs.config.server_notices_mxid_display_name = None
self.hs.config.server_notices_mxid_avatar_url = None self.hs.config.server_notices_mxid_avatar_url = None
self.hs.config.server_notices_room_name = "Test Server Notice Room" self.hs.config.server_notices_room_name = "Test Server Notice Room"
self.hs.config.mau_trial_days = 0
# AuthBlocking reads config options during hs creation. Recreate the
# hs' copy of AuthBlocking after we've updated config values above
self.auth_blocking = AuthBlocking(self.hs)
self.hs.get_auth()._auth_blocking = self.auth_blocking
return self.hs return self.hs
def test_simple_deny_mau(self): def test_simple_deny_mau(self):
@ -121,6 +128,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_trial_users_cant_come_back(self): def test_trial_users_cant_come_back(self):
self.auth_blocking._mau_trial_days = 1
self.hs.config.mau_trial_days = 1 self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially # We should be able to register more than the limit initially
@ -169,8 +177,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_tracked_but_not_limited(self): def test_tracked_but_not_limited(self):
self.hs.config.max_mau_value = 1 # should not matter self.auth_blocking._max_mau_value = 1 # should not matter
self.hs.config.limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
self.hs.config.mau_stats_only = True self.hs.config.mau_stats_only = True
# Simply being able to create 2 users indicates that the # Simply being able to create 2 users indicates that the