Check that event is visible in new APIs

This commit is contained in:
Erik Johnston 2019-05-16 14:19:06 +01:00
parent b5c62c6b26
commit 95f3fcda3c
2 changed files with 16 additions and 3 deletions

View File

@ -131,6 +131,7 @@ class RelationPaginationServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
@ -140,6 +141,10 @@ class RelationPaginationServlet(RestServlet):
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This checks that a) the event exists and b) the user is allowed to
# view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from") from_token = parse_string(request, "from")
to_token = parse_string(request, "to") to_token = parse_string(request, "to")
@ -195,6 +200,7 @@ class RelationAggregationPaginationServlet(RestServlet):
super(RelationAggregationPaginationServlet, self).__init__() super(RelationAggregationPaginationServlet, self).__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
@ -204,6 +210,10 @@ class RelationAggregationPaginationServlet(RestServlet):
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This checks that a) the event exists and b) the user is allowed to
# view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type not in (RelationTypes.ANNOTATION, None): if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
@ -258,6 +268,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
@ -267,6 +278,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This checks that a) the event exists and b) the user is allowed to
# view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type != RelationTypes.ANNOTATION: if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
@ -296,8 +311,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
defer.returnValue((200, return_value)) defer.returnValue((200, return_value))
defer.returnValue((200, return_value))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RelationSendServlet(hs).register(http_server) RelationSendServlet(hs).register(http_server)

View File

@ -296,7 +296,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request( request, channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/rooms/%s/aggregations/m.replaces/%s?limit=1" "/_matrix/client/unstable/rooms/%s/aggregations/%s/m.replace?limit=1"
% (self.room, self.parent_id), % (self.room, self.parent_id),
) )
self.render(request) self.render(request)