only block on sync where user is not part of the mau cohort

This commit is contained in:
Neil Johnson 2018-08-09 17:39:12 +01:00
parent c6b28fb479
commit 09cf130898
3 changed files with 48 additions and 12 deletions

View File

@ -775,15 +775,24 @@ class Auth(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth_blocking(self): def check_auth_blocking(self, user_id=None):
"""Checks if the user should be rejected for some external reason, """Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag such as monthly active user limiting or global disable flag
Args:
user_id(str): If present, checks for presence against existing MAU cohort
""" """
if self.hs.config.hs_disabled: if self.hs.config.hs_disabled:
raise AuthError( raise AuthError(
403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED
) )
if self.hs.config.limit_usage_by_mau is True: if self.hs.config.limit_usage_by_mau is True:
# If the user is already part of the MAU cohort
if user_id:
timestamp = yield self.store._user_last_seen_monthly_active(user_id)
if timestamp:
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count() current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value: if current_mau >= self.hs.config.max_mau_value:
raise AuthError( raise AuthError(

View File

@ -208,7 +208,12 @@ class SyncHandler(object):
Returns: Returns:
Deferred[SyncResult] Deferred[SyncResult]
""" """
yield self.auth.check_auth_blocking() # If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
yield self.auth.check_auth_blocking(user_id)
res = yield self.response_cache.wrap( res = yield self.response_cache.wrap(
sync_config.request_key, sync_config.request_key,
self._wait_for_sync_for_user, self._wait_for_sync_for_user,

View File

@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes
from synapse.api.errors import AuthError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig, SyncHandler from synapse.handlers.sync import SyncConfig, SyncHandler
from synapse.types import UserID from synapse.types import UserID
@ -31,19 +31,41 @@ class SyncTestCase(tests.unittest.TestCase):
def setUp(self): def setUp(self):
self.hs = yield setup_test_homeserver() self.hs = yield setup_test_homeserver()
self.sync_handler = SyncHandler(self.hs) self.sync_handler = SyncHandler(self.hs)
self.store = self.hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self): def test_wait_for_sync_for_user_auth_blocking(self):
sync_config = SyncConfig(
user=UserID("@user", "server"), user_id1 = "@user1:server"
user_id2 = "@user2:server"
sync_config = self._generate_sync_config(user_id1)
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
# Check that the happy case does not throw errors
yield self.store.upsert_monthly_active_user(user_id1)
yield self.sync_handler.wait_for_sync_for_user(sync_config)
# Test that global lock works
self.hs.config.hs_disabled = True
with self.assertRaises(AuthError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.HS_DISABLED)
self.hs.config.hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
with self.assertRaises(AuthError) as e:
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(
user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
filter_collection=DEFAULT_FILTER_COLLECTION, filter_collection=DEFAULT_FILTER_COLLECTION,
is_guest=False, is_guest=False,
request_key="request_key", request_key="request_key",
device_id="device_id", device_id="device_id",
) )
# Ensure that an exception is not thrown
yield self.sync_handler.wait_for_sync_for_user(sync_config)
self.hs.config.hs_disabled = True
with self.assertRaises(AuthError):
yield self.sync_handler.wait_for_sync_for_user(sync_config)