Move validation logic for AS 3PE query response into ApplicationServiceApi class, to keep the handler logic neater

This commit is contained in:
Paul "LeoNerd" Evans 2016-08-18 17:33:56 +01:00
parent 697872cf08
commit 65201631a4
2 changed files with 44 additions and 45 deletions

View File

@ -25,6 +25,28 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
for k in fields.keys():
if not isinstance(fields[k], str):
return False
return True
class ApplicationServiceApi(SimpleHttpClient): class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and """This class manages HS -> AS communications, including querying and
pushing. pushing.
@ -76,8 +98,10 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_3pe(self, service, kind, protocol, fields): def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER: if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION: elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
required_field = "alias"
else: else:
raise ValueError( raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind "Unrecognised 'kind' argument %r to query_3pe()", kind
@ -85,7 +109,24 @@ class ApplicationServiceApi(SimpleHttpClient):
try: try:
response = yield self.get_json(uri, fields) response = yield self.get_json(uri, fields)
defer.returnValue(response) if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
uri, response
)
defer.returnValue([])
ret = []
for r in response:
if _is_valid_3pe_result(r, field=required_field):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r",
uri, r
)
defer.returnValue(ret)
except Exception as ex: except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex) logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([]) defer.returnValue([])

View File

@ -18,7 +18,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.types import ThirdPartyEntityKind
import logging import logging
@ -36,28 +35,6 @@ def log_failure(failure):
) )
def _is_valid_3pentity_result(r, field):
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
for k in fields.keys():
if not isinstance(fields[k], str):
return False
return True
class ApplicationServicesHandler(object): class ApplicationServicesHandler(object):
def __init__(self, hs): def __init__(self, hs):
@ -178,29 +155,10 @@ class ApplicationServicesHandler(object):
for service in services for service in services
], consumeErrors=True) ], consumeErrors=True)
required_field = (
"userid" if kind == ThirdPartyEntityKind.USER else
"alias" if kind == ThirdPartyEntityKind.LOCATION else
None
)
ret = [] ret = []
for (success, result) in results: for (success, result) in results:
if not success: if success:
logger.warn("Application service failed %r", result) ret.extend(result)
continue
if not isinstance(result, list):
logger.warn(
"Application service returned an invalid response %r", result
)
continue
for r in result:
if _is_valid_3pentity_result(r, field=required_field):
ret.append(r)
else:
logger.warn(
"Application service returned an invalid result %r", r
)
defer.returnValue(ret) defer.returnValue(ret)