Avoid so much copypasta between 3PU and 3PL query by unifying around a ThirdPartyEntityKind enumeration

This commit is contained in:
Paul "LeoNerd" Evans 2016-08-18 17:19:55 +01:00
parent 2a91799fcc
commit b515f844ee
4 changed files with 35 additions and 43 deletions

View File

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind
import logging import logging
import urllib import urllib
@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pu(self, service, protocol, fields): def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol)) uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
response = None elif kind == ThirdPartyEntityKind.LOCATION:
try:
response = yield self.get_json(uri, fields)
defer.returnValue(response)
except Exception as ex:
logger.warning("query_3pu to %s threw exception %s", uri, ex)
defer.returnValue([])
@defer.inlineCallbacks
def query_3pl(self, service, protocol, fields):
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol)) uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
response = None else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
try: try:
response = yield self.get_json(uri, fields) response = yield self.get_json(uri, fields)
defer.returnValue(response) defer.returnValue(response)
except Exception as ex: except Exception as ex:
logger.warning("query_3pl to %s threw exception %s", uri, ex) logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([]) defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.types import ThirdPartyEntityKind
import logging import logging
@ -169,14 +170,20 @@ class ApplicationServicesHandler(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pu(self, protocol, fields): def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol) services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([ results = yield defer.DeferredList([
self.appservice_api.query_3pu(service, protocol, fields) self.appservice_api.query_3pe(service, kind, protocol, fields)
for service in services for service in services
], consumeErrors=True) ], consumeErrors=True)
required_field = (
"userid" if kind == ThirdPartyEntityKind.USER else
"alias" if kind == ThirdPartyEntityKind.LOCATION else
None
)
ret = [] ret = []
for (success, result) in results: for (success, result) in results:
if not success: if not success:
@ -184,31 +191,7 @@ class ApplicationServicesHandler(object):
if not isinstance(result, list): if not isinstance(result, list):
continue continue
for r in result: for r in result:
if _is_valid_3pentity_result(r, field="userid"): if _is_valid_3pentity_result(r, field=required_field):
ret.append(r)
else:
logger.warn("Application service returned an " +
"invalid result %r", r)
defer.returnValue(ret)
@defer.inlineCallbacks
def query_3pl(self, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([
self.appservice_api.query_3pl(service, protocol, fields)
for service in services
], consumeErrors=True)
ret = []
for (success, result) in results:
if not success:
continue
if not isinstance(result, list):
continue
for r in result:
if _is_valid_3pentity_result(r, field="alias"):
ret.append(r) ret.append(r)
else: else:
logger.warn("Application service returned an " + logger.warn("Application service returned an " +

View File

@ -19,6 +19,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet):
fields = request.args fields = request.args
del fields["access_token"] del fields["access_token"]
results = yield self.appservice_handler.query_3pu(protocol, fields) results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
defer.returnValue((200, results)) defer.returnValue((200, results))
@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet):
fields = request.args fields = request.args
del fields["access_token"] del fields["access_token"]
results = yield self.appservice_handler.query_3pl(protocol, fields) results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
defer.returnValue((200, results)) defer.returnValue((200, results))

View File

@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'