Move register_device into handler

This commit is contained in:
Erik Johnston 2019-02-18 16:49:38 +00:00
parent 8b9ae6d3a6
commit af691e415c
5 changed files with 97 additions and 172 deletions

View File

@ -27,6 +27,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import ReplicationRegisterServlet from synapse.replication.http.register import ReplicationRegisterServlet
from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -64,6 +65,11 @@ class RegistrationHandler(BaseHandler):
if hs.config.worker_app: if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = (
RegisterDeviceReplicationServlet.make_client(hs)
)
else:
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None, def check_username(self, localpart, guest_access_token=None,
@ -159,7 +165,7 @@ class RegistrationHandler(BaseHandler):
yield self.auth.check_auth_blocking(threepid=threepid) yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None password_hash = None
if password: if password:
password_hash = yield self.auth_handler().hash(password) password_hash = yield self._auth_handler.hash(password)
if localpart: if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token) yield self.check_username(localpart, guest_access_token=guest_access_token)
@ -516,9 +522,6 @@ class RegistrationHandler(BaseHandler):
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
def auth_handler(self):
return self.hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_register_3pid_guest(self, medium, address, inviter_user_id): def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
"""Get a guest access token for a 3PID, creating a guest account if """Get a guest access token for a 3PID, creating a guest account if
@ -628,3 +631,43 @@ class RegistrationHandler(BaseHandler):
admin=admin, admin=admin,
user_type=user_type, user_type=user_type,
) )
@defer.inlineCallbacks
def register_device(self, user_id, device_id, initial_display_name,
is_guest=False):
"""Register a device for a user and generate an access token.
Args:
user_id (str): full canonical @user:id
device_id (str|None): The device ID to check, or None to generate
a new one.
initial_display_name (str|None): An optional display name for the
device.
is_guest (bool): Whether this is a guest account
Returns:
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
"""
if self.hs.config.worker_app:
r = yield self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
)
defer.returnValue((r["device_id"], r["access_token"]))
else:
device_id = yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
else:
access_token = yield self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
)
defer.returnValue((device_id, access_token))

View File

@ -35,9 +35,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
def __init__(self, hs): def __init__(self, hs):
super(RegisterDeviceReplicationServlet, self).__init__(hs) super(RegisterDeviceReplicationServlet, self).__init__(hs)
self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@staticmethod @staticmethod
def _serialize_payload(user_id, device_id, initial_display_name, is_guest): def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
@ -62,17 +60,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
initial_display_name = content["initial_display_name"] initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"] is_guest = content["is_guest"]
device_id = yield self.device_handler.check_device_registered( device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name, user_id, device_id, initial_display_name, is_guest,
)
if is_guest:
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
else:
access_token = yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
) )
defer.returnValue((200, { defer.returnValue((200, {

View File

@ -94,7 +94,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_algorithm = hs.config.jwt_algorithm self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler() self.registration_handler = hs.get_handlers().registration_handler
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self._well_known_builder = WellKnownBuilder(hs) self._well_known_builder = WellKnownBuilder(hs)
@ -220,11 +220,10 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission, login_submission,
) )
device_id = yield self._register_device( device_id = login_submission.get("device_id")
canonical_user_id, login_submission, initial_display_name = login_submission.get("initial_device_display_name")
) device_id, access_token = yield self.registration_handler.register_device(
access_token = yield auth_handler.get_access_token_for_user_id( canonical_user_id, device_id, initial_display_name,
canonical_user_id, device_id,
) )
result = { result = {
@ -246,10 +245,13 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = ( user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id( device_id = login_submission.get("device_id")
user_id, device_id, initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name,
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
@ -286,11 +288,10 @@ class LoginRestServlet(ClientV1RestServlet):
auth_handler = self.auth_handler auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id) registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id: if registered_user_id:
device_id = yield self._register_device( device_id = login_submission.get("device_id")
registered_user_id, login_submission initial_display_name = login_submission.get("initial_device_display_name")
) device_id, access_token = yield self.registration_handler.register_device(
access_token = yield auth_handler.get_access_token_for_user_id( registered_user_id, device_id, initial_display_name,
registered_user_id, device_id,
) )
result = { result = {
@ -299,12 +300,16 @@ class LoginRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
else: else:
# TODO: we should probably check that the register isn't going
# to fonx/change our user_id before registering the device
device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = ( user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user) yield self.handlers.registration_handler.register(localpart=user)
) )
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
registered_user_id, device_id, initial_display_name,
)
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
@ -313,26 +318,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue(result) defer.returnValue(result)
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
This is called after the user's credentials have been validated, but
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) login_submission: dictionary supplied to /login call, from
which we pull device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get(
"initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
class CasRedirectServlet(RestServlet): class CasRedirectServlet(RestServlet):
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect") PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")

View File

@ -33,7 +33,6 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
@ -193,13 +192,6 @@ class RegisterRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
if self.hs.config.worker_app:
self._register_device_client = (
RegisterDeviceReplicationServlet.make_client(hs)
)
else:
self.device_handler = hs.get_device_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -642,7 +634,7 @@ class RegisterRestServlet(RestServlet):
if not params.get("inhibit_login", False): if not params.get("inhibit_login", False):
device_id = params.get("device_id") device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
device_id, access_token = yield self._register_device( device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=False, user_id, device_id, initial_display_name, is_guest=False,
) )
@ -652,43 +644,6 @@ class RegisterRestServlet(RestServlet):
}) })
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def _register_device(self, user_id, device_id, initial_display_name,
is_guest):
"""Register a device for a user and generate an access token.
Args:
user_id (str): full canonical @user:id
device_id (str|None): The device ID to check, or None to generate
a new one.
initial_display_name (str|None): An optional display name for the
device.
is_guest (bool): Whether this is a guest account
Returns:
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
"""
if self.hs.config.worker_app:
r = yield self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
)
defer.returnValue((r["device_id"], r["access_token"]))
else:
device_id = yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
else:
access_token = yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
)
defer.returnValue((device_id, access_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self, params): def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
@ -702,7 +657,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it. # we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
device_id, access_token = yield self._register_device( device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=True, user_id, device_id, initial_display_name, is_guest=True,
) )

View File

@ -1,10 +1,7 @@
import json import json
from mock import Mock from synapse.api.constants import LoginType
from synapse.appservice import ApplicationService
from twisted.python import failure
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.rest.client.v2_alpha.register import register_servlets from synapse.rest.client.v2_alpha.register import register_servlets
from tests import unittest from tests import unittest
@ -18,61 +15,28 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.url = b"/_matrix/client/r0/register" self.url = b"/_matrix/client/r0/register"
self.appservice = None
self.auth = Mock(
get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
)
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None),
)
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
self.device_handler = Mock()
def check_device_registered(user_id, device_id, initial_display_name):
# Just echo back the given device ID, or return a new "FAKE" device
# ID
if device_id:
return device_id
else:
return "FAKE"
self.device_handler.check_device_registered = Mock(
side_effect=check_device_registered,
)
self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=[])
# do the dance to hook it up to the hs global
self.handlers = Mock(
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler,
)
self.hs = self.setup_test_homeserver() self.hs = self.setup_test_homeserver()
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.config.enable_registration = True self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = [] self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = [] self.hs.config.auto_join_rooms = []
self.hs.config.enable_registration_captcha = False
return self.hs return self.hs
def test_POST_appservice_registration_valid(self): def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet" user_id = "@as_user_kermit:test"
token = "kermits_access_token" as_token = "i_am_an_app_service"
self.appservice = {"id": "1234"}
self.registration_handler.appservice_register = Mock(return_value=user_id) appservice = ApplicationService(
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) as_token, self.hs.config.hostname,
request_data = json.dumps({"username": "kermit"}) id="1234",
namespaces={
"users": [{"regex": r"@as_user.*", "exclusive": True}],
},
)
self.hs.get_datastore().services_cache.append(appservice)
request_data = json.dumps({"username": "as_user_kermit"})
request, channel = self.make_request( request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
@ -82,7 +46,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
@ -114,37 +77,30 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["error"], "Invalid username") self.assertEquals(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self): def test_POST_user_valid(self):
user_id = "@kermit:muppet" user_id = "@kermit:test"
token = "kermits_access_token"
device_id = "frogfone" device_id = "frogfone"
params = {"username": "kermit", "password": "monkey", "device_id": device_id} params = {
"username": "kermit",
"password": "monkey",
"device_id": device_id,
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params) request_data = json.dumps(params)
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, params, None)
self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
request, channel = self.make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request) self.render(request)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None
)
def test_POST_disabled_registration(self): def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"}) request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
request, channel = self.make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request) self.render(request)
@ -153,16 +109,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["error"], "Registration has been disabled") self.assertEquals(channel.json_body["error"], "Registration has been disabled")
def test_POST_guest_registration(self): def test_POST_guest_registration(self):
user_id = "a@b"
self.hs.config.macaroon_secret_key = "test" self.hs.config.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True self.hs.config.allow_guest_access = True
self.registration_handler.register = Mock(return_value=(user_id, None))
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request) self.render(request)
det_data = { det_data = {
"user_id": user_id,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": "guest_device", "device_id": "guest_device",
} }