Update the auth providers to be async. (#7935)

This commit is contained in:
Patrick Cloke 2020-07-23 15:45:39 -04:00 committed by GitHub
parent 7078866969
commit 83434df381
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 100 deletions

1
changelog.d/7935.misc Normal file
View File

@ -0,0 +1 @@
Convert the auth providers to be async/await.

View File

@ -19,102 +19,103 @@ password auth provider module implementations:
Password auth provider classes must provide the following methods: Password auth provider classes must provide the following methods:
*class* `SomeProvider.parse_config`(*config*) * `parse_config(config)`
This method is passed the `config` object for this module from the
homeserver configuration file.
> This method is passed the `config` object for this module from the It should perform any appropriate sanity checks on the provided
> homeserver configuration file. configuration, and return an object which is then passed into
>
> It should perform any appropriate sanity checks on the provided
> configuration, and return an object which is then passed into
> `__init__`.
*class* `SomeProvider`(*config*, *account_handler*) This method should have the `@staticmethod` decoration.
> The constructor is passed the config object returned by * `__init__(self, config, account_handler)`
> `parse_config`, and a `synapse.module_api.ModuleApi` object which
> allows the password provider to check if accounts exist and/or create The constructor is passed the config object returned by
> new ones. `parse_config`, and a `synapse.module_api.ModuleApi` object which
allows the password provider to check if accounts exist and/or create
new ones.
## Optional methods ## Optional methods
Password auth provider classes may optionally provide the following Password auth provider classes may optionally provide the following methods:
methods.
*class* `SomeProvider.get_db_schema_files`() * `get_db_schema_files(self)`
> This method, if implemented, should return an Iterable of This method, if implemented, should return an Iterable of
> `(name, stream)` pairs of database schema files. Each file is applied `(name, stream)` pairs of database schema files. Each file is applied
> in turn at initialisation, and a record is then made in the database in turn at initialisation, and a record is then made in the database
> so that it is not re-applied on the next start. so that it is not re-applied on the next start.
`someprovider.get_supported_login_types`() * `get_supported_login_types(self)`
> This method, if implemented, should return a `dict` mapping from a This method, if implemented, should return a `dict` mapping from a
> login type identifier (such as `m.login.password`) to an iterable login type identifier (such as `m.login.password`) to an iterable
> giving the fields which must be provided by the user in the submission giving the fields which must be provided by the user in the submission
> to the `/login` api. These fields are passed in the `login_dict` to [the `/login` API](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login).
> dictionary to `check_auth`. These fields are passed in the `login_dict` dictionary to `check_auth`.
>
> For example, if a password auth provider wants to implement a custom
> login type of `com.example.custom_login`, where the client is expected
> to pass the fields `secret1` and `secret2`, the provider should
> implement this method and return the following dict:
>
> {"com.example.custom_login": ("secret1", "secret2")}
`someprovider.check_auth`(*username*, *login_type*, *login_dict*) For example, if a password auth provider wants to implement a custom
login type of `com.example.custom_login`, where the client is expected
to pass the fields `secret1` and `secret2`, the provider should
implement this method and return the following dict:
> This method is the one that does the real work. If implemented, it ```python
> will be called for each login attempt where the login type matches one {"com.example.custom_login": ("secret1", "secret2")}
> of the keys returned by `get_supported_login_types`. ```
>
> It is passed the (possibly UNqualified) `user` provided by the client,
> the login type, and a dictionary of login secrets passed by the
> client.
>
> The method should return a Twisted `Deferred` object, which resolves
> to the canonical `@localpart:domain` user id if authentication is
> successful, and `None` if not.
>
> Alternatively, the `Deferred` can resolve to a `(str, func)` tuple, in
> which case the second field is a callback which will be called with
> the result from the `/login` call (including `access_token`,
> `device_id`, etc.)
`someprovider.check_3pid_auth`(*medium*, *address*, *password*) * `check_auth(self, username, login_type, login_dict)`
> This method, if implemented, is called when a user attempts to This method does the real work. If implemented, it
> register or log in with a third party identifier, such as email. It is will be called for each login attempt where the login type matches one
> passed the medium (ex. "email"), an address (ex. of the keys returned by `get_supported_login_types`.
> "<jdoe@example.com>") and the user's password.
>
> The method should return a Twisted `Deferred` object, which resolves
> to a `str` containing the user's (canonical) User ID if
> authentication was successful, and `None` if not.
>
> As with `check_auth`, the `Deferred` may alternatively resolve to a
> `(user_id, callback)` tuple.
`someprovider.check_password`(*user_id*, *password*) It is passed the (possibly unqualified) `user` field provided by the client,
the login type, and a dictionary of login secrets passed by the
client.
> This method provides a simpler interface than The method should return an `Awaitable` object, which resolves
> `get_supported_login_types` and `check_auth` for password auth to the canonical `@localpart:domain` user ID if authentication is
> providers that just want to provide a mechanism for validating successful, and `None` if not.
> `m.login.password` logins.
>
> Iif implemented, it will be called to check logins with an
> `m.login.password` login type. It is passed a qualified
> `@localpart:domain` user id, and the password provided by the user.
>
> The method should return a Twisted `Deferred` object, which resolves
> to `True` if authentication is successful, and `False` if not.
`someprovider.on_logged_out`(*user_id*, *device_id*, *access_token*) Alternatively, the `Awaitable` can resolve to a `(str, func)` tuple, in
which case the second field is a callback which will be called with
the result from the `/login` call (including `access_token`,
`device_id`, etc.)
> This method, if implemented, is called when a user logs out. It is * `check_3pid_auth(self, medium, address, password)`
> passed the qualified user ID, the ID of the deactivated device (if
> any: access tokens are occasionally created without an associated This method, if implemented, is called when a user attempts to
> device ID), and the (now deactivated) access token. register or log in with a third party identifier, such as email. It is
> passed the medium (ex. "email"), an address (ex.
> It may return a Twisted `Deferred` object; the logout request will "<jdoe@example.com>") and the user's password.
> wait for the deferred to complete but the result is ignored.
The method should return an `Awaitable` object, which resolves
to a `str` containing the user's (canonical) User id if
authentication was successful, and `None` if not.
As with `check_auth`, the `Awaitable` may alternatively resolve to a
`(user_id, callback)` tuple.
* `check_password(self, user_id, password)`
This method provides a simpler interface than
`get_supported_login_types` and `check_auth` for password auth
providers that just want to provide a mechanism for validating
`m.login.password` logins.
If implemented, it will be called to check logins with an
`m.login.password` login type. It is passed a qualified
`@localpart:domain` user id, and the password provided by the user.
The method should return an `Awaitable` object, which resolves
to `True` if authentication is successful, and `False` if not.
* `on_logged_out(self, user_id, device_id, access_token)`
This method, if implemented, is called when a user logs out. It is
passed the qualified user ID, the ID of the deactivated device (if
any: access tokens are occasionally created without an associated
device ID), and the (now deactivated) access token.
It may return an `Awaitable` object; the logout request will
wait for the `Awaitable` to complete, but the result is ignored.

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
import time import time
import unicodedata import unicodedata
@ -863,11 +864,15 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this # see if any of our auth providers want to know about this
for provider in self.password_providers: for provider in self.password_providers:
if hasattr(provider, "on_logged_out"): if hasattr(provider, "on_logged_out"):
await provider.on_logged_out( # This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
user_id=str(user_info["user"]), user_id=str(user_info["user"]),
device_id=user_info["device_id"], device_id=user_info["device_id"],
access_token=access_token, access_token=access_token,
) )
if inspect.isawaitable(result):
await result
# delete pushers associated with this access token # delete pushers associated with this access token
if user_info["token_id"] is not None: if user_info["token_id"] is not None:

View File

@ -14,10 +14,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
@ -33,25 +33,25 @@ class UserInteractiveAuthChecker:
def __init__(self, hs): def __init__(self, hs):
pass pass
def is_enabled(self): def is_enabled(self) -> bool:
"""Check if the configuration of the homeserver allows this checker to work """Check if the configuration of the homeserver allows this checker to work
Returns: Returns:
bool: True if this login type is enabled. True if this login type is enabled.
""" """
def check_auth(self, authdict, clientip): async def check_auth(self, authdict: dict, clientip: str) -> Any:
"""Given the authentication dict from the client, attempt to check this step """Given the authentication dict from the client, attempt to check this step
Args: Args:
authdict (dict): authentication dictionary from the client authdict: authentication dictionary from the client
clientip (str): The IP address of the client. clientip: The IP address of the client.
Raises: Raises:
SynapseError if authentication failed SynapseError if authentication failed
Returns: Returns:
Deferred: the result of authentication (to pass back to the client?) The result of authentication (to pass back to the client?)
""" """
raise NotImplementedError() raise NotImplementedError()
@ -62,8 +62,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self): def is_enabled(self):
return True return True
def check_auth(self, authdict, clientip): async def check_auth(self, authdict, clientip):
return defer.succeed(True) return True
class TermsAuthChecker(UserInteractiveAuthChecker): class TermsAuthChecker(UserInteractiveAuthChecker):
@ -72,8 +72,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self): def is_enabled(self):
return True return True
def check_auth(self, authdict, clientip): async def check_auth(self, authdict, clientip):
return defer.succeed(True) return True
class RecaptchaAuthChecker(UserInteractiveAuthChecker): class RecaptchaAuthChecker(UserInteractiveAuthChecker):
@ -89,8 +89,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self): def is_enabled(self):
return self._enabled return self._enabled
@defer.inlineCallbacks async def check_auth(self, authdict, clientip):
def check_auth(self, authdict, clientip):
try: try:
user_response = authdict["response"] user_response = authdict["response"]
except KeyError: except KeyError:
@ -107,7 +106,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
# each request # each request
try: try:
resp_body = yield self._http_client.post_urlencoded_get_json( resp_body = await self._http_client.post_urlencoded_get_json(
self._url, self._url,
args={ args={
"secret": self._secret, "secret": self._secret,
@ -219,8 +218,8 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
ThreepidBehaviour.LOCAL, ThreepidBehaviour.LOCAL,
) )
def check_auth(self, authdict, clientip): async def check_auth(self, authdict, clientip):
return defer.ensureDeferred(self._check_threepid("email", authdict)) return await self._check_threepid("email", authdict)
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@ -233,8 +232,8 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
def is_enabled(self): def is_enabled(self):
return bool(self.hs.config.account_threepid_delegate_msisdn) return bool(self.hs.config.account_threepid_delegate_msisdn)
def check_auth(self, authdict, clientip): async def check_auth(self, authdict, clientip):
return defer.ensureDeferred(self._check_threepid("msisdn", authdict)) return await self._check_threepid("msisdn", authdict)
INTERACTIVE_AUTH_CHECKERS = [ INTERACTIVE_AUTH_CHECKERS = [