Run Black. (#5482)

This commit is contained in:
Amber Brown 2019-06-20 19:32:02 +10:00 committed by GitHub
parent 7dcf984075
commit 32e7c9e7f2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
376 changed files with 9142 additions and 10388 deletions

View file

@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission):
to a typed object.
"""
if "user" in submission:
submission["identifier"] = {
"type": "m.id.user",
"user": submission["user"],
}
submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
del submission["user"]
if "medium" in submission and "address" in submission:
@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier):
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
return {
"type": "m.id.thirdparty",
"medium": "msisdn",
"address": msisdn,
}
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
class LoginRestServlet(RestServlet):
@ -120,9 +113,9 @@ class LoginRestServlet(RestServlet):
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
flows.extend((
{"type": t} for t in self.auth_handler.get_supported_login_types()
))
flows.extend(
({"type": t} for t in self.auth_handler.get_supported_login_types())
)
return (200, {"flows": flows})
@ -132,7 +125,8 @@ class LoginRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
self._address_ratelimiter.ratelimit(
request.getClientIP(), time_now_s=self.hs.clock.time(),
request.getClientIP(),
time_now_s=self.hs.clock.time(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
update=True,
@ -140,8 +134,9 @@ class LoginRestServlet(RestServlet):
login_submission = parse_json_object_from_request(request)
try:
if self.jwt_enabled and (login_submission["type"] ==
LoginRestServlet.JWT_TYPE):
if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
):
result = yield self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
@ -170,10 +165,10 @@ class LoginRestServlet(RestServlet):
# field)
logger.info(
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
login_submission.get('identifier'),
login_submission.get('medium'),
login_submission.get('address'),
login_submission.get('user'),
login_submission.get("identifier"),
login_submission.get("medium"),
login_submission.get("address"),
login_submission.get("user"),
)
login_submission_legacy_convert(login_submission)
@ -190,13 +185,13 @@ class LoginRestServlet(RestServlet):
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
address = identifier.get('address')
medium = identifier.get('medium')
address = identifier.get("address")
medium = identifier.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
if medium == 'email':
if medium == "email":
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
@ -205,34 +200,28 @@ class LoginRestServlet(RestServlet):
# Check for login providers that support 3pid login types
canonical_user_id, callback_3pid = (
yield self.auth_handler.check_password_provider_3pid(
medium,
address,
login_submission["password"],
medium, address, login_submission["password"]
)
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback_3pid,
canonical_user_id, login_submission, callback_3pid
)
defer.returnValue(result)
# No password providers were able to handle this 3pid
# Check local store
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
medium, address,
medium, address
)
if not user_id:
logger.warn(
"unknown 3pid identifier medium %s, address %r",
medium, address,
"unknown 3pid identifier medium %s, address %r", medium, address
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {
"type": "m.id.user",
"user": user_id,
}
identifier = {"type": "m.id.user", "user": user_id}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
@ -242,22 +231,16 @@ class LoginRestServlet(RestServlet):
raise SynapseError(400, "User identifier is missing 'user' key")
canonical_user_id, callback = yield self.auth_handler.validate_login(
identifier["user"],
login_submission,
identifier["user"], login_submission
)
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback,
canonical_user_id, login_submission, callback
)
defer.returnValue(result)
@defer.inlineCallbacks
def _register_device_with_callback(
self,
user_id,
login_submission,
callback=None,
):
def _register_device_with_callback(self, user_id, login_submission, callback=None):
""" Registers a device with a given user_id. Optionally run a callback
function after registration has completed.
@ -273,7 +256,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name,
user_id, device_id, initial_display_name
)
result = {
@ -290,7 +273,7 @@ class LoginRestServlet(RestServlet):
@defer.inlineCallbacks
def do_token_login(self, login_submission):
token = login_submission['token']
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
@ -299,7 +282,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name,
user_id, device_id, initial_display_name
)
result = {
@ -316,15 +299,16 @@ class LoginRestServlet(RestServlet):
token = login_submission.get("token", None)
if token is None:
raise LoginError(
401, "Token field for JWT is missing",
errcode=Codes.UNAUTHORIZED
401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
)
import jwt
from jwt.exceptions import InvalidTokenError
try:
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
payload = jwt.decode(
token, self.jwt_secret, algorithms=[self.jwt_algorithm]
)
except jwt.ExpiredSignatureError:
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
except InvalidTokenError:
@ -342,7 +326,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
registered_user_id, device_id, initial_display_name,
registered_user_id, device_id, initial_display_name
)
result = {
@ -358,7 +342,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
registered_user_id, device_id, initial_display_name,
registered_user_id, device_id, initial_display_name
)
result = {
@ -375,21 +359,20 @@ class CasRedirectServlet(RestServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
self.cas_server_url = hs.config.cas_server_url.encode("ascii")
self.cas_service_url = hs.config.cas_service_url.encode("ascii")
def on_GET(self, request):
args = request.args
if b"redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth")
client_redirect_url_param = urllib.parse.urlencode({
b"redirectUrl": args[b"redirectUrl"][0]
}).encode('ascii')
hs_redirect_url = (self.cas_service_url +
b"/_matrix/client/r0/login/cas/ticket")
service_param = urllib.parse.urlencode({
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
}).encode('ascii')
client_redirect_url_param = urllib.parse.urlencode(
{b"redirectUrl": args[b"redirectUrl"][0]}
).encode("ascii")
hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
service_param = urllib.parse.urlencode(
{b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
).encode("ascii")
request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
finish_request(request)
@ -411,7 +394,7 @@ class CasTicketServlet(RestServlet):
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url
"service": self.cas_service_url,
}
try:
body = yield self._http_client.get_raw(uri, args)
@ -438,7 +421,7 @@ class CasTicketServlet(RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth(
user, request, client_redirect_url,
user, request, client_redirect_url
)
def parse_cas_response(self, cas_response_body):
@ -448,7 +431,7 @@ class CasTicketServlet(RestServlet):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = (root[0].tag.endswith("authenticationSuccess"))
success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@ -466,11 +449,11 @@ class CasTicketServlet(RestServlet):
raise Exception("CAS response does not contain user")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response",
errcode=Codes.UNAUTHORIZED)
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(401, "Unsuccessful CAS response",
errcode=Codes.UNAUTHORIZED)
raise LoginError(
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
)
return user, attributes
@ -482,6 +465,7 @@ class SSOAuthHandler(object):
Args:
hs (synapse.server.HomeServer)
"""
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
@ -490,8 +474,7 @@ class SSOAuthHandler(object):
@defer.inlineCallbacks
def on_successful_auth(
self, username, request, client_redirect_url,
user_display_name=None,
self, username, request, client_redirect_url, user_display_name=None
):
"""Called once the user has successfully authenticated with the SSO.