Create user with expiry

- Add unittests for client, api and handler

Signed-off-by: Negar Fazeli <negar.fazeli@ericsson.com>
This commit is contained in:
Negi Fazeli 2016-04-20 16:21:40 +02:00 committed by Negar Fazeli
parent ae1af262f6
commit 40aa6e8349
10 changed files with 301 additions and 9 deletions

View file

@ -612,7 +612,8 @@ class Auth(object):
def get_user_from_macaroon(self, macaroon_str):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, "access", False)
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token)
user_prefix = "user_id = "
user = None

View file

@ -57,6 +57,8 @@ class KeyConfig(Config):
seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed)
self.expire_access_token = config.get("expire_access_token", False)
def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
@ -69,6 +71,9 @@ class KeyConfig(Config):
return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
# Used to enable access token expiration.
expire_access_token: False
## Signing Keys ##
# Path to the signing key to sign messages with

View file

@ -32,6 +32,7 @@ class RegistrationConfig(Config):
)
self.registration_shared_secret = config.get("registration_shared_secret")
self.user_creation_max_duration = int(config["user_creation_max_duration"])
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
# Sets the expiry for the short term user creation in
# milliseconds. For instance the bellow duration is two weeks
# in milliseconds.
user_creation_max_duration: 1209600000
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12.

View file

@ -521,11 +521,11 @@ class AuthHandler(BaseHandler):
))
return m.serialize()
def generate_short_term_login_token(self, user_id):
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + (2 * 60 * 1000)
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()

View file

@ -358,6 +358,59 @@ class RegistrationHandler(BaseHandler):
)
defer.returnValue(data)
@defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_seconds):
"""Creates a new user or returns an access token for an existing one
Args:
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
Returns:
A tuple of (user_id, access_token).
Raises:
RegistrationError if there was a problem registering.
"""
yield run_on_reactor()
if localpart is None:
raise SynapseError(400, "Request must include user id")
need_register = True
try:
yield self.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
else:
raise
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
auth_handler = self.hs.get_handlers().auth_handler
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
if need_register:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=None
)
yield registered_user(self.distributor, user)
else:
yield self.store.flush_user(user_id=user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler
yield profile_handler.set_displayname(
user, user, displayname
)
defer.returnValue((user_id, token))
def auth_handler(self):
return self.hs.get_handlers().auth_handler

View file

@ -355,5 +355,76 @@ class RegisterRestServlet(ClientV1RestServlet):
)
class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface
"""
PATTERNS = client_path_patterns("/createUser$", releases=())
def __init__(self, hs):
super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
@defer.inlineCallbacks
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
if "access_token" not in request.args:
raise SynapseError(400, "Expected application service token.")
app_service = yield self.store.get_app_service_by_token(
request.args["access_token"][0]
)
if not app_service:
raise SynapseError(403, "Invalid application service token.")
logger.debug("creating user: %s", user_json)
response = yield self._do_create(user_json)
defer.returnValue((200, response))
def on_OPTIONS(self, request):
return 403, {}
@defer.inlineCallbacks
def _do_create(self, user_json):
yield run_on_reactor()
if "localpart" not in user_json:
raise SynapseError(400, "Expected 'localpart' key.")
if "displayname" not in user_json:
raise SynapseError(400, "Expected 'displayname' key.")
if "duration_seconds" not in user_json:
raise SynapseError(400, "Expected 'duration_seconds' key.")
localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8")
duration_seconds = 0
try:
duration_seconds = int(user_json["duration_seconds"])
except ValueError:
raise SynapseError(400, "Failed to parse 'duration_seconds'")
if duration_seconds > self.direct_user_creation_max_duration:
duration_seconds = self.direct_user_creation_max_duration
handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user(
localpart=localpart,
displayname=displayname,
duration_seconds=duration_seconds
)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
})
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
CreateUserRestServlet(hs).register(http_server)