From af691e415c3247b912137227a06a68d4c4356586 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 18 Feb 2019 16:49:38 +0000 Subject: [PATCH] Move register_device into handler --- synapse/handlers/register.py | 51 ++++++++++- synapse/replication/http/login.py | 17 +--- synapse/rest/client/v1/login.py | 59 +++++-------- synapse/rest/client/v2_alpha/register.py | 49 +---------- tests/rest/client/v2_alpha/test_register.py | 93 +++++---------------- 5 files changed, 97 insertions(+), 172 deletions(-) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 8ea557a00..f92ab4d52 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -27,6 +27,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.http.client import CaptchaServerHttpClient +from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ReplicationRegisterServlet from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.util.async_helpers import Linearizer @@ -64,6 +65,11 @@ class RegistrationHandler(BaseHandler): if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) + self._register_device_client = ( + RegisterDeviceReplicationServlet.make_client(hs) + ) + else: + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def check_username(self, localpart, guest_access_token=None, @@ -159,7 +165,7 @@ class RegistrationHandler(BaseHandler): yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None if password: - password_hash = yield self.auth_handler().hash(password) + password_hash = yield self._auth_handler.hash(password) if localpart: yield self.check_username(localpart, guest_access_token=guest_access_token) @@ -516,9 +522,6 @@ class RegistrationHandler(BaseHandler): defer.returnValue((user_id, token)) - def auth_handler(self): - return self.hs.get_auth_handler() - @defer.inlineCallbacks def get_or_register_3pid_guest(self, medium, address, inviter_user_id): """Get a guest access token for a 3PID, creating a guest account if @@ -628,3 +631,43 @@ class RegistrationHandler(BaseHandler): admin=admin, user_type=user_type, ) + + @defer.inlineCallbacks + def register_device(self, user_id, device_id, initial_display_name, + is_guest=False): + """Register a device for a user and generate an access token. + + Args: + user_id (str): full canonical @user:id + device_id (str|None): The device ID to check, or None to generate + a new one. + initial_display_name (str|None): An optional display name for the + device. + is_guest (bool): Whether this is a guest account + + Returns: + defer.Deferred[tuple[str, str]]: Tuple of device ID and access token + """ + + if self.hs.config.worker_app: + r = yield self._register_device_client( + user_id=user_id, + device_id=device_id, + initial_display_name=initial_display_name, + is_guest=is_guest, + ) + defer.returnValue((r["device_id"], r["access_token"])) + else: + device_id = yield self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + if is_guest: + access_token = self.macaroon_gen.generate_access_token( + user_id, ["guest = true"] + ) + else: + access_token = yield self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, + ) + + defer.returnValue((device_id, access_token)) diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py index 797f6aabd..1590eca31 100644 --- a/synapse/replication/http/login.py +++ b/synapse/replication/http/login.py @@ -35,9 +35,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): def __init__(self, hs): super(RegisterDeviceReplicationServlet, self).__init__(hs) - self.auth_handler = hs.get_auth_handler() - self.device_handler = hs.get_device_handler() - self.macaroon_gen = hs.get_macaroon_generator() + self.registration_handler = hs.get_handlers().registration_handler @staticmethod def _serialize_payload(user_id, device_id, initial_display_name, is_guest): @@ -62,19 +60,10 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): initial_display_name = content["initial_display_name"] is_guest = content["is_guest"] - device_id = yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name, + device_id, access_token = yield self.registration_handler.register_device( + user_id, device_id, initial_display_name, is_guest, ) - if is_guest: - access_token = self.macaroon_gen.generate_access_token( - user_id, ["guest = true"] - ) - else: - access_token = yield self.auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, - ) - defer.returnValue((200, { "device_id": device_id, "access_token": access_token, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 942e4d381..4a5775083 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -94,7 +94,7 @@ class LoginRestServlet(ClientV1RestServlet): self.jwt_algorithm = hs.config.jwt_algorithm self.cas_enabled = hs.config.cas_enabled self.auth_handler = self.hs.get_auth_handler() - self.device_handler = self.hs.get_device_handler() + self.registration_handler = hs.get_handlers().registration_handler self.handlers = hs.get_handlers() self._well_known_builder = WellKnownBuilder(hs) @@ -220,11 +220,10 @@ class LoginRestServlet(ClientV1RestServlet): login_submission, ) - device_id = yield self._register_device( - canonical_user_id, login_submission, - ) - access_token = yield auth_handler.get_access_token_for_user_id( - canonical_user_id, device_id, + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = yield self.registration_handler.register_device( + canonical_user_id, device_id, initial_display_name, ) result = { @@ -246,10 +245,13 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) - device_id = yield self._register_device(user_id, login_submission) - access_token = yield auth_handler.get_access_token_for_user_id( - user_id, device_id, + + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = yield self.registration_handler.register_device( + user_id, device_id, initial_display_name, ) + result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -286,11 +288,10 @@ class LoginRestServlet(ClientV1RestServlet): auth_handler = self.auth_handler registered_user_id = yield auth_handler.check_user_exists(user_id) if registered_user_id: - device_id = yield self._register_device( - registered_user_id, login_submission - ) - access_token = yield auth_handler.get_access_token_for_user_id( - registered_user_id, device_id, + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = yield self.registration_handler.register_device( + registered_user_id, device_id, initial_display_name, ) result = { @@ -299,12 +300,16 @@ class LoginRestServlet(ClientV1RestServlet): "home_server": self.hs.hostname, } else: - # TODO: we should probably check that the register isn't going - # to fonx/change our user_id before registering the device - device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) + + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get("initial_device_display_name") + device_id, access_token = yield self.registration_handler.register_device( + registered_user_id, device_id, initial_display_name, + ) + result = { "user_id": user_id, # may have changed "access_token": access_token, @@ -313,26 +318,6 @@ class LoginRestServlet(ClientV1RestServlet): defer.returnValue(result) - def _register_device(self, user_id, login_submission): - """Register a device for a user. - - This is called after the user's credentials have been validated, but - before the access token has been issued. - - Args: - (str) user_id: full canonical @user:id - (object) login_submission: dictionary supplied to /login call, from - which we pull device_id and initial_device_name - Returns: - defer.Deferred: (str) device_id - """ - device_id = login_submission.get("device_id") - initial_display_name = login_submission.get( - "initial_device_display_name") - return self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) - class CasRedirectServlet(RestServlet): PATTERNS = client_path_patterns("/login/(cas|sso)/redirect") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index c52280c50..c1cdb8f9c 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -33,7 +33,6 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.threepids import check_3pid_allowed @@ -193,13 +192,6 @@ class RegisterRestServlet(RestServlet): self.room_member_handler = hs.get_room_member_handler() self.macaroon_gen = hs.get_macaroon_generator() - if self.hs.config.worker_app: - self._register_device_client = ( - RegisterDeviceReplicationServlet.make_client(hs) - ) - else: - self.device_handler = hs.get_device_handler() - @interactive_auth_handler @defer.inlineCallbacks def on_POST(self, request): @@ -642,7 +634,7 @@ class RegisterRestServlet(RestServlet): if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self._register_device( + device_id, access_token = yield self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=False, ) @@ -652,43 +644,6 @@ class RegisterRestServlet(RestServlet): }) defer.returnValue(result) - @defer.inlineCallbacks - def _register_device(self, user_id, device_id, initial_display_name, - is_guest): - """Register a device for a user and generate an access token. - - Args: - user_id (str): full canonical @user:id - device_id (str|None): The device ID to check, or None to generate - a new one. - initial_display_name (str|None): An optional display name for the - device. - is_guest (bool): Whether this is a guest account - Returns: - defer.Deferred[tuple[str, str]]: Tuple of device ID and access token - """ - if self.hs.config.worker_app: - r = yield self._register_device_client( - user_id=user_id, - device_id=device_id, - initial_display_name=initial_display_name, - is_guest=is_guest, - ) - defer.returnValue((r["device_id"], r["access_token"])) - else: - device_id = yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) - if is_guest: - access_token = self.macaroon_gen.generate_access_token( - user_id, ["guest = true"] - ) - else: - access_token = yield self.auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, - ) - defer.returnValue((device_id, access_token)) - @defer.inlineCallbacks def _do_guest_registration(self, params): if not self.hs.config.allow_guest_access: @@ -702,7 +657,7 @@ class RegisterRestServlet(RestServlet): # we have nowhere to store it. device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") - device_id, access_token = yield self._register_device( + device_id, access_token = yield self.registration_handler.register_device( user_id, device_id, initial_display_name, is_guest=True, ) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 18080ebfd..906b348d3 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,10 +1,7 @@ import json -from mock import Mock - -from twisted.python import failure - -from synapse.api.errors import InteractiveAuthIncompleteError +from synapse.api.constants import LoginType +from synapse.appservice import ApplicationService from synapse.rest.client.v2_alpha.register import register_servlets from tests import unittest @@ -18,61 +15,28 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.url = b"/_matrix/client/r0/register" - self.appservice = None - self.auth = Mock( - get_appservice_by_req=Mock(side_effect=lambda x: self.appservice) - ) - - self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None)) - self.auth_handler = Mock( - check_auth=Mock(side_effect=lambda x, y, z: self.auth_result), - get_session_data=Mock(return_value=None), - ) - self.registration_handler = Mock() - self.identity_handler = Mock() - self.login_handler = Mock() - self.device_handler = Mock() - - def check_device_registered(user_id, device_id, initial_display_name): - # Just echo back the given device ID, or return a new "FAKE" device - # ID - if device_id: - return device_id - else: - return "FAKE" - - self.device_handler.check_device_registered = Mock( - side_effect=check_device_registered, - ) - - self.datastore = Mock(return_value=Mock()) - self.datastore.get_current_state_deltas = Mock(return_value=[]) - - # do the dance to hook it up to the hs global - self.handlers = Mock( - registration_handler=self.registration_handler, - identity_handler=self.identity_handler, - login_handler=self.login_handler, - ) self.hs = self.setup_test_homeserver() - self.hs.get_auth = Mock(return_value=self.auth) - self.hs.get_handlers = Mock(return_value=self.handlers) - self.hs.get_auth_handler = Mock(return_value=self.auth_handler) - self.hs.get_device_handler = Mock(return_value=self.device_handler) - self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] self.hs.config.auto_join_rooms = [] + self.hs.config.enable_registration_captcha = False return self.hs def test_POST_appservice_registration_valid(self): - user_id = "@kermit:muppet" - token = "kermits_access_token" - self.appservice = {"id": "1234"} - self.registration_handler.appservice_register = Mock(return_value=user_id) - self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) - request_data = json.dumps({"username": "kermit"}) + user_id = "@as_user_kermit:test" + as_token = "i_am_an_app_service" + + appservice = ApplicationService( + as_token, self.hs.config.hostname, + id="1234", + namespaces={ + "users": [{"regex": r"@as_user.*", "exclusive": True}], + }, + ) + + self.hs.get_datastore().services_cache.append(appservice) + request_data = json.dumps({"username": "as_user_kermit"}) request, channel = self.make_request( b"POST", self.url + b"?access_token=i_am_an_app_service", request_data @@ -82,7 +46,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.result["code"], b"200", channel.result) det_data = { "user_id": user_id, - "access_token": token, "home_server": self.hs.hostname, } self.assertDictContainsSubset(det_data, channel.json_body) @@ -114,37 +77,30 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["error"], "Invalid username") def test_POST_user_valid(self): - user_id = "@kermit:muppet" - token = "kermits_access_token" + user_id = "@kermit:test" device_id = "frogfone" - params = {"username": "kermit", "password": "monkey", "device_id": device_id} + params = { + "username": "kermit", + "password": "monkey", + "device_id": device_id, + "auth": {"type": LoginType.DUMMY}, + } request_data = json.dumps(params) - self.registration_handler.check_username = Mock(return_value=True) - self.auth_result = (None, params, None) - self.registration_handler.register = Mock(return_value=(user_id, None)) - self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) - request, channel = self.make_request(b"POST", self.url, request_data) self.render(request) det_data = { "user_id": user_id, - "access_token": token, "home_server": self.hs.hostname, "device_id": device_id, } self.assertEquals(channel.result["code"], b"200", channel.result) self.assertDictContainsSubset(det_data, channel.json_body) - self.auth_handler.get_login_tuple_for_user_id( - user_id, device_id=device_id, initial_device_display_name=None - ) def test_POST_disabled_registration(self): self.hs.config.enable_registration = False request_data = json.dumps({"username": "kermit", "password": "monkey"}) - self.registration_handler.check_username = Mock(return_value=True) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) - self.registration_handler.register = Mock(return_value=("@user:id", "t")) request, channel = self.make_request(b"POST", self.url, request_data) self.render(request) @@ -153,16 +109,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["error"], "Registration has been disabled") def test_POST_guest_registration(self): - user_id = "a@b" self.hs.config.macaroon_secret_key = "test" self.hs.config.allow_guest_access = True - self.registration_handler.register = Mock(return_value=(user_id, None)) request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") self.render(request) det_data = { - "user_id": user_id, "home_server": self.hs.hostname, "device_id": "guest_device", }