limit register and sign in on number of monthly users

This commit is contained in:
Neil Johnson 2018-07-30 15:55:57 +01:00
parent e9b2d047f6
commit 251e6c1210
7 changed files with 166 additions and 3 deletions

View File

@ -55,6 +55,7 @@ class Codes(object):
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN" CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM" CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
MAU_LIMIT_EXCEEDED = "M_MAU_LIMIT_EXCEEDED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View File

@ -67,6 +67,11 @@ class ServerConfig(Config):
"block_non_admin_invites", False, "block_non_admin_invites", False,
) )
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = config.get(
"max_mau_value", 0,
)
# FIXME: federation_domain_whitelist needs sytests # FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None self.federation_domain_whitelist = None
federation_domain_whitelist = config.get( federation_domain_whitelist = config.get(

View File

@ -519,6 +519,7 @@ class AuthHandler(BaseHandler):
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
self._check_mau_limits()
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -729,6 +730,7 @@ class AuthHandler(BaseHandler):
defer.returnValue(access_token) defer.returnValue(access_token)
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
self._check_mau_limits()
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
@ -892,6 +894,17 @@ class AuthHandler(BaseHandler):
else: else:
return defer.succeed(False) return defer.succeed(False)
def _check_mau_limits(self):
"""
Ensure that if mau blocking is enabled that invalid users cannot
log in.
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value:
raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
)
@attr.s @attr.s
class MacaroonGenerator(object): class MacaroonGenerator(object):

View File

@ -45,7 +45,7 @@ class RegistrationHandler(BaseHandler):
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
@ -144,6 +144,7 @@ class RegistrationHandler(BaseHandler):
Raises: Raises:
RegistrationError if there was a problem registering. RegistrationError if there was a problem registering.
""" """
self._check_mau_limits()
password_hash = None password_hash = None
if password: if password:
password_hash = yield self.auth_handler().hash(password) password_hash = yield self.auth_handler().hash(password)
@ -288,6 +289,7 @@ class RegistrationHandler(BaseHandler):
400, 400,
"User ID can only contain characters a-z, 0-9, or '=_-./'", "User ID can only contain characters a-z, 0-9, or '=_-./'",
) )
self._check_mau_limits()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -437,7 +439,7 @@ class RegistrationHandler(BaseHandler):
""" """
if localpart is None: if localpart is None:
raise SynapseError(400, "Request must include user id") raise SynapseError(400, "Request must include user id")
self._check_mau_limits()
need_register = True need_register = True
try: try:
@ -531,3 +533,15 @@ class RegistrationHandler(BaseHandler):
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
action="join", action="join",
) )
def _check_mau_limits(self):
"""
Do not accept registrations if monthly active user limits exceeded
and limiting is enabled
"""
if self.hs.config.limit_usage_by_mau is True:
current_mau = self.store.count_monthly_users()
if current_mau >= self.hs.config.max_mau_value:
raise RegistrationError(
403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
)

View File

@ -19,6 +19,7 @@ import logging
import time import time
from dateutil import tz from dateutil import tz
from prometheus_client import Gauge
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.storage.devices import DeviceStore from synapse.storage.devices import DeviceStore
@ -60,6 +61,13 @@ from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerat
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_current_mau", "Current MAU")
max_mau_value_gauge = Gauge("synapse_admin_max_mau_value", "MAU Limit")
limit_usage_by_mau_gauge = Gauge(
"synapse_admin_limit_usage_by_mau", "MAU Limiting enabled"
)
class DataStore(RoomMemberStore, RoomStore, class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, RegistrationStore, StreamStore, ProfileStore,
@ -266,6 +274,32 @@ class DataStore(RoomMemberStore, RoomStore,
return self.runInteraction("count_users", _count_users) return self.runInteraction("count_users", _count_users)
def count_monthly_users(self):
"""
Counts the number of users who used this homeserver in the last 30 days
This method should be refactored with count_daily_users - the only
reason not to is waiting on definition of mau
returns:
int: count of current monthly active users
"""
def _count_monthly_users(txn):
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
sql = """
SELECT COUNT(*) FROM user_ips
WHERE last_seen > ?
"""
txn.execute(sql, (thirty_days_ago,))
count, = txn.fetchone()
self._current_mau = count
current_mau_gauge.set(self._current_mau)
max_mau_value_gauge.set(self.hs.config.max_mau_value)
limit_usage_by_mau_gauge.set(self.hs.config.limit_usage_by_mau)
logger.info("calling mau stats")
return count
return self.runInteraction("count_monthly_users", _count_monthly_users)
def count_r30_users(self): def count_r30_users(self):
""" """
Counts the number of 30 day retained users, defined as:- Counts the number of 30 day retained users, defined as:-

View File

@ -12,15 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock
import pymacaroons import pymacaroons
from twisted.internet import defer from twisted.internet import defer
import synapse import synapse
from synapse.api.errors import AuthError
import synapse.api.errors import synapse.api.errors
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -37,6 +39,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = AuthHandlers(self.hs) self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator() self.macaroon_generator = self.hs.get_macaroon_generator()
# MAU tests
self.hs.config.max_mau_value = 50
self.small_number_of_users = 1
self.large_number_of_users = 100
def test_token_is_a_macaroon(self): def test_token_is_a_macaroon(self):
token = self.macaroon_generator.generate_access_token("some_user") token = self.macaroon_generator.generate_access_token("some_user")
@ -113,3 +119,44 @@ class AuthTestCase(unittest.TestCase):
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize() macaroon.serialize()
) )
@defer.inlineCallbacks
def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
@defer.inlineCallbacks
def test_mau_limits_exceeded(self):
self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock(
return_value=self.large_number_of_users
)
with self.assertRaises(AuthError):
yield self.auth_handler.get_access_token_for_user_id('user_a')
with self.assertRaises(AuthError):
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
@defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True
self.hs.get_datastore().count_monthly_users = Mock(
return_value=self.small_number_of_users
)
# Ensure does not raise exception
yield self.auth_handler.get_access_token_for_user_id('user_a')
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize()
)
def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token(
"user_a", 5000
)
return pymacaroons.Macaroon.deserialize(token)

View File

@ -17,6 +17,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import RegistrationError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID, create_requester from synapse.types import UserID, create_requester
@ -77,3 +78,51 @@ class RegistrationTestCase(unittest.TestCase):
requester, local_part, display_name) requester, local_part, display_name)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@defer.inlineCallbacks
def test_cannot_register_when_mau_limits_exceeded(self):
local_part = "someone"
display_name = "someone"
requester = create_requester("@as:test")
store = self.hs.get_datastore()
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
lots_of_users = 100
small_number_users = 1
store.count_monthly_users = Mock(return_value=lots_of_users)
# Ensure does not throw exception
yield self.handler.get_or_create_user(requester, 'a', display_name)
self.hs.config.limit_usage_by_mau = True
with self.assertRaises(RegistrationError):
yield self.handler.get_or_create_user(requester, 'b', display_name)
store.count_monthly_users = Mock(return_value=small_number_users)
self._macaroon_mock_generator("another_secret")
# Ensure does not throw exception
yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=lots_of_users)
with self.assertRaises(RegistrationError):
yield self.handler.register(localpart=local_part)
self._macaroon_mock_generator("another another secret")
store.count_monthly_users = Mock(return_value=lots_of_users)
with self.assertRaises(RegistrationError):
yield self.handler.register_saml2(local_part)
def _macaroon_mock_generator(self, secret):
"""
Reset macaroon generator in the case where the test creates multiple users
"""
macaroon_generator = Mock(
generate_access_token=Mock(return_value=secret))
self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler