Refactor some logic from LoginRestServlet into AuthHandler

I'm going to need some more flexibility in handling login types in password
auth providers, so as a first step, move some stuff from LoginRestServlet into
AuthHandler.

In particular, we pass everything other than SAML, JWT and token logins down to
the AuthHandler, which now has responsibility for checking the login type and
fishing the password out of the login dictionary, as well as qualifying the
user_id if need be. Ideally SAML, JWT and token would go that way too, but
there's no real need for it right now and I'm trying to minimise impact.

This commit *should* be non-functional.
This commit is contained in:
Richard van der Hoff 2017-10-31 10:38:40 +00:00
parent e2f4190209
commit 1b65ae00ac
2 changed files with 79 additions and 58 deletions

View file

@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
if self.password_enabled:
flows.append({"type": LoginRestServlet.PASS_TYPE})
flows.extend((
{"type": t} for t in self.auth_handler.get_supported_login_types()
))
return (200, {"flows": flows})
@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request):
login_submission = parse_json_object_from_request(request)
try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
if not self.password_enabled:
raise SynapseError(400, "Password login has been disabled.")
result = yield self.do_password_login(login_submission)
defer.returnValue(result)
elif self.saml2_enabled and (login_submission["type"] ==
LoginRestServlet.SAML2_TYPE):
if self.saml2_enabled and (login_submission["type"] ==
LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote(
@ -157,15 +151,21 @@ class LoginRestServlet(ClientV1RestServlet):
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else:
raise SynapseError(400, "Bad login type.")
result = yield self._do_other_login(login_submission)
defer.returnValue(result)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks
def do_password_login(self, login_submission):
if "password" not in login_submission:
raise SynapseError(400, "Missing parameter: password")
def _do_other_login(self, login_submission):
"""Handle non-token/saml/jwt logins
Args:
login_submission:
Returns:
(int, object): HTTP code/response
"""
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
@ -208,25 +208,22 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
user_id = identifier["user"]
if not user_id.startswith('@'):
user_id = UserID(
user_id, self.hs.hostname
).to_string()
auth_handler = self.auth_handler
user_id = yield auth_handler.validate_password_login(
user_id=user_id,
password=login_submission["password"],
canonical_user_id = yield auth_handler.validate_login(
identifier["user"],
login_submission,
)
device_id = yield self._register_device(
canonical_user_id, login_submission,
)
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id,
canonical_user_id, device_id,
login_submission.get("initial_device_display_name"),
)
result = {
"user_id": user_id, # may have changed
"user_id": canonical_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,