mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-07 16:27:53 -05:00
Move register_device into handler
This commit is contained in:
parent
8b9ae6d3a6
commit
af691e415c
@ -27,6 +27,7 @@ from synapse.api.errors import (
|
|||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.http.client import CaptchaServerHttpClient
|
from synapse.http.client import CaptchaServerHttpClient
|
||||||
|
from synapse.replication.http.login import RegisterDeviceReplicationServlet
|
||||||
from synapse.replication.http.register import ReplicationRegisterServlet
|
from synapse.replication.http.register import ReplicationRegisterServlet
|
||||||
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
from synapse.types import RoomAlias, RoomID, UserID, create_requester
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
@ -64,6 +65,11 @@ class RegistrationHandler(BaseHandler):
|
|||||||
|
|
||||||
if hs.config.worker_app:
|
if hs.config.worker_app:
|
||||||
self._register_client = ReplicationRegisterServlet.make_client(hs)
|
self._register_client = ReplicationRegisterServlet.make_client(hs)
|
||||||
|
self._register_device_client = (
|
||||||
|
RegisterDeviceReplicationServlet.make_client(hs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_username(self, localpart, guest_access_token=None,
|
def check_username(self, localpart, guest_access_token=None,
|
||||||
@ -159,7 +165,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
yield self.auth.check_auth_blocking(threepid=threepid)
|
yield self.auth.check_auth_blocking(threepid=threepid)
|
||||||
password_hash = None
|
password_hash = None
|
||||||
if password:
|
if password:
|
||||||
password_hash = yield self.auth_handler().hash(password)
|
password_hash = yield self._auth_handler.hash(password)
|
||||||
|
|
||||||
if localpart:
|
if localpart:
|
||||||
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
||||||
@ -516,9 +522,6 @@ class RegistrationHandler(BaseHandler):
|
|||||||
|
|
||||||
defer.returnValue((user_id, token))
|
defer.returnValue((user_id, token))
|
||||||
|
|
||||||
def auth_handler(self):
|
|
||||||
return self.hs.get_auth_handler()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
|
def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
|
||||||
"""Get a guest access token for a 3PID, creating a guest account if
|
"""Get a guest access token for a 3PID, creating a guest account if
|
||||||
@ -628,3 +631,43 @@ class RegistrationHandler(BaseHandler):
|
|||||||
admin=admin,
|
admin=admin,
|
||||||
user_type=user_type,
|
user_type=user_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def register_device(self, user_id, device_id, initial_display_name,
|
||||||
|
is_guest=False):
|
||||||
|
"""Register a device for a user and generate an access token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): full canonical @user:id
|
||||||
|
device_id (str|None): The device ID to check, or None to generate
|
||||||
|
a new one.
|
||||||
|
initial_display_name (str|None): An optional display name for the
|
||||||
|
device.
|
||||||
|
is_guest (bool): Whether this is a guest account
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.hs.config.worker_app:
|
||||||
|
r = yield self._register_device_client(
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
initial_display_name=initial_display_name,
|
||||||
|
is_guest=is_guest,
|
||||||
|
)
|
||||||
|
defer.returnValue((r["device_id"], r["access_token"]))
|
||||||
|
else:
|
||||||
|
device_id = yield self.device_handler.check_device_registered(
|
||||||
|
user_id, device_id, initial_display_name
|
||||||
|
)
|
||||||
|
if is_guest:
|
||||||
|
access_token = self.macaroon_gen.generate_access_token(
|
||||||
|
user_id, ["guest = true"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
access_token = yield self._auth_handler.get_access_token_for_user_id(
|
||||||
|
user_id, device_id=device_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((device_id, access_token))
|
||||||
|
@ -35,9 +35,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(RegisterDeviceReplicationServlet, self).__init__(hs)
|
super(RegisterDeviceReplicationServlet, self).__init__(hs)
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
self.device_handler = hs.get_device_handler()
|
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
|
def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
|
||||||
@ -62,17 +60,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
|
|||||||
initial_display_name = content["initial_display_name"]
|
initial_display_name = content["initial_display_name"]
|
||||||
is_guest = content["is_guest"]
|
is_guest = content["is_guest"]
|
||||||
|
|
||||||
device_id = yield self.device_handler.check_device_registered(
|
device_id, access_token = yield self.registration_handler.register_device(
|
||||||
user_id, device_id, initial_display_name,
|
user_id, device_id, initial_display_name, is_guest,
|
||||||
)
|
|
||||||
|
|
||||||
if is_guest:
|
|
||||||
access_token = self.macaroon_gen.generate_access_token(
|
|
||||||
user_id, ["guest = true"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
access_token = yield self.auth_handler.get_access_token_for_user_id(
|
|
||||||
user_id, device_id=device_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {
|
defer.returnValue((200, {
|
||||||
|
@ -94,7 +94,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
self.jwt_algorithm = hs.config.jwt_algorithm
|
self.jwt_algorithm = hs.config.jwt_algorithm
|
||||||
self.cas_enabled = hs.config.cas_enabled
|
self.cas_enabled = hs.config.cas_enabled
|
||||||
self.auth_handler = self.hs.get_auth_handler()
|
self.auth_handler = self.hs.get_auth_handler()
|
||||||
self.device_handler = self.hs.get_device_handler()
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
self._well_known_builder = WellKnownBuilder(hs)
|
self._well_known_builder = WellKnownBuilder(hs)
|
||||||
|
|
||||||
@ -220,11 +220,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
login_submission,
|
login_submission,
|
||||||
)
|
)
|
||||||
|
|
||||||
device_id = yield self._register_device(
|
device_id = login_submission.get("device_id")
|
||||||
canonical_user_id, login_submission,
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
)
|
device_id, access_token = yield self.registration_handler.register_device(
|
||||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
canonical_user_id, device_id, initial_display_name,
|
||||||
canonical_user_id, device_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
@ -246,10 +245,13 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
user_id = (
|
user_id = (
|
||||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||||
)
|
)
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
|
||||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
device_id = login_submission.get("device_id")
|
||||||
user_id, device_id,
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
|
device_id, access_token = yield self.registration_handler.register_device(
|
||||||
|
user_id, device_id, initial_display_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
@ -286,11 +288,10 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||||
if registered_user_id:
|
if registered_user_id:
|
||||||
device_id = yield self._register_device(
|
device_id = login_submission.get("device_id")
|
||||||
registered_user_id, login_submission
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
)
|
device_id, access_token = yield self.registration_handler.register_device(
|
||||||
access_token = yield auth_handler.get_access_token_for_user_id(
|
registered_user_id, device_id, initial_display_name,
|
||||||
registered_user_id, device_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
@ -299,12 +300,16 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# TODO: we should probably check that the register isn't going
|
|
||||||
# to fonx/change our user_id before registering the device
|
|
||||||
device_id = yield self._register_device(user_id, login_submission)
|
|
||||||
user_id, access_token = (
|
user_id, access_token = (
|
||||||
yield self.handlers.registration_handler.register(localpart=user)
|
yield self.handlers.registration_handler.register(localpart=user)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device_id = login_submission.get("device_id")
|
||||||
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
|
device_id, access_token = yield self.registration_handler.register_device(
|
||||||
|
registered_user_id, device_id, initial_display_name,
|
||||||
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id, # may have changed
|
"user_id": user_id, # may have changed
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
@ -313,26 +318,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def _register_device(self, user_id, login_submission):
|
|
||||||
"""Register a device for a user.
|
|
||||||
|
|
||||||
This is called after the user's credentials have been validated, but
|
|
||||||
before the access token has been issued.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
(str) user_id: full canonical @user:id
|
|
||||||
(object) login_submission: dictionary supplied to /login call, from
|
|
||||||
which we pull device_id and initial_device_name
|
|
||||||
Returns:
|
|
||||||
defer.Deferred: (str) device_id
|
|
||||||
"""
|
|
||||||
device_id = login_submission.get("device_id")
|
|
||||||
initial_display_name = login_submission.get(
|
|
||||||
"initial_device_display_name")
|
|
||||||
return self.device_handler.check_device_registered(
|
|
||||||
user_id, device_id, initial_display_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CasRedirectServlet(RestServlet):
|
class CasRedirectServlet(RestServlet):
|
||||||
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")
|
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")
|
||||||
|
@ -33,7 +33,6 @@ from synapse.http.servlet import (
|
|||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
parse_string,
|
parse_string,
|
||||||
)
|
)
|
||||||
from synapse.replication.http.login import RegisterDeviceReplicationServlet
|
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
from synapse.util.threepids import check_3pid_allowed
|
from synapse.util.threepids import check_3pid_allowed
|
||||||
@ -193,13 +192,6 @@ class RegisterRestServlet(RestServlet):
|
|||||||
self.room_member_handler = hs.get_room_member_handler()
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
if self.hs.config.worker_app:
|
|
||||||
self._register_device_client = (
|
|
||||||
RegisterDeviceReplicationServlet.make_client(hs)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.device_handler = hs.get_device_handler()
|
|
||||||
|
|
||||||
@interactive_auth_handler
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
@ -642,7 +634,7 @@ class RegisterRestServlet(RestServlet):
|
|||||||
if not params.get("inhibit_login", False):
|
if not params.get("inhibit_login", False):
|
||||||
device_id = params.get("device_id")
|
device_id = params.get("device_id")
|
||||||
initial_display_name = params.get("initial_device_display_name")
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
device_id, access_token = yield self._register_device(
|
device_id, access_token = yield self.registration_handler.register_device(
|
||||||
user_id, device_id, initial_display_name, is_guest=False,
|
user_id, device_id, initial_display_name, is_guest=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -652,43 +644,6 @@ class RegisterRestServlet(RestServlet):
|
|||||||
})
|
})
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _register_device(self, user_id, device_id, initial_display_name,
|
|
||||||
is_guest):
|
|
||||||
"""Register a device for a user and generate an access token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id (str): full canonical @user:id
|
|
||||||
device_id (str|None): The device ID to check, or None to generate
|
|
||||||
a new one.
|
|
||||||
initial_display_name (str|None): An optional display name for the
|
|
||||||
device.
|
|
||||||
is_guest (bool): Whether this is a guest account
|
|
||||||
Returns:
|
|
||||||
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
|
|
||||||
"""
|
|
||||||
if self.hs.config.worker_app:
|
|
||||||
r = yield self._register_device_client(
|
|
||||||
user_id=user_id,
|
|
||||||
device_id=device_id,
|
|
||||||
initial_display_name=initial_display_name,
|
|
||||||
is_guest=is_guest,
|
|
||||||
)
|
|
||||||
defer.returnValue((r["device_id"], r["access_token"]))
|
|
||||||
else:
|
|
||||||
device_id = yield self.device_handler.check_device_registered(
|
|
||||||
user_id, device_id, initial_display_name
|
|
||||||
)
|
|
||||||
if is_guest:
|
|
||||||
access_token = self.macaroon_gen.generate_access_token(
|
|
||||||
user_id, ["guest = true"]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
access_token = yield self.auth_handler.get_access_token_for_user_id(
|
|
||||||
user_id, device_id=device_id,
|
|
||||||
)
|
|
||||||
defer.returnValue((device_id, access_token))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_guest_registration(self, params):
|
def _do_guest_registration(self, params):
|
||||||
if not self.hs.config.allow_guest_access:
|
if not self.hs.config.allow_guest_access:
|
||||||
@ -702,7 +657,7 @@ class RegisterRestServlet(RestServlet):
|
|||||||
# we have nowhere to store it.
|
# we have nowhere to store it.
|
||||||
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
||||||
initial_display_name = params.get("initial_device_display_name")
|
initial_display_name = params.get("initial_device_display_name")
|
||||||
device_id, access_token = yield self._register_device(
|
device_id, access_token = yield self.registration_handler.register_device(
|
||||||
user_id, device_id, initial_display_name, is_guest=True,
|
user_id, device_id, initial_display_name, is_guest=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from mock import Mock
|
from synapse.api.constants import LoginType
|
||||||
|
from synapse.appservice import ApplicationService
|
||||||
from twisted.python import failure
|
|
||||||
|
|
||||||
from synapse.api.errors import InteractiveAuthIncompleteError
|
|
||||||
from synapse.rest.client.v2_alpha.register import register_servlets
|
from synapse.rest.client.v2_alpha.register import register_servlets
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@ -18,61 +15,28 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.url = b"/_matrix/client/r0/register"
|
self.url = b"/_matrix/client/r0/register"
|
||||||
|
|
||||||
self.appservice = None
|
|
||||||
self.auth = Mock(
|
|
||||||
get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
|
|
||||||
self.auth_handler = Mock(
|
|
||||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
|
||||||
get_session_data=Mock(return_value=None),
|
|
||||||
)
|
|
||||||
self.registration_handler = Mock()
|
|
||||||
self.identity_handler = Mock()
|
|
||||||
self.login_handler = Mock()
|
|
||||||
self.device_handler = Mock()
|
|
||||||
|
|
||||||
def check_device_registered(user_id, device_id, initial_display_name):
|
|
||||||
# Just echo back the given device ID, or return a new "FAKE" device
|
|
||||||
# ID
|
|
||||||
if device_id:
|
|
||||||
return device_id
|
|
||||||
else:
|
|
||||||
return "FAKE"
|
|
||||||
|
|
||||||
self.device_handler.check_device_registered = Mock(
|
|
||||||
side_effect=check_device_registered,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.datastore = Mock(return_value=Mock())
|
|
||||||
self.datastore.get_current_state_deltas = Mock(return_value=[])
|
|
||||||
|
|
||||||
# do the dance to hook it up to the hs global
|
|
||||||
self.handlers = Mock(
|
|
||||||
registration_handler=self.registration_handler,
|
|
||||||
identity_handler=self.identity_handler,
|
|
||||||
login_handler=self.login_handler,
|
|
||||||
)
|
|
||||||
self.hs = self.setup_test_homeserver()
|
self.hs = self.setup_test_homeserver()
|
||||||
self.hs.get_auth = Mock(return_value=self.auth)
|
|
||||||
self.hs.get_handlers = Mock(return_value=self.handlers)
|
|
||||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
|
||||||
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
|
||||||
self.hs.get_datastore = Mock(return_value=self.datastore)
|
|
||||||
self.hs.config.enable_registration = True
|
self.hs.config.enable_registration = True
|
||||||
self.hs.config.registrations_require_3pid = []
|
self.hs.config.registrations_require_3pid = []
|
||||||
self.hs.config.auto_join_rooms = []
|
self.hs.config.auto_join_rooms = []
|
||||||
|
self.hs.config.enable_registration_captcha = False
|
||||||
|
|
||||||
return self.hs
|
return self.hs
|
||||||
|
|
||||||
def test_POST_appservice_registration_valid(self):
|
def test_POST_appservice_registration_valid(self):
|
||||||
user_id = "@kermit:muppet"
|
user_id = "@as_user_kermit:test"
|
||||||
token = "kermits_access_token"
|
as_token = "i_am_an_app_service"
|
||||||
self.appservice = {"id": "1234"}
|
|
||||||
self.registration_handler.appservice_register = Mock(return_value=user_id)
|
appservice = ApplicationService(
|
||||||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
as_token, self.hs.config.hostname,
|
||||||
request_data = json.dumps({"username": "kermit"})
|
id="1234",
|
||||||
|
namespaces={
|
||||||
|
"users": [{"regex": r"@as_user.*", "exclusive": True}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hs.get_datastore().services_cache.append(appservice)
|
||||||
|
request_data = json.dumps({"username": "as_user_kermit"})
|
||||||
|
|
||||||
request, channel = self.make_request(
|
request, channel = self.make_request(
|
||||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||||
@ -82,7 +46,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
}
|
}
|
||||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||||
@ -114,37 +77,30 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEquals(channel.json_body["error"], "Invalid username")
|
self.assertEquals(channel.json_body["error"], "Invalid username")
|
||||||
|
|
||||||
def test_POST_user_valid(self):
|
def test_POST_user_valid(self):
|
||||||
user_id = "@kermit:muppet"
|
user_id = "@kermit:test"
|
||||||
token = "kermits_access_token"
|
|
||||||
device_id = "frogfone"
|
device_id = "frogfone"
|
||||||
params = {"username": "kermit", "password": "monkey", "device_id": device_id}
|
params = {
|
||||||
|
"username": "kermit",
|
||||||
|
"password": "monkey",
|
||||||
|
"device_id": device_id,
|
||||||
|
"auth": {"type": LoginType.DUMMY},
|
||||||
|
}
|
||||||
request_data = json.dumps(params)
|
request_data = json.dumps(params)
|
||||||
self.registration_handler.check_username = Mock(return_value=True)
|
|
||||||
self.auth_result = (None, params, None)
|
|
||||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
|
||||||
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
|
|
||||||
|
|
||||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"access_token": token,
|
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
}
|
}
|
||||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||||
self.assertDictContainsSubset(det_data, channel.json_body)
|
self.assertDictContainsSubset(det_data, channel.json_body)
|
||||||
self.auth_handler.get_login_tuple_for_user_id(
|
|
||||||
user_id, device_id=device_id, initial_device_display_name=None
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_POST_disabled_registration(self):
|
def test_POST_disabled_registration(self):
|
||||||
self.hs.config.enable_registration = False
|
self.hs.config.enable_registration = False
|
||||||
request_data = json.dumps({"username": "kermit", "password": "monkey"})
|
request_data = json.dumps({"username": "kermit", "password": "monkey"})
|
||||||
self.registration_handler.check_username = Mock(return_value=True)
|
|
||||||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
||||||
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
|
|
||||||
|
|
||||||
request, channel = self.make_request(b"POST", self.url, request_data)
|
request, channel = self.make_request(b"POST", self.url, request_data)
|
||||||
self.render(request)
|
self.render(request)
|
||||||
@ -153,16 +109,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
|
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
|
||||||
|
|
||||||
def test_POST_guest_registration(self):
|
def test_POST_guest_registration(self):
|
||||||
user_id = "a@b"
|
|
||||||
self.hs.config.macaroon_secret_key = "test"
|
self.hs.config.macaroon_secret_key = "test"
|
||||||
self.hs.config.allow_guest_access = True
|
self.hs.config.allow_guest_access = True
|
||||||
self.registration_handler.register = Mock(return_value=(user_id, None))
|
|
||||||
|
|
||||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
self.render(request)
|
self.render(request)
|
||||||
|
|
||||||
det_data = {
|
det_data = {
|
||||||
"user_id": user_id,
|
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": "guest_device",
|
"device_id": "guest_device",
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user