Fix incorrect thread summaries when the latest event is edited. (#11992)

If the latest event in a thread was edited than the original
event content was included in bundled aggregation for
threads instead of the edited event content.
This commit is contained in:
Patrick Cloke 2022-02-15 08:26:57 -05:00 committed by GitHub
parent 85e24d9d2b
commit 45f45404de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 31 deletions

1
changelog.d/11992.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary.

View File

@ -425,6 +425,33 @@ class EventClientSerializer:
return serialized_event return serialized_event
def _apply_edit(
self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase
) -> None:
"""Replace the content, preserving existing relations of the serialized event.
Args:
orig_event: The original event.
serialized_event: The original event, serialized. This is modified.
edit: The event which edits the above.
"""
# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
edit_content = edit.content.copy()
# Unfreeze the event content if necessary, so that we may modify it below
edit_content = unfreeze(edit_content)
serialized_event["content"] = edit_content.get("m.new_content", {})
# Check for existing relations
relates_to = orig_event.content.get("m.relates_to")
if relates_to:
# Keep the relations, ensuring we use a dict copy of the original
serialized_event["content"]["m.relates_to"] = relates_to.copy()
else:
serialized_event["content"].pop("m.relates_to", None)
def _inject_bundled_aggregations( def _inject_bundled_aggregations(
self, self,
event: EventBase, event: EventBase,
@ -450,26 +477,11 @@ class EventClientSerializer:
serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
if aggregations.replace: if aggregations.replace:
# If there is an edit replace the content, preserving existing # If there is an edit, apply it to the event.
# relations.
edit = aggregations.replace edit = aggregations.replace
self._apply_edit(event, serialized_event, edit)
# Ensure we take copies of the edit content, otherwise we risk modifying # Include information about it in the relations dict.
# the original event.
edit_content = edit.content.copy()
# Unfreeze the event content if necessary, so that we may modify it below
edit_content = unfreeze(edit_content)
serialized_event["content"] = edit_content.get("m.new_content", {})
# Check for existing relations
relates_to = event.content.get("m.relates_to")
if relates_to:
# Keep the relations, ensuring we use a dict copy of the original
serialized_event["content"]["m.relates_to"] = relates_to.copy()
else:
serialized_event["content"].pop("m.relates_to", None)
serialized_aggregations[RelationTypes.REPLACE] = { serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id, "event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts, "origin_server_ts": edit.origin_server_ts,
@ -478,13 +490,22 @@ class EventClientSerializer:
# If this event is the start of a thread, include a summary of the replies. # If this event is the start of a thread, include a summary of the replies.
if aggregations.thread: if aggregations.thread:
thread = aggregations.thread
# Don't bundle aggregations as this could recurse forever.
serialized_latest_event = self.serialize_event(
thread.latest_event, time_now, bundle_aggregations=None
)
# Manually apply an edit, if one exists.
if thread.latest_edit:
self._apply_edit(
thread.latest_event, serialized_latest_event, thread.latest_edit
)
serialized_aggregations[RelationTypes.THREAD] = { serialized_aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever. "latest_event": serialized_latest_event,
"latest_event": self.serialize_event( "count": thread.count,
aggregations.thread.latest_event, time_now, bundle_aggregations=None "current_user_participated": thread.current_user_participated,
),
"count": aggregations.thread.count,
"current_user_participated": aggregations.thread.current_user_participated,
} }
# Include the bundled aggregations in the event. # Include the bundled aggregations in the event.

View File

@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore):
include the previous states content in the unsigned field. include the previous states content in the unsigned field.
allow_rejected: If True, return rejected events. Otherwise, allow_rejected: If True, return rejected events. Otherwise,
omits rejeted events from the response. omits rejected events from the response.
Returns: Returns:
A mapping from event_id to event. A mapping from event_id to event.

View File

@ -53,8 +53,13 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation: class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool current_user_participated: bool
@ -461,8 +466,8 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids") @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries( async def _get_thread_summaries(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase]]]: ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies and the latest reply (if any) for the given event. """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
Args: Args:
event_ids: Summarize the thread related to this event ID. event_ids: Summarize the thread related to this event ID.
@ -471,8 +476,10 @@ class RelationsWorkerStore(SQLBaseStore):
A map of the thread summary each event. A missing event implies there A map of the thread summary each event. A missing event implies there
are no threaded replies. are no threaded replies.
Each summary includes the number of items in the thread and the most Each summary is a tuple of:
recent response. The number of events in the thread.
The most recent event in the thread.
The most recent edit to the most recent event in the thread, if applicable.
""" """
def _get_thread_summaries_txn( def _get_thread_summaries_txn(
@ -558,6 +565,9 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited.
latest_edits = await self._get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary. # Map to the event IDs to the thread summary.
# #
# There might not be a summary due to there not being a thread or # There might not be a summary due to there not being a thread or
@ -568,7 +578,8 @@ class RelationsWorkerStore(SQLBaseStore):
summary = None summary = None
if latest_event: if latest_event:
summary = (counts[parent_event_id], latest_event) latest_edit = latest_edits.get(latest_event_id)
summary = (counts[parent_event_id], latest_event, latest_edit)
summaries[parent_event_id] = summary summaries[parent_event_id] = summary
return summaries return summaries
@ -828,11 +839,12 @@ class RelationsWorkerStore(SQLBaseStore):
) )
for event_id, summary in summaries.items(): for event_id, summary in summaries.items():
if summary: if summary:
thread_count, latest_thread_event = summary thread_count, latest_thread_event, edit = summary
results.setdefault( results.setdefault(
event_id, BundledAggregations() event_id, BundledAggregations()
).thread = _ThreadAggregation( ).thread = _ThreadAggregation(
latest_event=latest_thread_event, latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count, count=thread_count,
# If there's a thread summary it must also exist in the # If there's a thread summary it must also exist in the
# participated dictionary. # participated dictionary.

View File

@ -1123,6 +1123,48 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
) )
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_edit_thread(self):
"""Test that editing a thread works."""
# Create a thread and edit the last event.
channel = self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "A threaded reply!"},
)
self.assertEquals(200, channel.code, channel.json_body)
threaded_event_id = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
parent_id=threaded_event_id,
)
self.assertEquals(200, channel.code, channel.json_body)
# Fetch the thread root, to get the bundled aggregation for the thread.
channel = self.make_request(
"GET",
f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
# We expect that the edit message appears in the thread summary in the
# unsigned relations section.
relations_dict = channel.json_body["unsigned"].get("m.relations")
self.assertIn(RelationTypes.THREAD, relations_dict)
thread_summary = relations_dict[RelationTypes.THREAD]
self.assertIn("latest_event", thread_summary)
latest_event_in_thread = thread_summary["latest_event"]
self.assertEquals(
latest_event_in_thread["content"]["body"], "I've been edited!"
)
def test_edit_edit(self): def test_edit_edit(self):
"""Test that an edit cannot be edited.""" """Test that an edit cannot be edited."""
new_body = {"msgtype": "m.text", "body": "Initial edit"} new_body = {"msgtype": "m.text", "body": "Initial edit"}