Update the auth providers to be async. (#7935)

This commit is contained in:
Patrick Cloke 2020-07-23 15:45:39 -04:00 committed by GitHub
parent 7078866969
commit 83434df381
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 100 deletions

View file

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import time
import unicodedata
@ -863,11 +864,15 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
await provider.on_logged_out(
# This might return an awaitable, if it does block the log out
# until it completes.
result = provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
access_token=access_token,
)
if inspect.isawaitable(result):
await result
# delete pushers associated with this access token
if user_info["token_id"] is not None:

View file

@ -14,10 +14,10 @@
# limitations under the License.
import logging
from typing import Any
from canonicaljson import json
from twisted.internet import defer
from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
@ -33,25 +33,25 @@ class UserInteractiveAuthChecker:
def __init__(self, hs):
pass
def is_enabled(self):
def is_enabled(self) -> bool:
"""Check if the configuration of the homeserver allows this checker to work
Returns:
bool: True if this login type is enabled.
True if this login type is enabled.
"""
def check_auth(self, authdict, clientip):
async def check_auth(self, authdict: dict, clientip: str) -> Any:
"""Given the authentication dict from the client, attempt to check this step
Args:
authdict (dict): authentication dictionary from the client
clientip (str): The IP address of the client.
authdict: authentication dictionary from the client
clientip: The IP address of the client.
Raises:
SynapseError if authentication failed
Returns:
Deferred: the result of authentication (to pass back to the client?)
The result of authentication (to pass back to the client?)
"""
raise NotImplementedError()
@ -62,8 +62,8 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self):
return True
def check_auth(self, authdict, clientip):
return defer.succeed(True)
async def check_auth(self, authdict, clientip):
return True
class TermsAuthChecker(UserInteractiveAuthChecker):
@ -72,8 +72,8 @@ class TermsAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self):
return True
def check_auth(self, authdict, clientip):
return defer.succeed(True)
async def check_auth(self, authdict, clientip):
return True
class RecaptchaAuthChecker(UserInteractiveAuthChecker):
@ -89,8 +89,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
def is_enabled(self):
return self._enabled
@defer.inlineCallbacks
def check_auth(self, authdict, clientip):
async def check_auth(self, authdict, clientip):
try:
user_response = authdict["response"]
except KeyError:
@ -107,7 +106,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
resp_body = yield self._http_client.post_urlencoded_get_json(
resp_body = await self._http_client.post_urlencoded_get_json(
self._url,
args={
"secret": self._secret,
@ -219,8 +218,8 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
ThreepidBehaviour.LOCAL,
)
def check_auth(self, authdict, clientip):
return defer.ensureDeferred(self._check_threepid("email", authdict))
async def check_auth(self, authdict, clientip):
return await self._check_threepid("email", authdict)
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@ -233,8 +232,8 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
def is_enabled(self):
return bool(self.hs.config.account_threepid_delegate_msisdn)
def check_auth(self, authdict, clientip):
return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
async def check_auth(self, authdict, clientip):
return await self._check_threepid("msisdn", authdict)
INTERACTIVE_AUTH_CHECKERS = [