Use new Federation Query API to implement HS->HS fetching of remote users' profile information instead of (ab)using the client-side REST API

This commit is contained in:
Paul "LeoNerd" Evans 2014-08-13 17:12:50 +01:00
parent 827de7cee9
commit 505917cb97
3 changed files with 71 additions and 33 deletions

View File

@ -26,15 +26,16 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PREFIX = "/matrix/client/api/v1"
class ProfileHandler(BaseHandler): class ProfileHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(ProfileHandler, self).__init__(hs) super(ProfileHandler, self).__init__(hs)
self.client = hs.get_http_client() self.federation = hs.get_replication_layer()
self.federation.register_query_handler(
"profile", self.on_profile_query
)
distributor = hs.get_distributor() distributor = hs.get_distributor()
self.distributor = distributor self.distributor = distributor
@ -57,17 +58,14 @@ class ProfileHandler(BaseHandler):
defer.returnValue(displayname) defer.returnValue(displayname)
elif not local_only: elif not local_only:
# TODO(paul): This should use the server-server API to ask another
# HS. For now we'll just have it use the http client to talk to the
# other HS's REST client API
path = PREFIX + "/profile/%s/displayname?local_only=1" % (
target_user.to_string()
)
try: try:
result = yield self.client.get_json( result = yield self.federation.make_query(
destination=target_user.domain, destination=target_user.domain,
path=path query_type="profile",
args={
"user_id": target_user.to_string(),
"field": "displayname",
}
) )
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 404: if e.code != 404:
@ -76,8 +74,8 @@ class ProfileHandler(BaseHandler):
raise raise
except: except:
logger.exception("Failed to get displayname") logger.exception("Failed to get displayname")
else:
defer.returnValue(result["displayname"]) defer.returnValue(result["displayname"])
else: else:
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
@ -110,18 +108,14 @@ class ProfileHandler(BaseHandler):
defer.returnValue(avatar_url) defer.returnValue(avatar_url)
elif not local_only: elif not local_only:
# TODO(paul): This should use the server-server API to ask another
# HS. For now we'll just have it use the http client to talk to the
# other HS's REST client API
destination = target_user.domain
path = PREFIX + "/profile/%s/avatar_url?local_only=1" % (
target_user.to_string(),
)
try: try:
result = yield self.client.get_json( result = yield self.federation.make_query(
destination=destination, destination=target_user.domain,
path=path query_type="profile",
args={
"user_id": target_user.to_string(),
"field": "avatar_url",
}
) )
except CodeMessageException as e: except CodeMessageException as e:
if e.code != 404: if e.code != 404:
@ -168,3 +162,25 @@ class ProfileHandler(BaseHandler):
state["avatar_url"] = avatar_url state["avatar_url"] = avatar_url
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def on_profile_query(self, args):
user = self.hs.parse_userid(args["user_id"])
if not user.is_mine:
raise SynapseError(400, "User is not hosted on this Home Server")
just_field = args.get("field", None)
response = {}
if just_field is None or just_field == "displayname":
response["displayname"] = yield self.store.get_profile_displayname(
user.localpart
)
if just_field is None or just_field == "avatar_url":
response["avatar_url"] = yield self.store.get_profile_avatar_url(
user.localpart
)
defer.returnValue(response)

View File

@ -43,6 +43,9 @@ class MockReplication(object):
def register_edu_handler(self, edu_type, handler): def register_edu_handler(self, edu_type, handler):
self.edu_handlers[edu_type] = handler self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
pass
def received_edu(self, origin, edu_type, content): def received_edu(self, origin, edu_type, content):
self.edu_handlers[edu_type](origin, content) self.edu_handlers[edu_type](origin, content)

View File

@ -37,13 +37,18 @@ class ProfileTestCase(unittest.TestCase):
""" Tests profile management. """ """ Tests profile management. """
def setUp(self): def setUp(self):
self.mock_client = Mock(spec=[ self.mock_federation = Mock(spec=[
"get_json", "make_query",
]) ])
self.query_handlers = {}
def register_query_handler(query_type, handler):
self.query_handlers[query_type] = handler
self.mock_federation.register_query_handler = register_query_handler
hs = HomeServer("test", hs = HomeServer("test",
db_pool=None, db_pool=None,
http_client=self.mock_client, http_client=None,
datastore=Mock(spec=[ datastore=Mock(spec=[
"get_profile_displayname", "get_profile_displayname",
"set_profile_displayname", "set_profile_displayname",
@ -52,6 +57,7 @@ class ProfileTestCase(unittest.TestCase):
]), ]),
handlers=None, handlers=None,
http_server=Mock(), http_server=Mock(),
replication_layer=self.mock_federation,
) )
hs.handlers = ProfileHandlers(hs) hs.handlers = ProfileHandlers(hs)
@ -93,18 +99,31 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_other_name(self): def test_get_other_name(self):
self.mock_client.get_json.return_value = defer.succeed( self.mock_federation.make_query.return_value = defer.succeed(
{"displayname": "Alice"}) {"displayname": "Alice"}
)
displayname = yield self.handler.get_displayname(self.alice) displayname = yield self.handler.get_displayname(self.alice)
self.assertEquals(displayname, "Alice") self.assertEquals(displayname, "Alice")
self.mock_client.get_json.assert_called_with( self.mock_federation.make_query.assert_called_with(
destination="remote", destination="remote",
path="/matrix/client/api/v1/profile/@alice:remote/displayname" query_type="profile",
"?local_only=1" args={"user_id": "@alice:remote", "field": "displayname"}
) )
@defer.inlineCallbacks
def test_incoming_fed_query(self):
mocked_get = self.datastore.get_profile_displayname
mocked_get.return_value = defer.succeed("Caroline")
response = yield self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
)
self.assertEquals({"displayname": "Caroline"}, response)
mocked_get.assert_called_with("caroline")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_my_avatar(self): def test_get_my_avatar(self):
mocked_get = self.datastore.get_profile_avatar_url mocked_get = self.datastore.get_profile_avatar_url