Implement registering with shared secret.

This commit is contained in:
Erik Johnston 2015-03-13 15:23:37 +00:00
parent 58367a9da2
commit 69135f59aa
4 changed files with 83 additions and 5 deletions

View File

@ -60,6 +60,7 @@ class LoginType(object):
EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha"
APPLICATION_SERVICE = u"m.login.application_service"
SHARED_SECRET = u"org.matrix.login.shared_secret"
class EventTypes(object):

View File

@ -15,23 +15,37 @@
from ._base import Config
from synapse.util.stringutils import random_string_with_symbols
class RegistrationConfig(Config):
def __init__(self, args):
super(RegistrationConfig, self).__init__(args)
self.disable_registration = args.disable_registration
self.registration_shared_secret = args.registration_shared_secret
@classmethod
def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser)
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
"--disable-registration",
action='store_true',
help="Disable registration of new users."
action='store_const',
const=True,
help="Disable registration of new users.",
)
reg_group.add_argument(
"--registration-shared-secret", type=str,
help="If set, allows registration by anyone who also has the shared"
" secret, even if registration is otherwise disabled.",
)
@classmethod
def generate_config(cls, args, config_dir_path):
if args.disable_registration is None:
args.disable_registration = True
if args.registration_shared_secret is None:
args.registration_shared_secret= random_string_with_symbols(50)

View File

@ -110,14 +110,22 @@ class RegisterRestServlet(ClientV1RestServlet):
login_type = register_json["type"]
is_application_server = login_type == LoginType.APPLICATION_SERVICE
if self.disable_registration and not is_application_server:
is_using_shared_secret = login_type == LoginType.SHARED_SECRET
can_register = (
not self.disable_registration
or is_application_server
or is_using_shared_secret
)
if not can_register:
raise SynapseError(403, "Registration has been disabled")
stages = {
LoginType.RECAPTCHA: self._do_recaptcha,
LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity,
LoginType.APPLICATION_SERVICE: self._do_app_service
LoginType.APPLICATION_SERVICE: self._do_app_service,
LoginType.SHARED_SECRET: self._do_shared_secret,
}
session_info = self._get_session_info(request, session)
@ -304,6 +312,51 @@ class RegisterRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname,
})
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
yield run_on_reactor()
if "mac" not in register_json:
raise SynapseError(400, "Expected mac.")
if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
if "password" not in register_json:
raise SynapseError(400, "Expected 'password' key.")
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
user = register_json["user"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(register_json["mac"])
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
msg=user,
digestmod=sha1,
).hexdigest()
password = register_json["password"].encode("utf-8")
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
localpart=user,
password=password,
)
self._remove_session(session)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
})
else:
raise SynapseError(
400, "HMAC incorrect",
)
def _parse_json(request):
try:

View File

@ -16,6 +16,10 @@
import random
import string
_string_with_symbols = (
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
)
def origin_from_ucid(ucid):
return ucid.split("@", 1)[1]
@ -23,3 +27,9 @@ def origin_from_ucid(ucid):
def random_string(length):
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
def random_string_with_symbols(length):
return ''.join(
random.choice(_string_with_symbols) for _ in xrange(length)
)