Merge branch 'develop' into rav/saml2_client

This commit is contained in:
Richard van der Hoff 2019-06-26 22:34:41 +01:00
commit a4daa899ec
478 changed files with 18927 additions and 11500 deletions

View file

@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission):
to a typed object.
"""
if "user" in submission:
submission["identifier"] = {
"type": "m.id.user",
"user": submission["user"],
}
submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
del submission["user"]
if "medium" in submission and "address" in submission:
@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier):
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
return {
"type": "m.id.thirdparty",
"medium": "msisdn",
"address": msisdn,
}
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
class LoginRestServlet(RestServlet):
@ -124,9 +117,9 @@ class LoginRestServlet(RestServlet):
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
flows.extend((
{"type": t} for t in self.auth_handler.get_supported_login_types()
))
flows.extend(
({"type": t} for t in self.auth_handler.get_supported_login_types())
)
return (200, {"flows": flows})
@ -136,7 +129,8 @@ class LoginRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
self._address_ratelimiter.ratelimit(
request.getClientIP(), time_now_s=self.hs.clock.time(),
request.getClientIP(),
time_now_s=self.hs.clock.time(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
update=True,
@ -144,8 +138,9 @@ class LoginRestServlet(RestServlet):
login_submission = parse_json_object_from_request(request)
try:
if self.jwt_enabled and (login_submission["type"] ==
LoginRestServlet.JWT_TYPE):
if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
):
result = yield self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
@ -174,10 +169,10 @@ class LoginRestServlet(RestServlet):
# field)
logger.info(
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
login_submission.get('identifier'),
login_submission.get('medium'),
login_submission.get('address'),
login_submission.get('user'),
login_submission.get("identifier"),
login_submission.get("medium"),
login_submission.get("address"),
login_submission.get("user"),
)
login_submission_legacy_convert(login_submission)
@ -194,13 +189,13 @@ class LoginRestServlet(RestServlet):
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
address = identifier.get('address')
medium = identifier.get('medium')
address = identifier.get("address")
medium = identifier.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
if medium == 'email':
if medium == "email":
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
@ -209,34 +204,28 @@ class LoginRestServlet(RestServlet):
# Check for login providers that support 3pid login types
canonical_user_id, callback_3pid = (
yield self.auth_handler.check_password_provider_3pid(
medium,
address,
login_submission["password"],
medium, address, login_submission["password"]
)
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback_3pid,
canonical_user_id, login_submission, callback_3pid
)
defer.returnValue(result)
# No password providers were able to handle this 3pid
# Check local store
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
medium, address,
medium, address
)
if not user_id:
logger.warn(
"unknown 3pid identifier medium %s, address %r",
medium, address,
"unknown 3pid identifier medium %s, address %r", medium, address
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {
"type": "m.id.user",
"user": user_id,
}
identifier = {"type": "m.id.user", "user": user_id}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
@ -246,22 +235,16 @@ class LoginRestServlet(RestServlet):
raise SynapseError(400, "User identifier is missing 'user' key")
canonical_user_id, callback = yield self.auth_handler.validate_login(
identifier["user"],
login_submission,
identifier["user"], login_submission
)
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback,
canonical_user_id, login_submission, callback
)
defer.returnValue(result)
@defer.inlineCallbacks
def _register_device_with_callback(
self,
user_id,
login_submission,
callback=None,
):
def _register_device_with_callback(self, user_id, login_submission, callback=None):
""" Registers a device with a given user_id. Optionally run a callback
function after registration has completed.
@ -277,7 +260,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name,
user_id, device_id, initial_display_name
)
result = {
@ -294,7 +277,7 @@ class LoginRestServlet(RestServlet):
@defer.inlineCallbacks
def do_token_login(self, login_submission):
token = login_submission['token']
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
@ -303,7 +286,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name,
user_id, device_id, initial_display_name
)
result = {
@ -320,15 +303,16 @@ class LoginRestServlet(RestServlet):
token = login_submission.get("token", None)
if token is None:
raise LoginError(
401, "Token field for JWT is missing",
errcode=Codes.UNAUTHORIZED
401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
)
import jwt
from jwt.exceptions import InvalidTokenError
try:
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
payload = jwt.decode(
token, self.jwt_secret, algorithms=[self.jwt_algorithm]
)
except jwt.ExpiredSignatureError:
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
except InvalidTokenError:
@ -346,7 +330,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
registered_user_id, device_id, initial_display_name,
registered_user_id, device_id, initial_display_name
)
result = {
@ -362,7 +346,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
registered_user_id, device_id, initial_display_name,
registered_user_id, device_id, initial_display_name
)
result = {
@ -376,6 +360,7 @@ class LoginRestServlet(RestServlet):
class BaseSsoRedirectServlet(RestServlet):
"""Common base class for /login/sso/redirect impls"""
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def on_GET(self, request):
@ -401,21 +386,20 @@ class BaseSsoRedirectServlet(RestServlet):
raise NotImplementedError()
class CasRedirectServlet(RestServlet):
class CasRedirectServlet(BaseSsoRedirectServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
self.cas_server_url = hs.config.cas_server_url.encode("ascii")
self.cas_service_url = hs.config.cas_service_url.encode("ascii")
def get_sso_url(self, client_redirect_url):
client_redirect_url_param = urllib.parse.urlencode({
b"redirectUrl": client_redirect_url
}).encode('ascii')
hs_redirect_url = (self.cas_service_url +
b"/_matrix/client/r0/login/cas/ticket")
service_param = urllib.parse.urlencode({
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
}).encode('ascii')
client_redirect_url_param = urllib.parse.urlencode(
{b"redirectUrl": client_redirect_url}
).encode("ascii")
hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
service_param = urllib.parse.urlencode(
{b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
).encode("ascii")
return b"%s/login?%s" % (self.cas_server_url, service_param)
@ -436,7 +420,7 @@ class CasTicketServlet(RestServlet):
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url
"service": self.cas_service_url,
}
try:
body = yield self._http_client.get_raw(uri, args)
@ -463,7 +447,7 @@ class CasTicketServlet(RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth(
user, request, client_redirect_url,
user, request, client_redirect_url
)
def parse_cas_response(self, cas_response_body):
@ -473,7 +457,7 @@ class CasTicketServlet(RestServlet):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = (root[0].tag.endswith("authenticationSuccess"))
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@ -491,11 +475,11 @@ class CasTicketServlet(RestServlet):
raise Exception("CAS response does not contain user")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response",
errcode=Codes.UNAUTHORIZED)
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(401, "Unsuccessful CAS response",
errcode=Codes.UNAUTHORIZED)
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
@ -507,11 +491,11 @@ class SAMLRedirectServlet(BaseSsoRedirectServlet):
def get_sso_url(self, client_redirect_url):
reqid, info = self._saml_client.prepare_for_authenticate(
relay_state=client_redirect_url,
relay_state=client_redirect_url
)
for key, value in info['headers']:
if key == 'Location':
for key, value in info["headers"]:
if key == "Location":
return value
# this shouldn't happen!
@ -526,6 +510,7 @@ class SSOAuthHandler(object):
Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
@ -534,8 +519,7 @@ class SSOAuthHandler(object):
@defer.inlineCallbacks
def on_successful_auth(
self, username, request, client_redirect_url,
user_display_name=None,
self, username, request, client_redirect_url, user_display_name=None
):
"""Called once the user has successfully authenticated with the SSO.