Include bundled aggregations in /sync and related fixes (#11478)

Due to updates to MSC2675 this includes a few fixes:

* Include bundled aggregations for /sync.
* Do not include bundled aggregations for /initialSync and /events.
* Do not bundle aggregations for state events.
* Clarifies comments and variable names.
This commit is contained in:
Patrick Cloke 2021-12-06 10:51:15 -05:00 committed by GitHub
parent a77c369897
commit 494ebd7347
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 169 additions and 101 deletions

1
changelog.d/11478.bugfix Normal file
View File

@ -0,0 +1 @@
Include bundled relation aggregations during a limited `/sync` request, per [MSC2675](https://github.com/matrix-org/matrix-doc/pull/2675).

View File

@ -306,6 +306,7 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
def serialize_event( def serialize_event(
e: Union[JsonDict, EventBase], e: Union[JsonDict, EventBase],
time_now_ms: int, time_now_ms: int,
*,
as_client_event: bool = True, as_client_event: bool = True,
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
token_id: Optional[str] = None, token_id: Optional[str] = None,
@ -393,7 +394,8 @@ class EventClientSerializer:
self, self,
event: Union[JsonDict, EventBase], event: Union[JsonDict, EventBase],
time_now: int, time_now: int,
bundle_relations: bool = True, *,
bundle_aggregations: bool = True,
**kwargs: Any, **kwargs: Any,
) -> JsonDict: ) -> JsonDict:
"""Serializes a single event. """Serializes a single event.
@ -401,8 +403,9 @@ class EventClientSerializer:
Args: Args:
event: The event being serialized. event: The event being serialized.
time_now: The current time in milliseconds time_now: The current time in milliseconds
bundle_relations: Whether to include the bundled relations for this bundle_aggregations: Whether to include the bundled aggregations for this
event. event. Only applies to non-state events. (State events never include
bundled aggregations.)
**kwargs: Arguments to pass to `serialize_event` **kwargs: Arguments to pass to `serialize_event`
Returns: Returns:
@ -414,20 +417,27 @@ class EventClientSerializer:
serialized_event = serialize_event(event, time_now, **kwargs) serialized_event = serialize_event(event, time_now, **kwargs)
# If MSC1849 is enabled then we need to look if there are any relations # Check if there are any bundled aggregations to include with the event.
# we need to bundle in with the event. #
# Do not bundle relations if the event has been redacted # Do not bundle aggregations if any of the following at true:
if not event.internal_metadata.is_redacted() and ( #
self._msc1849_enabled and bundle_relations # * Support is disabled via the configuration or the caller.
# * The event is a state event.
# * The event has been redacted.
if (
self._msc1849_enabled
and bundle_aggregations
and not event.is_state()
and not event.internal_metadata.is_redacted()
): ):
await self._injected_bundled_relations(event, time_now, serialized_event) await self._injected_bundled_aggregations(event, time_now, serialized_event)
return serialized_event return serialized_event
async def _injected_bundled_relations( async def _injected_bundled_aggregations(
self, event: EventBase, time_now: int, serialized_event: JsonDict self, event: EventBase, time_now: int, serialized_event: JsonDict
) -> None: ) -> None:
"""Potentially injects bundled relations into the unsigned portion of the serialized event. """Potentially injects bundled aggregations into the unsigned portion of the serialized event.
Args: Args:
event: The event being serialized. event: The event being serialized.
@ -435,7 +445,7 @@ class EventClientSerializer:
serialized_event: The serialized event which may be modified. serialized_event: The serialized event which may be modified.
""" """
# Do not bundle relations for an event which represents an edit or an # Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events. # annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to") relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)): if isinstance(relates_to, (dict, frozendict)):
@ -445,18 +455,18 @@ class EventClientSerializer:
event_id = event.event_id event_id = event.event_id
# The bundled relations to include. # The bundled aggregations to include.
relations = {} aggregations = {}
annotations = await self.store.get_aggregation_groups_for_event(event_id) annotations = await self.store.get_aggregation_groups_for_event(event_id)
if annotations.chunk: if annotations.chunk:
relations[RelationTypes.ANNOTATION] = annotations.to_dict() aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
references = await self.store.get_relations_for_event( references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f" event_id, RelationTypes.REFERENCE, direction="f"
) )
if references.chunk: if references.chunk:
relations[RelationTypes.REFERENCE] = references.to_dict() aggregations[RelationTypes.REFERENCE] = references.to_dict()
edit = None edit = None
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
@ -482,7 +492,7 @@ class EventClientSerializer:
else: else:
serialized_event["content"].pop("m.relates_to", None) serialized_event["content"].pop("m.relates_to", None)
relations[RelationTypes.REPLACE] = { aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id, "event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts, "origin_server_ts": edit.origin_server_ts,
"sender": edit.sender, "sender": edit.sender,
@ -495,17 +505,19 @@ class EventClientSerializer:
latest_thread_event, latest_thread_event,
) = await self.store.get_thread_summary(event_id) ) = await self.store.get_thread_summary(event_id)
if latest_thread_event: if latest_thread_event:
relations[RelationTypes.THREAD] = { aggregations[RelationTypes.THREAD] = {
# Don't bundle relations as this could recurse forever. # Don't bundle aggregations as this could recurse forever.
"latest_event": await self.serialize_event( "latest_event": await self.serialize_event(
latest_thread_event, time_now, bundle_relations=False latest_thread_event, time_now, bundle_aggregations=False
), ),
"count": thread_count, "count": thread_count,
} }
# If any bundled relations were found, include them. # If any bundled aggregations were found, include them.
if relations: if aggregations:
serialized_event["unsigned"].setdefault("m.relations", {}).update(relations) serialized_event["unsigned"].setdefault("m.relations", {}).update(
aggregations
)
async def serialize_events( async def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any

View File

@ -122,9 +122,8 @@ class EventStreamHandler:
events, events,
time_now, time_now,
as_client_event=as_client_event, as_client_event=as_client_event,
# We don't bundle "live" events, as otherwise clients # Don't bundle aggregations as this is a deprecated API.
# will end up double counting annotations. bundle_aggregations=False,
bundle_relations=False,
) )
chunk = { chunk = {

View File

@ -165,7 +165,11 @@ class InitialSyncHandler:
invite_event = await self.store.get_event(event.event_id) invite_event = await self.store.get_event(event.event_id)
d["invite"] = await self._event_serializer.serialize_event( d["invite"] = await self._event_serializer.serialize_event(
invite_event, time_now, as_client_event invite_event,
time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
) )
rooms_ret.append(d) rooms_ret.append(d)
@ -216,7 +220,11 @@ class InitialSyncHandler:
d["messages"] = { d["messages"] = {
"chunk": ( "chunk": (
await self._event_serializer.serialize_events( await self._event_serializer.serialize_events(
messages, time_now=time_now, as_client_event=as_client_event messages,
time_now=time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event,
) )
), ),
"start": await start_token.to_string(self.store), "start": await start_token.to_string(self.store),
@ -226,6 +234,8 @@ class InitialSyncHandler:
d["state"] = await self._event_serializer.serialize_events( d["state"] = await self._event_serializer.serialize_events(
current_state.values(), current_state.values(),
time_now=time_now, time_now=time_now,
# Don't bundle aggregations as this is a deprecated API.
bundle_aggregations=False,
as_client_event=as_client_event, as_client_event=as_client_event,
) )
@ -366,14 +376,18 @@ class InitialSyncHandler:
"room_id": room_id, "room_id": room_id,
"messages": { "messages": {
"chunk": ( "chunk": (
await self._event_serializer.serialize_events(messages, time_now) # Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
messages, time_now, bundle_aggregations=False
)
), ),
"start": await start_token.to_string(self.store), "start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store), "end": await end_token.to_string(self.store),
}, },
"state": ( "state": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events( await self._event_serializer.serialize_events(
room_state.values(), time_now room_state.values(), time_now, bundle_aggregations=False
) )
), ),
"presence": [], "presence": [],
@ -392,8 +406,9 @@ class InitialSyncHandler:
# TODO: These concurrently # TODO: These concurrently
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# Don't bundle aggregations as this is a deprecated API.
state = await self._event_serializer.serialize_events( state = await self._event_serializer.serialize_events(
current_state.values(), time_now current_state.values(), time_now, bundle_aggregations=False
) )
now_token = self.hs.get_event_sources().get_current_token() now_token = self.hs.get_event_sources().get_current_token()
@ -467,7 +482,10 @@ class InitialSyncHandler:
"room_id": room_id, "room_id": room_id,
"messages": { "messages": {
"chunk": ( "chunk": (
await self._event_serializer.serialize_events(messages, time_now) # Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
messages, time_now, bundle_aggregations=False
)
), ),
"start": await start_token.to_string(self.store), "start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store), "end": await end_token.to_string(self.store),

View File

@ -247,13 +247,7 @@ class MessageHandler:
room_state = room_state_events[membership_event_id] room_state = room_state_events[membership_event_id]
now = self.clock.time_msec() now = self.clock.time_msec()
events = await self._event_serializer.serialize_events( events = await self._event_serializer.serialize_events(room_state.values(), now)
room_state.values(),
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_relations=False,
)
return events return events
async def get_joined_members(self, requester: Requester, room_id: str) -> dict: async def get_joined_members(self, requester: Requester, room_id: str) -> dict:

View File

@ -449,13 +449,7 @@ class RoomStateRestServlet(RestServlet):
event_ids = await self.store.get_current_state_ids(room_id) event_ids = await self.store.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values()) events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec() now = self.clock.time_msec()
room_state = await self._event_serializer.serialize_events( room_state = await self._event_serializer.serialize_events(events.values(), now)
events.values(),
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_relations=False,
)
ret = {"state": room_state} ret = {"state": room_state}
return HTTPStatus.OK, ret return HTTPStatus.OK, ret
@ -789,10 +783,7 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now results["events_after"], time_now
) )
results["state"] = await self._event_serializer.serialize_events( results["state"] = await self._event_serializer.serialize_events(
results["state"], results["state"], time_now
time_now,
# No need to bundle aggregations for state events
bundle_relations=False,
) )
return HTTPStatus.OK, results return HTTPStatus.OK, results

View File

@ -224,14 +224,13 @@ class RelationPaginationServlet(RestServlet):
) )
now = self.clock.time_msec() now = self.clock.time_msec()
# We set bundle_relations to False when retrieving the original # Do not bundle aggregations when retrieving the original event because
# event because we want the content before relations were applied to # we want the content before relations are applied to it.
# it.
original_event = await self._event_serializer.serialize_event( original_event = await self._event_serializer.serialize_event(
event, now, bundle_relations=False event, now, bundle_aggregations=False
) )
# The relations returned for the requested event do include their # The relations returned for the requested event do include their
# bundled relations. # bundled aggregations.
serialized_events = await self._event_serializer.serialize_events(events, now) serialized_events = await self._event_serializer.serialize_events(events, now)
return_value = pagination_chunk.to_dict() return_value = pagination_chunk.to_dict()

View File

@ -716,10 +716,7 @@ class RoomEventContextServlet(RestServlet):
results["events_after"], time_now results["events_after"], time_now
) )
results["state"] = await self._event_serializer.serialize_events( results["state"] = await self._event_serializer.serialize_events(
results["state"], results["state"], time_now
time_now,
# No need to bundle aggregations for state events
bundle_relations=False,
) )
return 200, results return 200, results

View File

@ -520,9 +520,9 @@ class SyncRestServlet(RestServlet):
return self._event_serializer.serialize_events( return self._event_serializer.serialize_events(
events, events,
time_now=time_now, time_now=time_now,
# We don't bundle "live" events, as otherwise clients # Don't bother to bundle aggregations if the timeline is unlimited,
# will end up double counting annotations. # as clients will have all the necessary information.
bundle_relations=False, bundle_aggregations=room.timeline.limited,
token_id=token_id, token_id=token_id,
event_format=event_formatter, event_format=event_formatter,
only_event_fields=only_fields, only_event_fields=only_fields,

View File

@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, register, relations, room from synapse.rest.client import login, register, relations, room, sync
from tests import unittest from tests import unittest
from tests.server import FakeChannel from tests.server import FakeChannel
@ -29,6 +29,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
servlets = [ servlets = [
relations.register_servlets, relations.register_servlets,
room.register_servlets, room.register_servlets,
sync.register_servlets,
login.register_servlets, login.register_servlets,
register.register_servlets, register.register_servlets,
admin.register_servlets_for_client_rest_resource, admin.register_servlets_for_client_rest_resource,
@ -454,11 +455,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(400, channel.code, channel.json_body) self.assertEquals(400, channel.code, channel.json_body)
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_aggregation_get_event(self): def test_bundled_aggregations(self):
"""Test that annotations, references, and threads get correctly bundled when """Test that annotations, references, and threads get correctly bundled."""
getting the parent event. # Setup by sending a variety of relations.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
@ -485,29 +484,42 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
thread_2 = channel.json_body["event_id"] thread_2 = channel.json_body["event_id"]
channel = self.make_request( def assert_bundle(actual):
"GET", """Assert the expected values of the bundled aggregations."""
"/rooms/%s/event/%s" % (self.room, self.parent_id),
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
# Ensure the fields are as expected.
self.assertCountEqual(
actual.keys(),
(
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.THREAD,
),
)
# Check the values of each field.
self.assertEquals( self.assertEquals(
channel.json_body["unsigned"].get("m.relations"),
{ {
RelationTypes.ANNOTATION: {
"chunk": [ "chunk": [
{"type": "m.reaction", "key": "a", "count": 2}, {"type": "m.reaction", "key": "a", "count": 2},
{"type": "m.reaction", "key": "b", "count": 1}, {"type": "m.reaction", "key": "b", "count": 1},
] ]
}, },
RelationTypes.REFERENCE: { actual[RelationTypes.ANNOTATION],
"chunk": [{"event_id": reply_1}, {"event_id": reply_2}] )
},
RelationTypes.THREAD: { self.assertEquals(
"count": 2, {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
"latest_event": { actual[RelationTypes.REFERENCE],
"age": 100, )
self.assertEquals(
2,
actual[RelationTypes.THREAD].get("count"),
)
# The latest thread event has some fields that don't matter.
self.assert_dict(
{
"content": { "content": {
"m.relates_to": { "m.relates_to": {
"event_id": self.parent_id, "event_id": self.parent_id,
@ -515,19 +527,64 @@ class RelationsTestCase(unittest.HomeserverTestCase):
} }
}, },
"event_id": thread_2, "event_id": thread_2,
"origin_server_ts": 1600,
"room_id": self.room, "room_id": self.room,
"sender": self.user_id, "sender": self.user_id,
"type": "m.room.test", "type": "m.room.test",
"unsigned": {"age": 100},
"user_id": self.user_id, "user_id": self.user_id,
}, },
}, actual[RelationTypes.THREAD].get("latest_event"),
},
) )
def _find_and_assert_event(events):
"""
Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
"""
for event in events:
if event["event_id"] == self.parent_id:
break
else:
raise AssertionError(f"Event {self.parent_id} not found in chunk")
assert_bundle(event["unsigned"].get("m.relations"))
# Request the event directly.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["unsigned"].get("m.relations"))
# Request the room messages.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
_find_and_assert_event(channel.json_body["chunk"])
# Request the room context.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["event"]["unsigned"].get("m.relations"))
# Request sync.
channel = self.make_request("GET", "/sync", access_token=self.user_token)
self.assertEquals(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
self.assertTrue(room_timeline["limited"])
_find_and_assert_event(room_timeline["events"])
# Note that /relations is tested separately in test_aggregation_get_event_for_thread
# since it needs different data configured.
def test_aggregation_get_event_for_annotation(self): def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled relations included """Test that annotations do not get bundled aggregations included
when directly requested. when directly requested.
""" """
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
@ -549,7 +606,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
def test_aggregation_get_event_for_thread(self): def test_aggregation_get_event_for_thread(self):
"""Test that threads get bundled relations included when directly requested.""" """Test that threads get bundled aggregations included when directly requested."""
channel = self._send_relation(RelationTypes.THREAD, "m.room.test") channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
thread_id = channel.json_body["event_id"] thread_id = channel.json_body["event_id"]