mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-10-01 08:25:44 -04:00
Avoid so much copypasta between 3PU and 3PL query by unifying around a ThirdPartyEntityKind enumeration
This commit is contained in:
parent
2a91799fcc
commit
b515f844ee
@ -17,6 +17,7 @@ from twisted.internet import defer
|
||||
from synapse.api.errors import CodeMessageException
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.types import ThirdPartyEntityKind
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_3pu(self, service, protocol, fields):
|
||||
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
|
||||
response = None
|
||||
try:
|
||||
response = yield self.get_json(uri, fields)
|
||||
defer.returnValue(response)
|
||||
except Exception as ex:
|
||||
logger.warning("query_3pu to %s threw exception %s", uri, ex)
|
||||
defer.returnValue([])
|
||||
def query_3pe(self, service, kind, protocol, fields):
|
||||
if kind == ThirdPartyEntityKind.USER:
|
||||
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
|
||||
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_3pl(self, service, protocol, fields):
|
||||
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
|
||||
response = None
|
||||
try:
|
||||
response = yield self.get_json(uri, fields)
|
||||
defer.returnValue(response)
|
||||
except Exception as ex:
|
||||
logger.warning("query_3pl to %s threw exception %s", uri, ex)
|
||||
logger.warning("query_3pe to %s threw exception %s", uri, ex)
|
||||
defer.returnValue([])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.types import ThirdPartyEntityKind
|
||||
|
||||
import logging
|
||||
|
||||
@ -169,14 +170,20 @@ class ApplicationServicesHandler(object):
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_3pu(self, protocol, fields):
|
||||
def query_3pe(self, kind, protocol, fields):
|
||||
services = yield self._get_services_for_3pn(protocol)
|
||||
|
||||
results = yield defer.DeferredList([
|
||||
self.appservice_api.query_3pu(service, protocol, fields)
|
||||
self.appservice_api.query_3pe(service, kind, protocol, fields)
|
||||
for service in services
|
||||
], consumeErrors=True)
|
||||
|
||||
required_field = (
|
||||
"userid" if kind == ThirdPartyEntityKind.USER else
|
||||
"alias" if kind == ThirdPartyEntityKind.LOCATION else
|
||||
None
|
||||
)
|
||||
|
||||
ret = []
|
||||
for (success, result) in results:
|
||||
if not success:
|
||||
@ -184,31 +191,7 @@ class ApplicationServicesHandler(object):
|
||||
if not isinstance(result, list):
|
||||
continue
|
||||
for r in result:
|
||||
if _is_valid_3pentity_result(r, field="userid"):
|
||||
ret.append(r)
|
||||
else:
|
||||
logger.warn("Application service returned an " +
|
||||
"invalid result %r", r)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_3pl(self, protocol, fields):
|
||||
services = yield self._get_services_for_3pn(protocol)
|
||||
|
||||
results = yield defer.DeferredList([
|
||||
self.appservice_api.query_3pl(service, protocol, fields)
|
||||
for service in services
|
||||
], consumeErrors=True)
|
||||
|
||||
ret = []
|
||||
for (success, result) in results:
|
||||
if not success:
|
||||
continue
|
||||
if not isinstance(result, list):
|
||||
continue
|
||||
for r in result:
|
||||
if _is_valid_3pentity_result(r, field="alias"):
|
||||
if _is_valid_3pentity_result(r, field=required_field):
|
||||
ret.append(r)
|
||||
else:
|
||||
logger.warn("Application service returned an " +
|
||||
|
@ -19,6 +19,7 @@ import logging
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.types import ThirdPartyEntityKind
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet):
|
||||
fields = request.args
|
||||
del fields["access_token"]
|
||||
|
||||
results = yield self.appservice_handler.query_3pu(protocol, fields)
|
||||
results = yield self.appservice_handler.query_3pe(
|
||||
ThirdPartyEntityKind.USER, protocol, fields
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
||||
@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet):
|
||||
fields = request.args
|
||||
del fields["access_token"]
|
||||
|
||||
results = yield self.appservice_handler.query_3pl(protocol, fields)
|
||||
results = yield self.appservice_handler.query_3pe(
|
||||
ThirdPartyEntityKind.LOCATION, protocol, fields
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
||||
|
@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||
return "t%d-%d" % (self.topological, self.stream)
|
||||
else:
|
||||
return "s%d" % (self.stream,)
|
||||
|
||||
|
||||
# Some arbitrary constants used for internal API enumerations. Don't rely on
|
||||
# exact values; always pass or compare symbolically
|
||||
class ThirdPartyEntityKind(object):
|
||||
USER = 'user'
|
||||
LOCATION = 'location'
|
||||
|
Loading…
Reference in New Issue
Block a user