Merge remote-tracking branch 'origin/erikj/as_mau_block' into develop

This commit is contained in:
Erik Johnston 2020-12-18 09:51:56 +00:00
commit a7a913918c
7 changed files with 86 additions and 9 deletions

1
changelog.d/8962.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where application services couldn't register new ghost users if the server had reached its MAU limit.

View File

@ -36,6 +36,7 @@ class AuthBlocking:
self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
self._server_name = hs.hostname self._server_name = hs.hostname
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
async def check_auth_blocking( async def check_auth_blocking(
self, self,
@ -76,6 +77,12 @@ class AuthBlocking:
# We never block the server from doing actions on behalf of # We never block the server from doing actions on behalf of
# users. # users.
return return
elif requester.app_service and not self._track_appservice_user_ips:
# If we're authenticated as an appservice then we only block
# auth if `track_appservice_user_ips` is set, as that option
# implicitly means that application services are part of MAU
# limits.
return
# Never fail an auth check for the server notices users or support user # Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking # This can be a problem where event creation is prohibited due to blocking

View File

@ -738,6 +738,7 @@ class AuthHandler(BaseHandler):
device_id: Optional[str], device_id: Optional[str],
valid_until_ms: Optional[int], valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None, puppets_user_id: Optional[str] = None,
is_appservice_ghost: bool = False,
) -> str: ) -> str:
""" """
Creates a new access token for the user with the given user ID. Creates a new access token for the user with the given user ID.
@ -754,6 +755,7 @@ class AuthHandler(BaseHandler):
we should always have a device ID) we should always have a device ID)
valid_until_ms: when the token is valid until. None for valid_until_ms: when the token is valid until. None for
no expiry. no expiry.
is_appservice_ghost: Whether the user is an application ghost user
Returns: Returns:
The access token for the user's session. The access token for the user's session.
Raises: Raises:
@ -774,6 +776,10 @@ class AuthHandler(BaseHandler):
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
) )
if (
not is_appservice_ghost
or self.hs.config.appservice.track_appservice_user_ips
):
await self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id) access_token = self.macaroon_gen.generate_access_token(user_id)

View File

@ -630,6 +630,7 @@ class RegistrationHandler(BaseHandler):
device_id: Optional[str], device_id: Optional[str],
initial_display_name: Optional[str], initial_display_name: Optional[str],
is_guest: bool = False, is_guest: bool = False,
is_appservice_ghost: bool = False,
) -> Tuple[str, str]: ) -> Tuple[str, str]:
"""Register a device for a user and generate an access token. """Register a device for a user and generate an access token.
@ -651,6 +652,7 @@ class RegistrationHandler(BaseHandler):
device_id=device_id, device_id=device_id,
initial_display_name=initial_display_name, initial_display_name=initial_display_name,
is_guest=is_guest, is_guest=is_guest,
is_appservice_ghost=is_appservice_ghost,
) )
return r["device_id"], r["access_token"] return r["device_id"], r["access_token"]
@ -672,7 +674,10 @@ class RegistrationHandler(BaseHandler):
) )
else: else:
access_token = await self._auth_handler.get_access_token_for_user_id( access_token = await self._auth_handler.get_access_token_for_user_id(
user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms user_id,
device_id=registered_device_id,
valid_until_ms=valid_until_ms,
is_appservice_ghost=is_appservice_ghost,
) )
return (registered_device_id, access_token) return (registered_device_id, access_token)

View File

@ -36,7 +36,9 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@staticmethod @staticmethod
async def _serialize_payload(user_id, device_id, initial_display_name, is_guest): async def _serialize_payload(
user_id, device_id, initial_display_name, is_guest, is_appservice_ghost
):
""" """
Args: Args:
device_id (str|None): Device ID to use, if None a new one is device_id (str|None): Device ID to use, if None a new one is
@ -48,6 +50,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"device_id": device_id, "device_id": device_id,
"initial_display_name": initial_display_name, "initial_display_name": initial_display_name,
"is_guest": is_guest, "is_guest": is_guest,
"is_appservice_ghost": is_appservice_ghost,
} }
async def _handle_request(self, request, user_id): async def _handle_request(self, request, user_id):
@ -56,9 +59,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
device_id = content["device_id"] device_id = content["device_id"]
initial_display_name = content["initial_display_name"] initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"] is_guest = content["is_guest"]
is_appservice_ghost = content["is_appservice_ghost"]
device_id, access_token = await self.registration_handler.register_device( device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest user_id,
device_id,
initial_display_name,
is_guest,
is_appservice_ghost=is_appservice_ghost,
) )
return 200, {"device_id": device_id, "access_token": access_token} return 200, {"device_id": device_id, "access_token": access_token}

View File

@ -655,9 +655,13 @@ class RegisterRestServlet(RestServlet):
user_id = await self.registration_handler.appservice_register( user_id = await self.registration_handler.appservice_register(
username, as_token username, as_token
) )
return await self._create_registration_details(user_id, body) return await self._create_registration_details(
user_id, body, is_appservice_ghost=True,
)
async def _create_registration_details(self, user_id, params): async def _create_registration_details(
self, user_id, params, is_appservice_ghost=False
):
"""Complete registration of newly-registered user """Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token. Allocates device_id if one was not given; also creates access_token.
@ -674,7 +678,11 @@ class RegisterRestServlet(RestServlet):
device_id = params.get("device_id") device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device( device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=False user_id,
device_id,
initial_display_name,
is_guest=False,
is_appservice_ghost=is_appservice_ghost,
) )
result.update({"access_token": access_token, "device_id": device_id}) result.update({"access_token": access_token, "device_id": device_id})

View File

@ -19,6 +19,7 @@ import json
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.appservice import ApplicationService
from synapse.rest.client.v2_alpha import register, sync from synapse.rest.client.v2_alpha import register, sync
from tests import unittest from tests import unittest
@ -75,6 +76,45 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.code, 403) self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_as_ignores_mau(self):
"""Test that application services can still create users when the MAU
limit has been reached. This only works when application service
user ip tracking is disabled.
"""
# Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1")
self.do_sync_for_user(token1)
token2 = self.create_user("kermit2")
self.do_sync_for_user(token2)
# check we're testing what we think we are: there should be two active users
self.assertEqual(self.get_success(self.store.get_monthly_active_count()), 2)
# We've created and activated two users, we shouldn't be able to
# register new users
with self.assertRaises(SynapseError) as cm:
self.create_user("kermit3")
e = cm.exception
self.assertEqual(e.code, 403)
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
# Cheekily add an application service that we use to register a new user
# with.
as_token = "foobartoken"
self.store.services_cache.append(
ApplicationService(
token=as_token,
hostname=self.hs.hostname,
id="SomeASID",
sender="@as_sender:test",
namespaces={"users": [{"regex": "@as_*", "exclusive": True}]},
)
)
self.create_user("as_kermit4", token=as_token)
def test_allowed_after_a_month_mau(self): def test_allowed_after_a_month_mau(self):
# Create and sync so that the MAU counts get updated # Create and sync so that the MAU counts get updated
token1 = self.create_user("kermit1") token1 = self.create_user("kermit1")
@ -192,7 +232,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.reactor.advance(100) self.reactor.advance(100)
self.assertEqual(2, self.successResultOf(count)) self.assertEqual(2, self.successResultOf(count))
def create_user(self, localpart): def create_user(self, localpart, token=None):
request_data = json.dumps( request_data = json.dumps(
{ {
"username": localpart, "username": localpart,
@ -201,7 +241,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
} }
) )
channel = self.make_request("POST", "/register", request_data) channel = self.make_request(
"POST", "/register", request_data, access_token=token,
)
if channel.code != 200: if channel.code != 200:
raise HttpResponseException( raise HttpResponseException(