Include bundled aggregations in /sync and related fixes (#11478)

Due to updates to MSC2675 this includes a few fixes:

* Include bundled aggregations for /sync.
* Do not include bundled aggregations for /initialSync and /events.
* Do not bundle aggregations for state events.
* Clarifies comments and variable names.
This commit is contained in:
Patrick Cloke 2021-12-06 10:51:15 -05:00 committed by GitHub
parent a77c369897
commit 494ebd7347
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 169 additions and 101 deletions

1
changelog.d/11478.bugfix Normal file
View File

@ -0,0 +1 @@
Include bundled relation aggregations during a limited `/sync` request, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675).

View File

@ -306,6 +306,7 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
def serialize_event(
e: Union[JsonDict, EventBase],
time_now_ms: int,
*,
as_client_event: bool = True,
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
token_id: Optional[str] = None,
@ -393,7 +394,8 @@ class EventClientSerializer:
self,
event: Union[JsonDict, EventBase],
time_now: int,
bundle_relations: bool = True,
*,
bundle_aggregations: bool = True,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
@ -401,8 +403,9 @@ class EventClientSerializer:
Args:
event: The event being serialized.
time_now: The current time in milliseconds
bundle_relations: Whether to include the bundled relations for this
event.
bundle_aggregations: Whether to include the bundled aggregations for this
event. Only applies to non-state events. (State events never include
bundled aggregations.)
**kwargs: Arguments to pass to `serialize_event`
Returns:
@ -414,20 +417,27 @@ class EventClientSerializer:
serialized_event = serialize_event(event, time_now, **kwargs)
# If MSC1849 is enabled then we need to look if there are any relations
# we need to bundle in with the event.
# Do not bundle relations if the event has been redacted
if not event.internal_metadata.is_redacted() and (
self._msc1849_enabled and bundle_relations
# Check if there are any bundled aggregations to include with the event.
#
# Do not bundle aggregations if any of the following at true:
#
# * Support is disabled via the configuration or the caller.
# * The event is a state event.
# * The event has been redacted.
if (
self._msc1849_enabled
and bundle_aggregations
and not event.is_state()
and not event.internal_metadata.is_redacted()
):
await self._injected_bundled_relations(event, time_now, serialized_event)
await self._injected_bundled_aggregations(event, time_now, serialized_event)
return serialized_event
async def _injected_bundled_relations(
async def _injected_bundled_aggregations(
self, event: EventBase, time_now: int, serialized_event: JsonDict
) -> None:
"""Potentially injects bundled relations into the unsigned portion of the serialized event.
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
Args:
event: The event being serialized.
@ -435,7 +445,7 @@ class EventClientSerializer:
serialized_event: The serialized event which may be modified.
"""
# Do not bundle relations for an event which represents an edit or an
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
@ -445,18 +455,18 @@ class EventClientSerializer:
event_id = event.event_id
# The bundled relations to include.
relations = {}
# The bundled aggregations to include.
aggregations = {}
annotations = await self.store.get_aggregation_groups_for_event(event_id)
if annotations.chunk:
relations[RelationTypes.ANNOTATION] = annotations.to_dict()
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
relations[RelationTypes.REFERENCE] = references.to_dict()
aggregations[RelationTypes.REFERENCE] = references.to_dict()
edit = None
if event.type == EventTypes.Message:
@ -482,7 +492,7 @@ class EventClientSerializer:
else:
serialized_event["content"].pop("m.relates_to", None)
relations[RelationTypes.REPLACE] = {
aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts,
"sender": edit.sender,
@ -495,17 +505,19 @@ class EventClientSerializer:
latest_thread_event,
) = await self.store.get_thread_summary(event_id)
if latest_thread_event:
relations[RelationTypes.THREAD] = {
# Don't bundle relations as this could recurse forever.
aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": await self.serialize_event(
latest_thread_event, time_now, bundle_relations=False
latest_thread_event, time_now, bundle_aggregations=False
),
"count": thread_count,
}
# If any bundled relations were found, include them.
if relations:
serialized_event["unsigned"].setdefault("m.relations", {}).update(relations)
# If any bundled aggregations were found, include them.
if aggregations:
serialized_event["unsigned"].setdefault("m.relations", {}).update(
aggregations
)
async def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any

View File

@ -122,9 +122,8 @@ class EventStreamHandler:
events,
time_now,
as_client_event=as_client_event,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_relations=False,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
)
chunk = {

View File

@ -165,7 +165,11 @@ class InitialSyncHandler:
invite_event = await self.store.get_event(event.event_id)
d["invite"] = await self._event_serializer.serialize_event(
invite_event, time_now, as_client_event
invite_event,
time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
)
rooms_ret.append(d)
@ -216,7 +220,11 @@ class InitialSyncHandler:
d["messages"] = {
"chunk": (
await self._event_serializer.serialize_events(
messages, time_now=time_now, as_client_event=as_client_event
messages,
time_now=time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
)
),
"start": await start_token.to_string(self.store),
@ -226,6 +234,8 @@ class InitialSyncHandler:
d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
)
@ -366,14 +376,18 @@ class InitialSyncHandler:
"room_id": room_id,
"messages": {
"chunk": (
await self._event_serializer.serialize_events(messages, time_now)
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
messages, time_now, bundle_aggregations=False
)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
room_state.values(), time_now
room_state.values(), time_now, bundle_aggregations=False
)
),
"presence": [],
@ -392,8 +406,9 @@ class InitialSyncHandler:
# TODO: These concurrently
time_now = self.clock.time_msec()
# Don't bundle aggregations as this is a deprecated API.
state = await self._event_serializer.serialize_events(
current_state.values(), time_now
current_state.values(), time_now, bundle_aggregations=False
)
now_token = self.hs.get_event_sources().get_current_token()
@ -467,7 +482,10 @@ class InitialSyncHandler:
"room_id": room_id,
"messages": {
"chunk": (
await self._event_serializer.serialize_events(messages, time_now)
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
messages, time_now, bundle_aggregations=False
)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),

View File

@ -247,13 +247,7 @@ class MessageHandler:
room_state = room_state_events[membership_event_id]
now = self.clock.time_msec()
events = await self._event_serializer.serialize_events(
room_state.values(),
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_relations=False,
)
events = await self._event_serializer.serialize_events(room_state.values(), now)
return events
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:

View File

@ -449,13 +449,7 @@ class RoomStateRestServlet(RestServlet):
event_ids = await self.store.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
room_state = await self._event_serializer.serialize_events(
events.values(),
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_relations=False,
)
room_state = await self._event_serializer.serialize_events(events.values(), now)
ret = {"state": room_state}
return HTTPStatus.OK, ret
@ -789,10 +783,7 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now
)
results["state"] = await self._event_serializer.serialize_events(
results["state"],
time_now,
# No need to bundle aggregations for state events
bundle_relations=False,
results["state"], time_now
)
return HTTPStatus.OK, results

View File

@ -224,14 +224,13 @@ class RelationPaginationServlet(RestServlet):
)
now = self.clock.time_msec()
# We set bundle_relations to False when retrieving the original
# event because we want the content before relations were applied to
# it.
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
original_event = await self._event_serializer.serialize_event(
event, now, bundle_relations=False
event, now, bundle_aggregations=False
)
# The relations returned for the requested event do include their
# bundled relations.
# bundled aggregations.
serialized_events = await self._event_serializer.serialize_events(events, now)
return_value = pagination_chunk.to_dict()

View File

@ -716,10 +716,7 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now
)
results["state"] = await self._event_serializer.serialize_events(
results["state"],
time_now,
# No need to bundle aggregations for state events
bundle_relations=False,
results["state"], time_now
)
return 200, results

View File

@ -520,9 +520,9 @@ class SyncRestServlet(RestServlet):
return self._event_serializer.serialize_events(
events,
time_now=time_now,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_relations=False,
# Don't bother to bundle aggregations if the timeline is unlimited,
# as clients will have all the necessary information.
bundle_aggregations=room.timeline.limited,
token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,

View File

@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room
from synapse.rest.client import login, register, relations, room, sync
from tests import unittest
from tests.server import FakeChannel
@ -29,6 +29,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
servlets = [
relations.register_servlets,
room.register_servlets,
sync.register_servlets,
login.register_servlets,
register.register_servlets,
admin.register_servlets_for_client_rest_resource,
@ -454,11 +455,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(400, channel.code, channel.json_body)
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_aggregation_get_event(self):
"""Test that annotations, references, and threads get correctly bundled when
getting the parent event.
"""
def test_bundled_aggregations(self):
"""Test that annotations, references, and threads get correctly bundled."""
# Setup by sending a variety of relations.
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body)
@ -485,29 +484,42 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
thread_2 = channel.json_body["event_id"]
channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
def assert_bundle(actual):
"""Assert the expected values of the bundled aggregations."""
# Ensure the fields are as expected.
self.assertCountEqual(
actual.keys(),
(
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.THREAD,
),
)
# Check the values of each field.
self.assertEquals(
channel.json_body["unsigned"].get("m.relations"),
{
RelationTypes.ANNOTATION: {
"chunk": [
{"type": "m.reaction", "key": "a", "count": 2},
{"type": "m.reaction", "key": "b", "count": 1},
]
},
RelationTypes.REFERENCE: {
"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]
},
RelationTypes.THREAD: {
"count": 2,
"latest_event": {
"age": 100,
actual[RelationTypes.ANNOTATION],
)
self.assertEquals(
{"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
actual[RelationTypes.REFERENCE],
)
self.assertEquals(
2,
actual[RelationTypes.THREAD].get("count"),
)
# The latest thread event has some fields that don't matter.
self.assert_dict(
{
"content": {
"m.relates_to": {
"event_id": self.parent_id,
@ -515,19 +527,64 @@ class RelationsTestCase(unittest.HomeserverTestCase):
}
},
"event_id": thread_2,
"origin_server_ts": 1600,
"room_id": self.room,
"sender": self.user_id,
"type": "m.room.test",
"unsigned": {"age": 100},
"user_id": self.user_id,
},
},
},
actual[RelationTypes.THREAD].get("latest_event"),
)
def _find_and_assert_event(events):
"""
Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
"""
for event in events:
if event["event_id"] == self.parent_id:
break
else:
raise AssertionError(f"Event {self.parent_id} not found in chunk")
assert_bundle(event["unsigned"].get("m.relations"))
# Request the event directly.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["unsigned"].get("m.relations"))
# Request the room messages.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
_find_and_assert_event(channel.json_body["chunk"])
# Request the room context.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations"))
# Request sync.
channel = self.make_request("GET", "/sync", access_token=self.user_token)
self.assertEquals(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
self.assertTrue(room_timeline["limited"])
_find_and_assert_event(room_timeline["events"])
# Note that /relations is tested separately in test_aggregation_get_event_for_thread
# since it needs different data configured.
def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled relations included
"""Test that annotations do not get bundled aggregations included
when directly requested.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@ -549,7 +606,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
def test_aggregation_get_event_for_thread(self):
"""Test that threads get bundled relations included when directly requested."""
"""Test that threads get bundled aggregations included when directly requested."""
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
self.assertEquals(200, channel.code, channel.json_body)
thread_id = channel.json_body["event_id"]