mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-05 08:24:55 -04:00
Run Black. (#5482)
This commit is contained in:
parent
7dcf984075
commit
32e7c9e7f2
376 changed files with 9142 additions and 10388 deletions
|
@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission):
|
|||
to a typed object.
|
||||
"""
|
||||
if "user" in submission:
|
||||
submission["identifier"] = {
|
||||
"type": "m.id.user",
|
||||
"user": submission["user"],
|
||||
}
|
||||
submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
|
||||
del submission["user"]
|
||||
|
||||
if "medium" in submission and "address" in submission:
|
||||
|
@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier):
|
|||
|
||||
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
|
||||
|
||||
return {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": "msisdn",
|
||||
"address": msisdn,
|
||||
}
|
||||
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
|
||||
|
||||
|
||||
class LoginRestServlet(RestServlet):
|
||||
|
@ -120,9 +113,9 @@ class LoginRestServlet(RestServlet):
|
|||
# login flow types returned.
|
||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||
|
||||
flows.extend((
|
||||
{"type": t} for t in self.auth_handler.get_supported_login_types()
|
||||
))
|
||||
flows.extend(
|
||||
({"type": t} for t in self.auth_handler.get_supported_login_types())
|
||||
)
|
||||
|
||||
return (200, {"flows": flows})
|
||||
|
||||
|
@ -132,7 +125,8 @@ class LoginRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
self._address_ratelimiter.ratelimit(
|
||||
request.getClientIP(), time_now_s=self.hs.clock.time(),
|
||||
request.getClientIP(),
|
||||
time_now_s=self.hs.clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_address.per_second,
|
||||
burst_count=self.hs.config.rc_login_address.burst_count,
|
||||
update=True,
|
||||
|
@ -140,8 +134,9 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
login_submission = parse_json_object_from_request(request)
|
||||
try:
|
||||
if self.jwt_enabled and (login_submission["type"] ==
|
||||
LoginRestServlet.JWT_TYPE):
|
||||
if self.jwt_enabled and (
|
||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
):
|
||||
result = yield self.do_jwt_login(login_submission)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
result = yield self.do_token_login(login_submission)
|
||||
|
@ -170,10 +165,10 @@ class LoginRestServlet(RestServlet):
|
|||
# field)
|
||||
logger.info(
|
||||
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
|
||||
login_submission.get('identifier'),
|
||||
login_submission.get('medium'),
|
||||
login_submission.get('address'),
|
||||
login_submission.get('user'),
|
||||
login_submission.get("identifier"),
|
||||
login_submission.get("medium"),
|
||||
login_submission.get("address"),
|
||||
login_submission.get("user"),
|
||||
)
|
||||
login_submission_legacy_convert(login_submission)
|
||||
|
||||
|
@ -190,13 +185,13 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
# convert threepid identifiers to user IDs
|
||||
if identifier["type"] == "m.id.thirdparty":
|
||||
address = identifier.get('address')
|
||||
medium = identifier.get('medium')
|
||||
address = identifier.get("address")
|
||||
medium = identifier.get("medium")
|
||||
|
||||
if medium is None or address is None:
|
||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||
|
||||
if medium == 'email':
|
||||
if medium == "email":
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
|
@ -205,34 +200,28 @@ class LoginRestServlet(RestServlet):
|
|||
# Check for login providers that support 3pid login types
|
||||
canonical_user_id, callback_3pid = (
|
||||
yield self.auth_handler.check_password_provider_3pid(
|
||||
medium,
|
||||
address,
|
||||
login_submission["password"],
|
||||
medium, address, login_submission["password"]
|
||||
)
|
||||
)
|
||||
if canonical_user_id:
|
||||
# Authentication through password provider and 3pid succeeded
|
||||
result = yield self._register_device_with_callback(
|
||||
canonical_user_id, login_submission, callback_3pid,
|
||||
canonical_user_id, login_submission, callback_3pid
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
# No password providers were able to handle this 3pid
|
||||
# Check local store
|
||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
medium, address,
|
||||
medium, address
|
||||
)
|
||||
if not user_id:
|
||||
logger.warn(
|
||||
"unknown 3pid identifier medium %s, address %r",
|
||||
medium, address,
|
||||
"unknown 3pid identifier medium %s, address %r", medium, address
|
||||
)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
identifier = {
|
||||
"type": "m.id.user",
|
||||
"user": user_id,
|
||||
}
|
||||
identifier = {"type": "m.id.user", "user": user_id}
|
||||
|
||||
# by this point, the identifier should be an m.id.user: if it's anything
|
||||
# else, we haven't understood it.
|
||||
|
@ -242,22 +231,16 @@ class LoginRestServlet(RestServlet):
|
|||
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||
|
||||
canonical_user_id, callback = yield self.auth_handler.validate_login(
|
||||
identifier["user"],
|
||||
login_submission,
|
||||
identifier["user"], login_submission
|
||||
)
|
||||
|
||||
result = yield self._register_device_with_callback(
|
||||
canonical_user_id, login_submission, callback,
|
||||
canonical_user_id, login_submission, callback
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _register_device_with_callback(
|
||||
self,
|
||||
user_id,
|
||||
login_submission,
|
||||
callback=None,
|
||||
):
|
||||
def _register_device_with_callback(self, user_id, login_submission, callback=None):
|
||||
""" Registers a device with a given user_id. Optionally run a callback
|
||||
function after registration has completed.
|
||||
|
||||
|
@ -273,7 +256,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name,
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -290,7 +273,7 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def do_token_login(self, login_submission):
|
||||
token = login_submission['token']
|
||||
token = login_submission["token"]
|
||||
auth_handler = self.auth_handler
|
||||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
|
@ -299,7 +282,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name,
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -316,15 +299,16 @@ class LoginRestServlet(RestServlet):
|
|||
token = login_submission.get("token", None)
|
||||
if token is None:
|
||||
raise LoginError(
|
||||
401, "Token field for JWT is missing",
|
||||
errcode=Codes.UNAUTHORIZED
|
||||
401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
import jwt
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
|
||||
payload = jwt.decode(
|
||||
token, self.jwt_secret, algorithms=[self.jwt_algorithm]
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
|
||||
except InvalidTokenError:
|
||||
|
@ -342,7 +326,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
registered_user_id, device_id, initial_display_name,
|
||||
registered_user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -358,7 +342,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
registered_user_id, device_id, initial_display_name,
|
||||
registered_user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -375,21 +359,20 @@ class CasRedirectServlet(RestServlet):
|
|||
|
||||
def __init__(self, hs):
|
||||
super(CasRedirectServlet, self).__init__()
|
||||
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
|
||||
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
|
||||
self.cas_server_url = hs.config.cas_server_url.encode("ascii")
|
||||
self.cas_service_url = hs.config.cas_service_url.encode("ascii")
|
||||
|
||||
def on_GET(self, request):
|
||||
args = request.args
|
||||
if b"redirectUrl" not in args:
|
||||
return (400, "Redirect URL not specified for CAS auth")
|
||||
client_redirect_url_param = urllib.parse.urlencode({
|
||||
b"redirectUrl": args[b"redirectUrl"][0]
|
||||
}).encode('ascii')
|
||||
hs_redirect_url = (self.cas_service_url +
|
||||
b"/_matrix/client/r0/login/cas/ticket")
|
||||
service_param = urllib.parse.urlencode({
|
||||
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
|
||||
}).encode('ascii')
|
||||
client_redirect_url_param = urllib.parse.urlencode(
|
||||
{b"redirectUrl": args[b"redirectUrl"][0]}
|
||||
).encode("ascii")
|
||||
hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
|
||||
service_param = urllib.parse.urlencode(
|
||||
{b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
|
||||
).encode("ascii")
|
||||
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
|
||||
finish_request(request)
|
||||
|
||||
|
@ -411,7 +394,7 @@ class CasTicketServlet(RestServlet):
|
|||
uri = self.cas_server_url + "/proxyValidate"
|
||||
args = {
|
||||
"ticket": parse_string(request, "ticket", required=True),
|
||||
"service": self.cas_service_url
|
||||
"service": self.cas_service_url,
|
||||
}
|
||||
try:
|
||||
body = yield self._http_client.get_raw(uri, args)
|
||||
|
@ -438,7 +421,7 @@ class CasTicketServlet(RestServlet):
|
|||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
return self._sso_auth_handler.on_successful_auth(
|
||||
user, request, client_redirect_url,
|
||||
user, request, client_redirect_url
|
||||
)
|
||||
|
||||
def parse_cas_response(self, cas_response_body):
|
||||
|
@ -448,7 +431,7 @@ class CasTicketServlet(RestServlet):
|
|||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise Exception("root of CAS response is not serviceResponse")
|
||||
success = (root[0].tag.endswith("authenticationSuccess"))
|
||||
success = root[0].tag.endswith("authenticationSuccess")
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
|
@ -466,11 +449,11 @@ class CasTicketServlet(RestServlet):
|
|||
raise Exception("CAS response does not contain user")
|
||||
except Exception:
|
||||
logger.error("Error parsing CAS response", exc_info=1)
|
||||
raise LoginError(401, "Invalid CAS response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not success:
|
||||
raise LoginError(401, "Unsuccessful CAS response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
raise LoginError(
|
||||
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
return user, attributes
|
||||
|
||||
|
||||
|
@ -482,6 +465,7 @@ class SSOAuthHandler(object):
|
|||
Args:
|
||||
hs (synapse.server.HomeServer)
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self._hostname = hs.hostname
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
@ -490,8 +474,7 @@ class SSOAuthHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_successful_auth(
|
||||
self, username, request, client_redirect_url,
|
||||
user_display_name=None,
|
||||
self, username, request, client_redirect_url, user_display_name=None
|
||||
):
|
||||
"""Called once the user has successfully authenticated with the SSO.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue