Convert identity handler to async/await. (#7561)

This commit is contained in:
Patrick Cloke 2020-05-26 13:46:22 -04:00 committed by GitHub
parent edd9a7214c
commit ef884f6d04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 63 deletions

1
changelog.d/7561.misc Normal file
View File

@ -0,0 +1 @@
Convert the identity handler to async/await.

View File

@ -25,7 +25,6 @@ from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from twisted.internet import defer
from twisted.internet.error import TimeoutError from twisted.internet.error import TimeoutError
from synapse.api.errors import ( from synapse.api.errors import (
@ -60,8 +59,7 @@ class IdentityHandler(BaseHandler):
self.federation_http_client = hs.get_http_client() self.federation_http_client = hs.get_http_client()
self.hs = hs self.hs = hs
@defer.inlineCallbacks async def threepid_from_creds(self, id_server, creds):
def threepid_from_creds(self, id_server, creds):
""" """
Retrieve and validate a threepid identifier from a "credentials" dictionary against a Retrieve and validate a threepid identifier from a "credentials" dictionary against a
given identity server given identity server
@ -97,7 +95,7 @@ class IdentityHandler(BaseHandler):
url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid" url = id_server + "/_matrix/identity/api/v1/3pid/getValidated3pid"
try: try:
data = yield self.http_client.get_json(url, query_params) data = await self.http_client.get_json(url, query_params)
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e: except HttpResponseException as e:
@ -120,8 +118,7 @@ class IdentityHandler(BaseHandler):
logger.info("%s reported non-validated threepid: %s", id_server, creds) logger.info("%s reported non-validated threepid: %s", id_server, creds)
return None return None
@defer.inlineCallbacks async def bind_threepid(
def bind_threepid(
self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
): ):
"""Bind a 3PID to an identity server """Bind a 3PID to an identity server
@ -161,12 +158,12 @@ class IdentityHandler(BaseHandler):
try: try:
# Use the blacklisting http client as this call is only to identity servers # Use the blacklisting http client as this call is only to identity servers
# provided by a client # provided by a client
data = yield self.blacklisting_http_client.post_json_get_json( data = await self.blacklisting_http_client.post_json_get_json(
bind_url, bind_data, headers=headers bind_url, bind_data, headers=headers
) )
# Remember where we bound the threepid # Remember where we bound the threepid
yield self.store.add_user_bound_threepid( await self.store.add_user_bound_threepid(
user_id=mxid, user_id=mxid,
medium=data["medium"], medium=data["medium"],
address=data["address"], address=data["address"],
@ -185,13 +182,12 @@ class IdentityHandler(BaseHandler):
return data return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
res = yield self.bind_threepid( res = await self.bind_threepid(
client_secret, sid, mxid, id_server, id_access_token, use_v2=False client_secret, sid, mxid, id_server, id_access_token, use_v2=False
) )
return res return res
@defer.inlineCallbacks async def try_unbind_threepid(self, mxid, threepid):
def try_unbind_threepid(self, mxid, threepid):
"""Attempt to remove a 3PID from an identity server, or if one is not provided, all """Attempt to remove a 3PID from an identity server, or if one is not provided, all
identity servers we're aware the binding is present on identity servers we're aware the binding is present on
@ -211,7 +207,7 @@ class IdentityHandler(BaseHandler):
if threepid.get("id_server"): if threepid.get("id_server"):
id_servers = [threepid["id_server"]] id_servers = [threepid["id_server"]]
else: else:
id_servers = yield self.store.get_id_servers_user_bound( id_servers = await self.store.get_id_servers_user_bound(
user_id=mxid, medium=threepid["medium"], address=threepid["address"] user_id=mxid, medium=threepid["medium"], address=threepid["address"]
) )
@ -221,14 +217,13 @@ class IdentityHandler(BaseHandler):
changed = True changed = True
for id_server in id_servers: for id_server in id_servers:
changed &= yield self.try_unbind_threepid_with_id_server( changed &= await self.try_unbind_threepid_with_id_server(
mxid, threepid, id_server mxid, threepid, id_server
) )
return changed return changed
@defer.inlineCallbacks async def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
"""Removes a binding from an identity server """Removes a binding from an identity server
Args: Args:
@ -266,7 +261,7 @@ class IdentityHandler(BaseHandler):
try: try:
# Use the blacklisting http client as this call is only to identity servers # Use the blacklisting http client as this call is only to identity servers
# provided by a client # provided by a client
yield self.blacklisting_http_client.post_json_get_json( await self.blacklisting_http_client.post_json_get_json(
url, content, headers url, content, headers
) )
changed = True changed = True
@ -281,7 +276,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
yield self.store.remove_user_bound_threepid( await self.store.remove_user_bound_threepid(
user_id=mxid, user_id=mxid,
medium=threepid["medium"], medium=threepid["medium"],
address=threepid["address"], address=threepid["address"],
@ -376,8 +371,7 @@ class IdentityHandler(BaseHandler):
return session_id return session_id
@defer.inlineCallbacks async def requestEmailToken(
def requestEmailToken(
self, id_server, email, client_secret, send_attempt, next_link=None self, id_server, email, client_secret, send_attempt, next_link=None
): ):
""" """
@ -412,7 +406,7 @@ class IdentityHandler(BaseHandler):
) )
try: try:
data = yield self.http_client.post_json_get_json( data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/email/requestToken", id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
params, params,
) )
@ -423,8 +417,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
@defer.inlineCallbacks async def requestMsisdnToken(
def requestMsisdnToken(
self, self,
id_server, id_server,
country, country,
@ -466,7 +459,7 @@ class IdentityHandler(BaseHandler):
) )
try: try:
data = yield self.http_client.post_json_get_json( data = await self.http_client.post_json_get_json(
id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken", id_server + "/_matrix/identity/api/v1/validate/msisdn/requestToken",
params, params,
) )
@ -487,8 +480,7 @@ class IdentityHandler(BaseHandler):
) )
return data return data
@defer.inlineCallbacks async def validate_threepid_session(self, client_secret, sid):
def validate_threepid_session(self, client_secret, sid):
"""Validates a threepid session with only the client secret and session ID """Validates a threepid session with only the client secret and session ID
Tries validating against any configured account_threepid_delegates as well as locally. Tries validating against any configured account_threepid_delegates as well as locally.
@ -510,12 +502,12 @@ class IdentityHandler(BaseHandler):
# Try to validate as email # Try to validate as email
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
# Ask our delegated email identity server # Ask our delegated email identity server
validation_session = yield self.threepid_from_creds( validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds self.hs.config.account_threepid_delegate_email, threepid_creds
) )
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
# Get a validated session matching these details # Get a validated session matching these details
validation_session = yield self.store.get_threepid_validation_session( validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True "email", client_secret, sid=sid, validated=True
) )
@ -525,14 +517,13 @@ class IdentityHandler(BaseHandler):
# Try to validate as msisdn # Try to validate as msisdn
if self.hs.config.account_threepid_delegate_msisdn: if self.hs.config.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server # Ask our delegated msisdn identity server
validation_session = yield self.threepid_from_creds( validation_session = await self.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds self.hs.config.account_threepid_delegate_msisdn, threepid_creds
) )
return validation_session return validation_session
@defer.inlineCallbacks async def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
"""Proxy a POST submitToken request to an identity server for verification purposes """Proxy a POST submitToken request to an identity server for verification purposes
Args: Args:
@ -553,11 +544,9 @@ class IdentityHandler(BaseHandler):
body = {"client_secret": client_secret, "sid": sid, "token": token} body = {"client_secret": client_secret, "sid": sid, "token": token}
try: try:
return ( return await self.http_client.post_json_get_json(
yield self.http_client.post_json_get_json( id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken", body,
body,
)
) )
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
@ -565,8 +554,7 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server") raise SynapseError(400, "Error contacting the identity server")
@defer.inlineCallbacks async def lookup_3pid(self, id_server, medium, address, id_access_token=None):
def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server. """Looks up a 3pid in the passed identity server.
Args: Args:
@ -582,7 +570,7 @@ class IdentityHandler(BaseHandler):
""" """
if id_access_token is not None: if id_access_token is not None:
try: try:
results = yield self._lookup_3pid_v2( results = await self._lookup_3pid_v2(
id_server, id_access_token, medium, address id_server, id_access_token, medium, address
) )
return results return results
@ -601,10 +589,9 @@ class IdentityHandler(BaseHandler):
logger.warning("Error when looking up hashing details: %s", e) logger.warning("Error when looking up hashing details: %s", e)
return None return None
return (yield self._lookup_3pid_v1(id_server, medium, address)) return await self._lookup_3pid_v1(id_server, medium, address)
@defer.inlineCallbacks async def _lookup_3pid_v1(self, id_server, medium, address):
def _lookup_3pid_v1(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup. """Looks up a 3pid in the passed identity server using v1 lookup.
Args: Args:
@ -617,7 +604,7 @@ class IdentityHandler(BaseHandler):
str: the matrix ID of the 3pid, or None if it is not recognized. str: the matrix ID of the 3pid, or None if it is not recognized.
""" """
try: try:
data = yield self.blacklisting_http_client.get_json( data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server), "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address}, {"medium": medium, "address": address},
) )
@ -625,7 +612,7 @@ class IdentityHandler(BaseHandler):
if "mxid" in data: if "mxid" in data:
if "signatures" not in data: if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding") raise AuthError(401, "No signatures on 3pid binding")
yield self._verify_any_signature(data, id_server) await self._verify_any_signature(data, id_server)
return data["mxid"] return data["mxid"]
except TimeoutError: except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server") raise SynapseError(500, "Timed out contacting identity server")
@ -634,8 +621,7 @@ class IdentityHandler(BaseHandler):
return None return None
@defer.inlineCallbacks async def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup. """Looks up a 3pid in the passed identity server using v2 lookup.
Args: Args:
@ -650,7 +636,7 @@ class IdentityHandler(BaseHandler):
""" """
# Check what hashing details are supported by this identity server # Check what hashing details are supported by this identity server
try: try:
hash_details = yield self.blacklisting_http_client.get_json( hash_details = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server), "%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token}, {"access_token": id_access_token},
) )
@ -717,7 +703,7 @@ class IdentityHandler(BaseHandler):
headers = {"Authorization": create_id_access_token_header(id_access_token)} headers = {"Authorization": create_id_access_token_header(id_access_token)}
try: try:
lookup_results = yield self.blacklisting_http_client.post_json_get_json( lookup_results = await self.blacklisting_http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server), "%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{ {
"addresses": [lookup_value], "addresses": [lookup_value],
@ -745,13 +731,12 @@ class IdentityHandler(BaseHandler):
mxid = lookup_results["mappings"].get(lookup_value) mxid = lookup_results["mappings"].get(lookup_value)
return mxid return mxid
@defer.inlineCallbacks async def _verify_any_signature(self, data, server_hostname):
def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]: if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,)) raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items(): for key_name, signature in data["signatures"][server_hostname].items():
try: try:
key_data = yield self.blacklisting_http_client.get_json( key_data = await self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" "%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name) % (id_server_scheme, server_hostname, key_name)
) )
@ -770,8 +755,7 @@ class IdentityHandler(BaseHandler):
) )
return return
@defer.inlineCallbacks async def ask_id_server_for_third_party_invite(
def ask_id_server_for_third_party_invite(
self, self,
requester, requester,
id_server, id_server,
@ -844,7 +828,7 @@ class IdentityHandler(BaseHandler):
# Attempt a v2 lookup # Attempt a v2 lookup
url = base_url + "/v2/store-invite" url = base_url + "/v2/store-invite"
try: try:
data = yield self.blacklisting_http_client.post_json_get_json( data = await self.blacklisting_http_client.post_json_get_json(
url, url,
invite_config, invite_config,
{"Authorization": create_id_access_token_header(id_access_token)}, {"Authorization": create_id_access_token_header(id_access_token)},
@ -864,7 +848,7 @@ class IdentityHandler(BaseHandler):
url = base_url + "/api/v1/store-invite" url = base_url + "/api/v1/store-invite"
try: try:
data = yield self.blacklisting_http_client.post_json_get_json( data = await self.blacklisting_http_client.post_json_get_json(
url, invite_config url, invite_config
) )
except TimeoutError: except TimeoutError:
@ -882,7 +866,7 @@ class IdentityHandler(BaseHandler):
# types. This is especially true with old instances of Sydent, see # types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170 # https://github.com/matrix-org/sydent/pull/170
try: try:
data = yield self.blacklisting_http_client.post_urlencoded_get_json( data = await self.blacklisting_http_client.post_urlencoded_get_json(
url, invite_config url, invite_config
) )
except HttpResponseException as e: except HttpResponseException as e:

View File

@ -138,8 +138,7 @@ class _BaseThreepidAuthChecker:
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def _check_threepid(self, medium, authdict):
def _check_threepid(self, medium, authdict):
if "threepid_creds" not in authdict: if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM) raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
@ -155,18 +154,18 @@ class _BaseThreepidAuthChecker:
raise SynapseError( raise SynapseError(
400, "Phone number verification is not enabled on this homeserver" 400, "Phone number verification is not enabled on this homeserver"
) )
threepid = yield identity_handler.threepid_from_creds( threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_msisdn, threepid_creds self.hs.config.account_threepid_delegate_msisdn, threepid_creds
) )
elif medium == "email": elif medium == "email":
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
assert self.hs.config.account_threepid_delegate_email assert self.hs.config.account_threepid_delegate_email
threepid = yield identity_handler.threepid_from_creds( threepid = await identity_handler.threepid_from_creds(
self.hs.config.account_threepid_delegate_email, threepid_creds self.hs.config.account_threepid_delegate_email, threepid_creds
) )
elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
threepid = None threepid = None
row = yield self.store.get_threepid_validation_session( row = await self.store.get_threepid_validation_session(
medium, medium,
threepid_creds["client_secret"], threepid_creds["client_secret"],
sid=threepid_creds["sid"], sid=threepid_creds["sid"],
@ -181,7 +180,7 @@ class _BaseThreepidAuthChecker:
} }
# Valid threepid returned, delete from the db # Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"]) await self.store.delete_threepid_session(threepid_creds["sid"])
else: else:
raise SynapseError( raise SynapseError(
400, "Email address verification is not enabled on this homeserver" 400, "Email address verification is not enabled on this homeserver"
@ -220,7 +219,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
) )
def check_auth(self, authdict, clientip): def check_auth(self, authdict, clientip):
return self._check_threepid("email", authdict) return defer.ensureDeferred(self._check_threepid("email", authdict))
class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
@ -234,7 +233,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
return bool(self.hs.config.account_threepid_delegate_msisdn) return bool(self.hs.config.account_threepid_delegate_msisdn)
def check_auth(self, authdict, clientip): def check_auth(self, authdict, clientip):
return self._check_threepid("msisdn", authdict) return defer.ensureDeferred(self._check_threepid("msisdn", authdict))
INTERACTIVE_AUTH_CHECKERS = [ INTERACTIVE_AUTH_CHECKERS = [