diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d4cad1b1e..dd5e762e0 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.api.errors import CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.events.utils import serialize_event +from synapse.types import ThirdPartyEntityKind import logging import urllib @@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient): defer.returnValue(False) @defer.inlineCallbacks - def query_3pu(self, service, protocol, fields): - uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) - response = None - try: - response = yield self.get_json(uri, fields) - defer.returnValue(response) - except Exception as ex: - logger.warning("query_3pu to %s threw exception %s", uri, ex) - defer.returnValue([]) + def query_3pe(self, service, kind, protocol, fields): + if kind == ThirdPartyEntityKind.USER: + uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) + elif kind == ThirdPartyEntityKind.LOCATION: + uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) + else: + raise ValueError( + "Unrecognised 'kind' argument %r to query_3pe()", kind + ) - @defer.inlineCallbacks - def query_3pl(self, service, protocol, fields): - uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) - response = None try: response = yield self.get_json(uri, fields) defer.returnValue(response) except Exception as ex: - logger.warning("query_3pl to %s threw exception %s", uri, ex) + logger.warning("query_3pe to %s threw exception %s", uri, ex) defer.returnValue([]) @defer.inlineCallbacks diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 03452f6bb..52c127d2c 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes from synapse.util.metrics import Measure from synapse.util.logcontext import preserve_fn +from synapse.types import ThirdPartyEntityKind import logging @@ -169,14 +170,20 @@ class ApplicationServicesHandler(object): defer.returnValue(result) @defer.inlineCallbacks - def query_3pu(self, protocol, fields): + def query_3pe(self, kind, protocol, fields): services = yield self._get_services_for_3pn(protocol) results = yield defer.DeferredList([ - self.appservice_api.query_3pu(service, protocol, fields) + self.appservice_api.query_3pe(service, kind, protocol, fields) for service in services ], consumeErrors=True) + required_field = ( + "userid" if kind == ThirdPartyEntityKind.USER else + "alias" if kind == ThirdPartyEntityKind.LOCATION else + None + ) + ret = [] for (success, result) in results: if not success: @@ -184,31 +191,7 @@ class ApplicationServicesHandler(object): if not isinstance(result, list): continue for r in result: - if _is_valid_3pentity_result(r, field="userid"): - ret.append(r) - else: - logger.warn("Application service returned an " + - "invalid result %r", r) - - defer.returnValue(ret) - - @defer.inlineCallbacks - def query_3pl(self, protocol, fields): - services = yield self._get_services_for_3pn(protocol) - - results = yield defer.DeferredList([ - self.appservice_api.query_3pl(service, protocol, fields) - for service in services - ], consumeErrors=True) - - ret = [] - for (success, result) in results: - if not success: - continue - if not isinstance(result, list): - continue - for r in result: - if _is_valid_3pentity_result(r, field="alias"): + if _is_valid_3pentity_result(r, field=required_field): ret.append(r) else: logger.warn("Application service returned an " + diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index d229e4b81..9abca3a8a 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -19,6 +19,7 @@ import logging from twisted.internet import defer from synapse.http.servlet import RestServlet +from synapse.types import ThirdPartyEntityKind from ._base import client_v2_patterns logger = logging.getLogger(__name__) @@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet): fields = request.args del fields["access_token"] - results = yield self.appservice_handler.query_3pu(protocol, fields) + results = yield self.appservice_handler.query_3pe( + ThirdPartyEntityKind.USER, protocol, fields + ) defer.returnValue((200, results)) @@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet): fields = request.args del fields["access_token"] - results = yield self.appservice_handler.query_3pl(protocol, fields) + results = yield self.appservice_handler.query_3pe( + ThirdPartyEntityKind.LOCATION, protocol, fields + ) defer.returnValue((200, results)) diff --git a/synapse/types.py b/synapse/types.py index 5349b0c45..fd17ecbbe 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): return "t%d-%d" % (self.topological, self.stream) else: return "s%d" % (self.stream,) + + +# Some arbitrary constants used for internal API enumerations. Don't rely on +# exact values; always pass or compare symbolically +class ThirdPartyEntityKind(object): + USER = 'user' + LOCATION = 'location'