mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Create user with expiry
- Add unittests for client, api and handler Signed-off-by: Negar Fazeli <negar.fazeli@ericsson.com>
This commit is contained in:
parent
ae1af262f6
commit
40aa6e8349
@ -612,7 +612,8 @@ class Auth(object):
|
|||||||
def get_user_from_macaroon(self, macaroon_str):
|
def get_user_from_macaroon(self, macaroon_str):
|
||||||
try:
|
try:
|
||||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||||
self.validate_macaroon(macaroon, "access", False)
|
|
||||||
|
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
|
||||||
|
|
||||||
user_prefix = "user_id = "
|
user_prefix = "user_id = "
|
||||||
user = None
|
user = None
|
||||||
|
@ -57,6 +57,8 @@ class KeyConfig(Config):
|
|||||||
seed = self.signing_key[0].seed
|
seed = self.signing_key[0].seed
|
||||||
self.macaroon_secret_key = hashlib.sha256(seed)
|
self.macaroon_secret_key = hashlib.sha256(seed)
|
||||||
|
|
||||||
|
self.expire_access_token = config.get("expire_access_token", False)
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name, is_generating_file=False,
|
def default_config(self, config_dir_path, server_name, is_generating_file=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
base_key_name = os.path.join(config_dir_path, server_name)
|
base_key_name = os.path.join(config_dir_path, server_name)
|
||||||
@ -69,6 +71,9 @@ class KeyConfig(Config):
|
|||||||
return """\
|
return """\
|
||||||
macaroon_secret_key: "%(macaroon_secret_key)s"
|
macaroon_secret_key: "%(macaroon_secret_key)s"
|
||||||
|
|
||||||
|
# Used to enable access token expiration.
|
||||||
|
expire_access_token: False
|
||||||
|
|
||||||
## Signing Keys ##
|
## Signing Keys ##
|
||||||
|
|
||||||
# Path to the signing key to sign messages with
|
# Path to the signing key to sign messages with
|
||||||
|
@ -32,6 +32,7 @@ class RegistrationConfig(Config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||||
|
self.user_creation_max_duration = int(config["user_creation_max_duration"])
|
||||||
|
|
||||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||||
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
|
||||||
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
|
|||||||
# secret, even if registration is otherwise disabled.
|
# secret, even if registration is otherwise disabled.
|
||||||
registration_shared_secret: "%(registration_shared_secret)s"
|
registration_shared_secret: "%(registration_shared_secret)s"
|
||||||
|
|
||||||
|
# Sets the expiry for the short term user creation in
|
||||||
|
# milliseconds. For instance the bellow duration is two weeks
|
||||||
|
# in milliseconds.
|
||||||
|
user_creation_max_duration: 1209600000
|
||||||
|
|
||||||
# Set the number of bcrypt rounds used to generate password hash.
|
# Set the number of bcrypt rounds used to generate password hash.
|
||||||
# Larger numbers increase the work factor needed to generate the hash.
|
# Larger numbers increase the work factor needed to generate the hash.
|
||||||
# The default number of rounds is 12.
|
# The default number of rounds is 12.
|
||||||
|
@ -521,11 +521,11 @@ class AuthHandler(BaseHandler):
|
|||||||
))
|
))
|
||||||
return m.serialize()
|
return m.serialize()
|
||||||
|
|
||||||
def generate_short_term_login_token(self, user_id):
|
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = login")
|
macaroon.add_first_party_caveat("type = login")
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
expiry = now + (2 * 60 * 1000)
|
expiry = now + duration_in_ms
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
|
@ -358,6 +358,59 @@ class RegistrationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_or_create_user(self, localpart, displayname, duration_seconds):
|
||||||
|
"""Creates a new user or returns an access token for an existing one
|
||||||
|
|
||||||
|
Args:
|
||||||
|
localpart : The local part of the user ID to register. If None,
|
||||||
|
one will be randomly generated.
|
||||||
|
Returns:
|
||||||
|
A tuple of (user_id, access_token).
|
||||||
|
Raises:
|
||||||
|
RegistrationError if there was a problem registering.
|
||||||
|
"""
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
if localpart is None:
|
||||||
|
raise SynapseError(400, "Request must include user id")
|
||||||
|
|
||||||
|
need_register = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.check_username(localpart)
|
||||||
|
except SynapseError as e:
|
||||||
|
if e.errcode == Codes.USER_IN_USE:
|
||||||
|
need_register = False
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
user = UserID(localpart, self.hs.hostname)
|
||||||
|
user_id = user.to_string()
|
||||||
|
auth_handler = self.hs.get_handlers().auth_handler
|
||||||
|
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
|
||||||
|
|
||||||
|
if need_register:
|
||||||
|
yield self.store.register(
|
||||||
|
user_id=user_id,
|
||||||
|
token=token,
|
||||||
|
password_hash=None
|
||||||
|
)
|
||||||
|
|
||||||
|
yield registered_user(self.distributor, user)
|
||||||
|
else:
|
||||||
|
yield self.store.flush_user(user_id=user_id)
|
||||||
|
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||||
|
|
||||||
|
if displayname is not None:
|
||||||
|
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||||
|
profile_handler = self.hs.get_handlers().profile_handler
|
||||||
|
yield profile_handler.set_displayname(
|
||||||
|
user, user, displayname
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((user_id, token))
|
||||||
|
|
||||||
def auth_handler(self):
|
def auth_handler(self):
|
||||||
return self.hs.get_handlers().auth_handler
|
return self.hs.get_handlers().auth_handler
|
||||||
|
|
||||||
|
@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateUserRestServlet(ClientV1RestServlet):
|
||||||
|
"""Handles user creation via a server-to-server interface
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = client_path_patterns("/createUser$", releases=())
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(CreateUserRestServlet, self).__init__(hs)
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
user_json = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
if "access_token" not in request.args:
|
||||||
|
raise SynapseError(400, "Expected application service token.")
|
||||||
|
|
||||||
|
app_service = yield self.store.get_app_service_by_token(
|
||||||
|
request.args["access_token"][0]
|
||||||
|
)
|
||||||
|
if not app_service:
|
||||||
|
raise SynapseError(403, "Invalid application service token.")
|
||||||
|
|
||||||
|
logger.debug("creating user: %s", user_json)
|
||||||
|
|
||||||
|
response = yield self._do_create(user_json)
|
||||||
|
|
||||||
|
defer.returnValue((200, response))
|
||||||
|
|
||||||
|
def on_OPTIONS(self, request):
|
||||||
|
return 403, {}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_create(self, user_json):
|
||||||
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
if "localpart" not in user_json:
|
||||||
|
raise SynapseError(400, "Expected 'localpart' key.")
|
||||||
|
|
||||||
|
if "displayname" not in user_json:
|
||||||
|
raise SynapseError(400, "Expected 'displayname' key.")
|
||||||
|
|
||||||
|
if "duration_seconds" not in user_json:
|
||||||
|
raise SynapseError(400, "Expected 'duration_seconds' key.")
|
||||||
|
|
||||||
|
localpart = user_json["localpart"].encode("utf-8")
|
||||||
|
displayname = user_json["displayname"].encode("utf-8")
|
||||||
|
duration_seconds = 0
|
||||||
|
try:
|
||||||
|
duration_seconds = int(user_json["duration_seconds"])
|
||||||
|
except ValueError:
|
||||||
|
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
||||||
|
if duration_seconds > self.direct_user_creation_max_duration:
|
||||||
|
duration_seconds = self.direct_user_creation_max_duration
|
||||||
|
|
||||||
|
handler = self.handlers.registration_handler
|
||||||
|
user_id, token = yield handler.get_or_create_user(
|
||||||
|
localpart=localpart,
|
||||||
|
displayname=displayname,
|
||||||
|
duration_seconds=duration_seconds
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
RegisterRestServlet(hs).register(http_server)
|
RegisterRestServlet(hs).register(http_server)
|
||||||
|
CreateUserRestServlet(hs).register(http_server)
|
||||||
|
@ -284,12 +284,12 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
macaroon.add_first_party_caveat("time < 1") # ms
|
macaroon.add_first_party_caveat("time < 1") # ms
|
||||||
|
|
||||||
self.hs.clock.now = 5000 # seconds
|
self.hs.clock.now = 5000 # seconds
|
||||||
|
self.hs.config.expire_access_token = True
|
||||||
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||||
# TODO(daniel): Turn on the check that we validate expiration, when we
|
# TODO(daniel): Turn on the check that we validate expiration, when we
|
||||||
# validate expiration (and remove the above line, which will start
|
# validate expiration (and remove the above line, which will start
|
||||||
# throwing).
|
# throwing).
|
||||||
# with self.assertRaises(AuthError) as cm:
|
with self.assertRaises(AuthError) as cm:
|
||||||
# yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
yield self.auth.get_user_from_macaroon(macaroon.serialize())
|
||||||
# self.assertEqual(401, cm.exception.code)
|
self.assertEqual(401, cm.exception.code)
|
||||||
# self.assertIn("Invalid macaroon", cm.exception.msg)
|
self.assertIn("Invalid macaroon", cm.exception.msg)
|
||||||
|
67
tests/handlers/test_register.py
Normal file
67
tests/handlers/test_register.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
from .. import unittest
|
||||||
|
|
||||||
|
from synapse.handlers.register import RegistrationHandler
|
||||||
|
|
||||||
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
from mock import Mock
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationHandlers(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.registration_handler = RegistrationHandler(hs)
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationTestCase(unittest.TestCase):
|
||||||
|
""" Tests the RegistrationHandler. """
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def setUp(self):
|
||||||
|
self.mock_distributor = Mock()
|
||||||
|
self.mock_distributor.declare("registered_user")
|
||||||
|
self.mock_captcha_client = Mock()
|
||||||
|
hs = yield setup_test_homeserver(
|
||||||
|
handlers=None,
|
||||||
|
http_client=None,
|
||||||
|
expire_access_token=True)
|
||||||
|
hs.handlers = RegistrationHandlers(hs)
|
||||||
|
self.handler = hs.get_handlers().registration_handler
|
||||||
|
hs.get_handlers().profile_handler = Mock()
|
||||||
|
self.mock_handler = Mock(spec=[
|
||||||
|
"generate_short_term_login_token",
|
||||||
|
])
|
||||||
|
|
||||||
|
hs.get_handlers().auth_handler = self.mock_handler
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
The user doess not exist in this case so it will register and log it in
|
||||||
|
"""
|
||||||
|
duration_ms = 200
|
||||||
|
local_part = "someone"
|
||||||
|
display_name = "someone"
|
||||||
|
user_id = "@someone:test"
|
||||||
|
mock_token = self.mock_handler.generate_short_term_login_token
|
||||||
|
mock_token.return_value = 'secret'
|
||||||
|
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||||
|
local_part, display_name, duration_ms)
|
||||||
|
self.assertEquals(result_user_id, user_id)
|
||||||
|
self.assertEquals(result_token, 'secret')
|
88
tests/rest/client/v1/test_register.py
Normal file
88
tests/rest/client/v1/test_register.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.rest.client.v1.register import CreateUserRestServlet
|
||||||
|
from twisted.internet import defer
|
||||||
|
from mock import Mock
|
||||||
|
from tests import unittest
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class CreateUserServletTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
# do the dance to hook up request data to self.request_data
|
||||||
|
self.request_data = ""
|
||||||
|
self.request = Mock(
|
||||||
|
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
|
||||||
|
path='/_matrix/client/api/v1/createUser'
|
||||||
|
)
|
||||||
|
self.request.args = {}
|
||||||
|
|
||||||
|
self.appservice = None
|
||||||
|
self.auth = Mock(get_appservice_by_req=Mock(
|
||||||
|
side_effect=lambda x: defer.succeed(self.appservice))
|
||||||
|
)
|
||||||
|
|
||||||
|
self.auth_result = (False, None, None, None)
|
||||||
|
self.auth_handler = Mock(
|
||||||
|
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||||
|
get_session_data=Mock(return_value=None)
|
||||||
|
)
|
||||||
|
self.registration_handler = Mock()
|
||||||
|
self.identity_handler = Mock()
|
||||||
|
self.login_handler = Mock()
|
||||||
|
|
||||||
|
# do the dance to hook it up to the hs global
|
||||||
|
self.handlers = Mock(
|
||||||
|
auth_handler=self.auth_handler,
|
||||||
|
registration_handler=self.registration_handler,
|
||||||
|
identity_handler=self.identity_handler,
|
||||||
|
login_handler=self.login_handler
|
||||||
|
)
|
||||||
|
self.hs = Mock()
|
||||||
|
self.hs.hostname = "supergbig~testing~thing.com"
|
||||||
|
self.hs.get_auth = Mock(return_value=self.auth)
|
||||||
|
self.hs.get_handlers = Mock(return_value=self.handlers)
|
||||||
|
self.hs.config.enable_registration = True
|
||||||
|
# init the thing we're testing
|
||||||
|
self.servlet = CreateUserRestServlet(self.hs)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_POST_createuser_with_valid_user(self):
|
||||||
|
user_id = "@someone:interesting"
|
||||||
|
token = "my token"
|
||||||
|
self.request.args = {
|
||||||
|
"access_token": "i_am_an_app_service"
|
||||||
|
}
|
||||||
|
self.request_data = json.dumps({
|
||||||
|
"localpart": "someone",
|
||||||
|
"displayname": "someone interesting",
|
||||||
|
"duration_seconds": 200
|
||||||
|
})
|
||||||
|
|
||||||
|
self.registration_handler.get_or_create_user = Mock(
|
||||||
|
return_value=(user_id, token)
|
||||||
|
)
|
||||||
|
|
||||||
|
(code, result) = yield self.servlet.on_POST(self.request)
|
||||||
|
self.assertEquals(code, 200)
|
||||||
|
|
||||||
|
det_data = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": token,
|
||||||
|
"home_server": self.hs.hostname
|
||||||
|
}
|
||||||
|
self.assertDictContainsSubset(det_data, result)
|
@ -49,6 +49,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|||||||
config.event_cache_size = 1
|
config.event_cache_size = 1
|
||||||
config.enable_registration = True
|
config.enable_registration = True
|
||||||
config.macaroon_secret_key = "not even a little secret"
|
config.macaroon_secret_key = "not even a little secret"
|
||||||
|
config.expire_access_token = False
|
||||||
config.server_name = "server.under.test"
|
config.server_name = "server.under.test"
|
||||||
config.trusted_third_party_id_servers = []
|
config.trusted_third_party_id_servers = []
|
||||||
config.room_invite_state_types = []
|
config.room_invite_state_types = []
|
||||||
|
Loading…
Reference in New Issue
Block a user