Support pagination tokens from /sync and /messages in the relations API. (#11952)

This commit is contained in:
Patrick Cloke 2022-02-10 10:52:48 -05:00 committed by GitHub
parent 337f38cac3
commit df36945ff0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 217 additions and 53 deletions

View file

@ -21,7 +21,8 @@ from unittest.mock import patch
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
from synapse.types import JsonDict
from synapse.storage.relations import RelationPaginationToken
from synapse.types import JsonDict, StreamToken
from tests import unittest
from tests.server import FakeChannel
@ -200,6 +201,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel.json_body.get("next_batch"), str, channel.json_body
)
def _stream_token_to_relation_token(self, token: str) -> str:
"""Convert a StreamToken into a legacy token (RelationPaginationToken)."""
room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
return self.get_success(
RelationPaginationToken(
topological=room_key.topological, stream=room_key.stream
).to_string(self.store)
)
def test_repeated_paginate_relations(self):
"""Test that if we paginate using a limit and tokens then we get the
expected events.
@ -213,7 +223,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])
prev_token: Optional[str] = None
prev_token = ""
found_event_ids: List[str] = []
for _ in range(20):
from_token = ""
@ -222,8 +232,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
% (self.room, self.parent_id, from_token),
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
@ -241,6 +250,93 @@ class RelationsTestCase(unittest.HomeserverTestCase):
found_event_ids.reverse()
self.assertEquals(found_event_ids, expected_event_ids)
# Reset and try again, but convert the tokens to the legacy format.
prev_token = ""
found_event_ids = []
for _ in range(20):
from_token = ""
if prev_token:
from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
next_batch = channel.json_body.get("next_batch")
self.assertNotEquals(prev_token, next_batch)
prev_token = next_batch
if not prev_token:
break
# We paginated backwards, so reverse
found_event_ids.reverse()
self.assertEquals(found_event_ids, expected_event_ids)
def test_pagination_from_sync_and_messages(self):
"""Pagination tokens from /sync and /messages can be used to paginate /relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
self.assertEquals(200, channel.code, channel.json_body)
annotation_id = channel.json_body["event_id"]
# Send an event after the relation events.
self.helper.send(self.room, body="Latest event", tok=self.user_token)
# Request /sync, limiting it such that only the latest event is returned
# (and not the relation).
filter = urllib.parse.quote_plus(
'{"room": {"timeline": {"limit": 1}}}'.encode()
)
channel = self.make_request(
"GET", f"/sync?filter={filter}", access_token=self.user_token
)
self.assertEquals(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
sync_prev_batch = room_timeline["prev_batch"]
self.assertIsNotNone(sync_prev_batch)
# Ensure the relation event is not in the batch returned from /sync.
self.assertNotIn(
annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
)
# Request /messages, limiting it such that only the latest event is
# returned (and not the relation).
channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b&limit=1",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
messages_end = channel.json_body["end"]
self.assertIsNotNone(messages_end)
# Ensure the relation event is not in the chunk returned from /messages.
self.assertNotIn(
annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
)
# Request /relations with the pagination tokens received from both the
# /sync and /messages responses above, in turn.
#
# This is a tiny bit silly since the client wouldn't know the parent ID
# from the requests above; consider the parent ID to be known from a
# previous /sync.
for from_token in (sync_prev_batch, messages_end):
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
# The relation should be in the returned chunk.
self.assertIn(
annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
)
def test_aggregation_pagination_groups(self):
"""Test that we can paginate annotation groups correctly."""
@ -337,7 +433,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
prev_token: Optional[str] = None
prev_token = ""
found_event_ids: List[str] = []
encoded_key = urllib.parse.quote_plus("👍".encode())
for _ in range(20):
@ -347,15 +443,42 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s"
"/aggregations/%s/%s/m.reaction/%s?limit=1%s"
% (
self.room,
self.parent_id,
RelationTypes.ANNOTATION,
encoded_key,
from_token,
),
f"/_matrix/client/unstable/rooms/{self.room}"
f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
f"/m.reaction/{encoded_key}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
next_batch = channel.json_body.get("next_batch")
self.assertNotEquals(prev_token, next_batch)
prev_token = next_batch
if not prev_token:
break
# We paginated backwards, so reverse
found_event_ids.reverse()
self.assertEquals(found_event_ids, expected_event_ids)
# Reset and try again, but convert the tokens to the legacy format.
prev_token = ""
found_event_ids = []
for _ in range(20):
from_token = ""
if prev_token:
from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}"
f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
f"/m.reaction/{encoded_key}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)