mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-11 11:24:21 -05:00
Avoid so much copypasta between 3PU and 3PL query by unifying around a ThirdPartyEntityKind enumeration
This commit is contained in:
parent
2a91799fcc
commit
b515f844ee
@ -17,6 +17,7 @@ from twisted.internet import defer
|
|||||||
from synapse.api.errors import CodeMessageException
|
from synapse.api.errors import CodeMessageException
|
||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
|
from synapse.types import ThirdPartyEntityKind
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import urllib
|
import urllib
|
||||||
@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||||||
defer.returnValue(False)
|
defer.returnValue(False)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_3pu(self, service, protocol, fields):
|
def query_3pe(self, service, kind, protocol, fields):
|
||||||
|
if kind == ThirdPartyEntityKind.USER:
|
||||||
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
|
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
|
||||||
response = None
|
elif kind == ThirdPartyEntityKind.LOCATION:
|
||||||
try:
|
|
||||||
response = yield self.get_json(uri, fields)
|
|
||||||
defer.returnValue(response)
|
|
||||||
except Exception as ex:
|
|
||||||
logger.warning("query_3pu to %s threw exception %s", uri, ex)
|
|
||||||
defer.returnValue([])
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def query_3pl(self, service, protocol, fields):
|
|
||||||
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
|
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
|
||||||
response = None
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognised 'kind' argument %r to query_3pe()", kind
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = yield self.get_json(uri, fields)
|
response = yield self.get_json(uri, fields)
|
||||||
defer.returnValue(response)
|
defer.returnValue(response)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning("query_3pl to %s threw exception %s", uri, ex)
|
logger.warning("query_3pe to %s threw exception %s", uri, ex)
|
||||||
defer.returnValue([])
|
defer.returnValue([])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
|||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.logcontext import preserve_fn
|
from synapse.util.logcontext import preserve_fn
|
||||||
|
from synapse.types import ThirdPartyEntityKind
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -169,14 +170,20 @@ class ApplicationServicesHandler(object):
|
|||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_3pu(self, protocol, fields):
|
def query_3pe(self, kind, protocol, fields):
|
||||||
services = yield self._get_services_for_3pn(protocol)
|
services = yield self._get_services_for_3pn(protocol)
|
||||||
|
|
||||||
results = yield defer.DeferredList([
|
results = yield defer.DeferredList([
|
||||||
self.appservice_api.query_3pu(service, protocol, fields)
|
self.appservice_api.query_3pe(service, kind, protocol, fields)
|
||||||
for service in services
|
for service in services
|
||||||
], consumeErrors=True)
|
], consumeErrors=True)
|
||||||
|
|
||||||
|
required_field = (
|
||||||
|
"userid" if kind == ThirdPartyEntityKind.USER else
|
||||||
|
"alias" if kind == ThirdPartyEntityKind.LOCATION else
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
for (success, result) in results:
|
for (success, result) in results:
|
||||||
if not success:
|
if not success:
|
||||||
@ -184,31 +191,7 @@ class ApplicationServicesHandler(object):
|
|||||||
if not isinstance(result, list):
|
if not isinstance(result, list):
|
||||||
continue
|
continue
|
||||||
for r in result:
|
for r in result:
|
||||||
if _is_valid_3pentity_result(r, field="userid"):
|
if _is_valid_3pentity_result(r, field=required_field):
|
||||||
ret.append(r)
|
|
||||||
else:
|
|
||||||
logger.warn("Application service returned an " +
|
|
||||||
"invalid result %r", r)
|
|
||||||
|
|
||||||
defer.returnValue(ret)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def query_3pl(self, protocol, fields):
|
|
||||||
services = yield self._get_services_for_3pn(protocol)
|
|
||||||
|
|
||||||
results = yield defer.DeferredList([
|
|
||||||
self.appservice_api.query_3pl(service, protocol, fields)
|
|
||||||
for service in services
|
|
||||||
], consumeErrors=True)
|
|
||||||
|
|
||||||
ret = []
|
|
||||||
for (success, result) in results:
|
|
||||||
if not success:
|
|
||||||
continue
|
|
||||||
if not isinstance(result, list):
|
|
||||||
continue
|
|
||||||
for r in result:
|
|
||||||
if _is_valid_3pentity_result(r, field="alias"):
|
|
||||||
ret.append(r)
|
ret.append(r)
|
||||||
else:
|
else:
|
||||||
logger.warn("Application service returned an " +
|
logger.warn("Application service returned an " +
|
||||||
|
@ -19,6 +19,7 @@ import logging
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet
|
||||||
|
from synapse.types import ThirdPartyEntityKind
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet):
|
|||||||
fields = request.args
|
fields = request.args
|
||||||
del fields["access_token"]
|
del fields["access_token"]
|
||||||
|
|
||||||
results = yield self.appservice_handler.query_3pu(protocol, fields)
|
results = yield self.appservice_handler.query_3pe(
|
||||||
|
ThirdPartyEntityKind.USER, protocol, fields
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((200, results))
|
defer.returnValue((200, results))
|
||||||
|
|
||||||
@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet):
|
|||||||
fields = request.args
|
fields = request.args
|
||||||
del fields["access_token"]
|
del fields["access_token"]
|
||||||
|
|
||||||
results = yield self.appservice_handler.query_3pl(protocol, fields)
|
results = yield self.appservice_handler.query_3pe(
|
||||||
|
ThirdPartyEntityKind.LOCATION, protocol, fields
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue((200, results))
|
defer.returnValue((200, results))
|
||||||
|
|
||||||
|
@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||||||
return "t%d-%d" % (self.topological, self.stream)
|
return "t%d-%d" % (self.topological, self.stream)
|
||||||
else:
|
else:
|
||||||
return "s%d" % (self.stream,)
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
|
||||||
|
# Some arbitrary constants used for internal API enumerations. Don't rely on
|
||||||
|
# exact values; always pass or compare symbolically
|
||||||
|
class ThirdPartyEntityKind(object):
|
||||||
|
USER = 'user'
|
||||||
|
LOCATION = 'location'
|
||||||
|
Loading…
Reference in New Issue
Block a user