Add type annotations and comments to auth handler (#7063)

This commit is contained in:
Patrick Cloke 2020-03-12 11:36:27 -04:00 committed by GitHub
parent bd5e555b0d
commit 77d0a4507b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 106 additions and 89 deletions

1
changelog.d/7063.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations and comments to the auth handler.

View File

@ -18,10 +18,10 @@ import logging
import time import time
import unicodedata import unicodedata
import urllib.parse import urllib.parse
from typing import Any from typing import Any, Dict, Iterable, List, Optional
import attr import attr
import bcrypt import bcrypt # type: ignore[import]
import pymacaroons import pymacaroons
from twisted.internet import defer from twisted.internet import defer
@ -45,7 +45,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates from synapse.push.mailer import load_jinja2_templates
from synapse.types import UserID from synapse.types import Requester, UserID
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler from ._base import BaseHandler
@ -63,11 +63,11 @@ class AuthHandler(BaseHandler):
""" """
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
self.checkers = {} # type: dict[str, UserInteractiveAuthChecker] self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs) inst = auth_checker_class(hs)
if inst.is_enabled(): if inst.is_enabled():
self.checkers[inst.AUTH_TYPE] = inst self.checkers[inst.AUTH_TYPE] = inst # type: ignore
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
@ -124,7 +124,9 @@ class AuthHandler(BaseHandler):
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_user_via_ui_auth(self, requester, request_body, clientip): def validate_user_via_ui_auth(
self, requester: Requester, request_body: Dict[str, Any], clientip: str
):
""" """
Checks that the user is who they claim to be, via a UI auth. Checks that the user is who they claim to be, via a UI auth.
@ -133,11 +135,11 @@ class AuthHandler(BaseHandler):
that it isn't stolen by re-authenticating them. that it isn't stolen by re-authenticating them.
Args: Args:
requester (Requester): The user, as given by the access token requester: The user, as given by the access token
request_body (dict): The body of the request sent by the client request_body: The body of the request sent by the client
clientip (str): The IP address of the client. clientip: The IP address of the client.
Returns: Returns:
defer.Deferred[dict]: the parameters for this request (which may defer.Deferred[dict]: the parameters for this request (which may
@ -208,7 +210,9 @@ class AuthHandler(BaseHandler):
return self.checkers.keys() return self.checkers.keys()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(
self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
):
""" """
Takes a dictionary sent by the client in the login / registration Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow. protocol and handles the User-Interactive Auth flow.
@ -223,14 +227,14 @@ class AuthHandler(BaseHandler):
decorator. decorator.
Args: Args:
flows (list): A list of login flows. Each flow is an ordered list of flows: A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full strings representing auth-types. At least one full
flow must be completed in order for auth to be successful. flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent. 'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client. clientip: The IP address of the client.
Returns: Returns:
defer.Deferred[dict, dict, str]: a deferred tuple of defer.Deferred[dict, dict, str]: a deferred tuple of
@ -250,7 +254,7 @@ class AuthHandler(BaseHandler):
""" """
authdict = None authdict = None
sid = None sid = None # type: Optional[str]
if clientdict and "auth" in clientdict: if clientdict and "auth" in clientdict:
authdict = clientdict["auth"] authdict = clientdict["auth"]
del clientdict["auth"] del clientdict["auth"]
@ -283,9 +287,9 @@ class AuthHandler(BaseHandler):
creds = session["creds"] creds = session["creds"]
# check auth type currently being presented # check auth type currently being presented
errordict = {} errordict = {} # type: Dict[str, Any]
if "type" in authdict: if "type" in authdict:
login_type = authdict["type"] login_type = authdict["type"] # type: str
try: try:
result = yield self._check_auth_dict(authdict, clientip) result = yield self._check_auth_dict(authdict, clientip)
if result: if result:
@ -326,7 +330,7 @@ class AuthHandler(BaseHandler):
raise InteractiveAuthIncompleteError(ret) raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip): def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
""" """
Adds the result of out-of-band authentication into an existing auth Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth. session. Currently used for adding the result of fallback auth.
@ -348,7 +352,7 @@ class AuthHandler(BaseHandler):
return True return True
return False return False
def get_session_id(self, clientdict): def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
""" """
Gets the session ID for a client given the client dictionary Gets the session ID for a client given the client dictionary
@ -356,7 +360,7 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary sent by the client in the request clientdict: The dictionary sent by the client in the request
Returns: Returns:
str|None: The string session ID the client sent. If the client did The string session ID the client sent. If the client did
not send a session ID, returns None. not send a session ID, returns None.
""" """
sid = None sid = None
@ -366,40 +370,42 @@ class AuthHandler(BaseHandler):
sid = authdict["session"] sid = authdict["session"]
return sid return sid
def set_session_data(self, session_id, key, value): def set_session_data(self, session_id: str, key: str, value: Any) -> None:
""" """
Store a key-value pair into the sessions data associated with this Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by request. This data is stored server-side and cannot be modified by
the client. the client.
Args: Args:
session_id (string): The ID of this session as returned from check_auth session_id: The ID of this session as returned from check_auth
key (string): The key to store the data under key: The key to store the data under
value (any): The data to store value: The data to store
""" """
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
sess.setdefault("serverdict", {})[key] = value sess.setdefault("serverdict", {})[key] = value
self._save_session(sess) self._save_session(sess)
def get_session_data(self, session_id, key, default=None): def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
""" """
Retrieve data stored with set_session_data Retrieve data stored with set_session_data
Args: Args:
session_id (string): The ID of this session as returned from check_auth session_id: The ID of this session as returned from check_auth
key (string): The key to store the data under key: The key to store the data under
default (any): Value to return if the key has not been set default: Value to return if the key has not been set
""" """
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault("serverdict", {}).get(key, default) return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_auth_dict(self, authdict, clientip): def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
"""Attempt to validate the auth dict provided by a client """Attempt to validate the auth dict provided by a client
Args: Args:
authdict (object): auth dict provided by the client authdict: auth dict provided by the client
clientip (str): IP address of the client clientip: IP address of the client
Returns: Returns:
Deferred: result of the stage verification. Deferred: result of the stage verification.
@ -425,10 +431,10 @@ class AuthHandler(BaseHandler):
(canonical_id, callback) = yield self.validate_login(user_id, authdict) (canonical_id, callback) = yield self.validate_login(user_id, authdict)
return canonical_id return canonical_id
def _get_params_recaptcha(self): def _get_params_recaptcha(self) -> dict:
return {"public_key": self.hs.config.recaptcha_public_key} return {"public_key": self.hs.config.recaptcha_public_key}
def _get_params_terms(self): def _get_params_terms(self) -> dict:
return { return {
"policies": { "policies": {
"privacy_policy": { "privacy_policy": {
@ -445,7 +451,9 @@ class AuthHandler(BaseHandler):
} }
} }
def _auth_dict_for_flows(self, flows, session): def _auth_dict_for_flows(
self, flows: List[List[str]], session: Dict[str, Any]
) -> Dict[str, Any]:
public_flows = [] public_flows = []
for f in flows: for f in flows:
public_flows.append(f) public_flows.append(f)
@ -455,7 +463,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms, LoginType.TERMS: self._get_params_terms,
} }
params = {} params = {} # type: Dict[str, Any]
for f in public_flows: for f in public_flows:
for stage in f: for stage in f:
@ -468,7 +476,13 @@ class AuthHandler(BaseHandler):
"params": params, "params": params,
} }
def _get_session_info(self, session_id): def _get_session_info(self, session_id: Optional[str]) -> dict:
"""
Gets or creates a session given a session ID.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
"""
if session_id not in self.sessions: if session_id not in self.sessions:
session_id = None session_id = None
@ -481,7 +495,9 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
@defer.inlineCallbacks @defer.inlineCallbacks
def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms): def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
""" """
Creates a new access token for the user with the given user ID. Creates a new access token for the user with the given user ID.
@ -491,11 +507,11 @@ class AuthHandler(BaseHandler):
The device will be recorded in the table if it is not there already. The device will be recorded in the table if it is not there already.
Args: Args:
user_id (str): canonical User ID user_id: canonical User ID
device_id (str|None): the device ID to associate with the tokens. device_id: the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated: None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID) we should always have a device ID)
valid_until_ms (int|None): when the token is valid until. None for valid_until_ms: when the token is valid until. None for
no expiry. no expiry.
Returns: Returns:
The access token for the user's session. The access token for the user's session.
@ -530,13 +546,13 @@ class AuthHandler(BaseHandler):
return access_token return access_token
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_exists(self, user_id): def check_user_exists(self, user_id: str):
""" """
Checks to see if a user with the given id exists. Will check case Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches. insensitively, but return None if there are multiple inexact matches.
Args: Args:
(unicode|bytes) user_id: complete @user:id user_id: complete @user:id
Returns: Returns:
defer.Deferred: (unicode) canonical_user_id, or None if zero or defer.Deferred: (unicode) canonical_user_id, or None if zero or
@ -551,7 +567,7 @@ class AuthHandler(BaseHandler):
return None return None
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id: str):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact insensitively, but will return None if there are multiple inexact
matches. matches.
@ -581,7 +597,7 @@ class AuthHandler(BaseHandler):
) )
return result return result
def get_supported_login_types(self): def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API """Get a the login types supported for the /login API
By default this is just 'm.login.password' (unless password_enabled is By default this is just 'm.login.password' (unless password_enabled is
@ -589,20 +605,20 @@ class AuthHandler(BaseHandler):
other login types. other login types.
Returns: Returns:
Iterable[str]: login types login types
""" """
return self._supported_login_types return self._supported_login_types
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_login(self, username, login_submission): def validate_login(self, username: str, login_submission: Dict[str, Any]):
"""Authenticates the user for the /login API """Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate Also used by the user-interactive auth flow to validate
m.login.password auth types. m.login.password auth types.
Args: Args:
username (str): username supplied by the user username: username supplied by the user
login_submission (dict): the whole of the login submission login_submission: the whole of the login submission
(including 'type' and other relevant fields) (including 'type' and other relevant fields)
Returns: Returns:
Deferred[str, func]: canonical user id, and optional callback Deferred[str, func]: canonical user id, and optional callback
@ -690,13 +706,13 @@ class AuthHandler(BaseHandler):
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_password_provider_3pid(self, medium, address, password): def check_password_provider_3pid(self, medium: str, address: str, password: str):
"""Check if a password provider is able to validate a thirdparty login """Check if a password provider is able to validate a thirdparty login
Args: Args:
medium (str): The medium of the 3pid (ex. email). medium: The medium of the 3pid (ex. email).
address (str): The address of the 3pid (ex. jdoe@example.com). address: The address of the 3pid (ex. jdoe@example.com).
password (str): The password of the user. password: The password of the user.
Returns: Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id, Deferred[(str|None, func|None)]: A tuple of `(user_id,
@ -724,15 +740,15 @@ class AuthHandler(BaseHandler):
return None, None return None, None
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id: str, password: str):
"""Authenticate a user against the local password database. """Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are user_id is checked case insensitively, but will return None if there are
multiple inexact matches. multiple inexact matches.
Args: Args:
user_id (unicode): complete @user:id user_id: complete @user:id
password (unicode): the provided password password: the provided password
Returns: Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password unknown user/bad password
@ -755,7 +771,7 @@ class AuthHandler(BaseHandler):
return user_id return user_id
@defer.inlineCallbacks @defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth() auth_api = self.hs.get_auth()
user_id = None user_id = None
try: try:
@ -769,11 +785,11 @@ class AuthHandler(BaseHandler):
return user_id return user_id
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_access_token(self, access_token): def delete_access_token(self, access_token: str):
"""Invalidate a single access token """Invalidate a single access token
Args: Args:
access_token (str): access token to be deleted access_token: access token to be deleted
Returns: Returns:
Deferred Deferred
@ -798,15 +814,17 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_access_tokens_for_user( def delete_access_tokens_for_user(
self, user_id, except_token_id=None, device_id=None self,
user_id: str,
except_token_id: Optional[str] = None,
device_id: Optional[str] = None,
): ):
"""Invalidate access tokens belonging to a user """Invalidate access tokens belonging to a user
Args: Args:
user_id (str): ID of user the tokens belong to user_id: ID of user the tokens belong to
except_token_id (str|None): access_token ID which should *not* be except_token_id: access_token ID which should *not* be deleted
deleted device_id: ID of device the tokens are associated with.
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
be deleted be deleted
Returns: Returns:
@ -830,7 +848,7 @@ class AuthHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
# check if medium has a valid value # check if medium has a valid value
if medium not in ["email", "msisdn"]: if medium not in ["email", "msisdn"]:
raise SynapseError( raise SynapseError(
@ -856,19 +874,20 @@ class AuthHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_threepid(self, user_id, medium, address, id_server=None): def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
):
"""Attempts to unbind the 3pid on the identity servers and deletes it """Attempts to unbind the 3pid on the identity servers and deletes it
from the local database. from the local database.
Args: Args:
user_id (str) user_id: ID of user to remove the 3pid from.
medium (str) medium: The medium of the 3pid being removed: "email" or "msisdn".
address (str) address: The 3pid address to remove.
id_server (str|None): Use the given identity server when unbinding id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known). identity server specified when binding (if known).
Returns: Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on Deferred[bool]: Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the the identity server, False if identity server doesn't support the
@ -887,17 +906,18 @@ class AuthHandler(BaseHandler):
yield self.store.user_delete_threepid(user_id, medium, address) yield self.store.user_delete_threepid(user_id, medium, address)
return result return result
def _save_session(self, session): def _save_session(self, session: Dict[str, Any]) -> None:
"""Update the last used time on the session to now and add it back to the session store."""
# TODO: Persistent storage # TODO: Persistent storage
logger.debug("Saving session %s", session) logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec() session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session self.sessions[session["id"]] = session
def hash(self, password): def hash(self, password: str):
"""Computes a secure hash of password. """Computes a secure hash of password.
Args: Args:
password (unicode): Password to hash. password: Password to hash.
Returns: Returns:
Deferred(unicode): Hashed password. Deferred(unicode): Hashed password.
@ -914,12 +934,12 @@ class AuthHandler(BaseHandler):
return defer_to_thread(self.hs.get_reactor(), _do_hash) return defer_to_thread(self.hs.get_reactor(), _do_hash)
def validate_hash(self, password, stored_hash): def validate_hash(self, password: str, stored_hash: bytes):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
Args: Args:
password (unicode): Password to hash. password: Password to hash.
stored_hash (bytes): Expected hash value. stored_hash: Expected hash value.
Returns: Returns:
Deferred(bool): Whether self.hash(password) == stored_hash. Deferred(bool): Whether self.hash(password) == stored_hash.
@ -1007,7 +1027,9 @@ class MacaroonGenerator(object):
hs = attr.ib() hs = attr.ib()
def generate_access_token(self, user_id, extra_caveats=None): def generate_access_token(
self, user_id: str, extra_caveats: Optional[List[str]] = None
) -> str:
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
@ -1020,16 +1042,9 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat(caveat) macaroon.add_first_party_caveat(caveat)
return macaroon.serialize() return macaroon.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): def generate_short_term_login_token(
""" self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
) -> str:
Args:
user_id (unicode):
duration_in_ms (int):
Returns:
unicode
"""
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login") macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
@ -1037,12 +1052,12 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize() return macaroon.serialize()
def generate_delete_pusher_token(self, user_id): def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher") macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize() return macaroon.serialize()
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
identifier="key", identifier="key",

View File

@ -185,6 +185,7 @@ commands = mypy \
synapse/federation/federation_client.py \ synapse/federation/federation_client.py \
synapse/federation/sender \ synapse/federation/sender \
synapse/federation/transport \ synapse/federation/transport \
synapse/handlers/auth.py \
synapse/handlers/directory.py \ synapse/handlers/directory.py \
synapse/handlers/presence.py \ synapse/handlers/presence.py \
synapse/handlers/sync.py \ synapse/handlers/sync.py \