Experimental support to include bundled aggregations in search results (MSC3666) (#11837)

This commit is contained in:
Patrick Cloke 2022-02-08 09:21:20 -05:00 committed by GitHub
parent 6c0984e3f0
commit 8c94b3abe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 76 additions and 8 deletions

View File

@ -0,0 +1 @@
Experimental support for [MSC3666](https://github.com/matrix-org/matrix-doc/pull/3666): including bundled aggregations in server side search results.

View File

@ -26,6 +26,8 @@ class ExperimentalConfig(Config):
# MSC3440 (thread relation) # MSC3440 (thread relation)
self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False) self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False)
# MSC3666: including bundled relations in /search.
self.msc3666_enabled: bool = experimental.get("msc3666_enabled", False)
# MSC3026 (busy presence state) # MSC3026 (busy presence state)
self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False) self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)

View File

@ -43,6 +43,8 @@ class SearchHandler:
self.state_store = self.storage.state self.state_store = self.storage.state
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._msc3666_enabled = hs.config.experimental.msc3666_enabled
async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
"""Retrieves room IDs of old rooms in the history of an upgraded room. """Retrieves room IDs of old rooms in the history of an upgraded room.
@ -238,8 +240,6 @@ class SearchHandler:
results = search_result["results"] results = search_result["results"]
results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results}) rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = await search_filter.filter([r["event"] for r in results]) filtered_events = await search_filter.filter([r["event"] for r in results])
@ -420,12 +420,29 @@ class SearchHandler:
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
aggregations = None
if self._msc3666_enabled:
aggregations = await self.store.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
# The events_before and events_after for each context.
itertools.chain.from_iterable(
itertools.chain(context["events_before"], context["events_after"]) # type: ignore[arg-type]
for context in contexts.values()
),
# The returned events.
allowed_events,
),
user.to_string(),
)
for context in contexts.values(): for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events( context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now # type: ignore[arg-type] context["events_before"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
) )
context["events_after"] = self._event_serializer.serialize_events( context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now # type: ignore[arg-type] context["events_after"], time_now, bundle_aggregations=aggregations # type: ignore[arg-type]
) )
state_results = {} state_results = {}
@ -442,7 +459,9 @@ class SearchHandler:
results.append( results.append(
{ {
"rank": rank_map[e.event_id], "rank": rank_map[e.event_id],
"result": self._event_serializer.serialize_event(e, time_now), "result": self._event_serializer.serialize_event(
e, time_now, bundle_aggregations=aggregations
),
"context": contexts.get(e.event_id, {}), "context": contexts.get(e.event_id, {}),
} }
) )

View File

@ -715,6 +715,9 @@ class RelationsWorkerStore(SQLBaseStore):
A map of event ID to the bundled aggregation for the event. Not all A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results. events may have bundled aggregations in the results.
""" """
# The already processed event IDs. Tracked separately from the result
# since the result omits events which do not have bundled aggregations.
seen_event_ids = set()
# State events and redacted events do not get bundled aggregations. # State events and redacted events do not get bundled aggregations.
events = [ events = [
@ -728,13 +731,19 @@ class RelationsWorkerStore(SQLBaseStore):
# Fetch other relations per event. # Fetch other relations per event.
for event in events: for event in events:
# De-duplicate events by ID to handle the same event requested multiple
# times. The caches that _get_bundled_aggregation_for_event use should
# capture this, but best to reduce work.
if event.event_id in seen_event_ids:
continue
seen_event_ids.add(event.event_id)
event_result = await self._get_bundled_aggregation_for_event(event, user_id) event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result: if event_result:
results[event.event_id] = event_result results[event.event_id] = event_result
# Fetch any edits. # Fetch any edits.
event_ids = [event.event_id for event in events] edits = await self._get_applicable_edits(seen_event_ids)
edits = await self._get_applicable_edits(event_ids)
for event_id, edit in edits.items(): for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit results.setdefault(event_id, BundledAggregations()).replace = edit

View File

@ -453,7 +453,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
) )
self.assertEquals(400, channel.code, channel.json_body) self.assertEquals(400, channel.code, channel.json_body)
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) @unittest.override_config(
{"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
)
def test_bundled_aggregations(self): def test_bundled_aggregations(self):
""" """
Test that annotations, references, and threads get correctly bundled. Test that annotations, references, and threads get correctly bundled.
@ -579,6 +581,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertTrue(room_timeline["limited"]) self.assertTrue(room_timeline["limited"])
assert_bundle(self._find_event_in_chunk(room_timeline["events"])) assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
# Request search.
channel = self.make_request(
"POST",
"/search",
# Search term matches the parent message.
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
chunk = [
result["result"]
for result in channel.json_body["search_categories"]["room_events"][
"results"
]
]
assert_bundle(self._find_event_in_chunk(chunk))
def test_aggregation_get_event_for_annotation(self): def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled aggregations included """Test that annotations do not get bundled aggregations included
when directly requested. when directly requested.
@ -759,6 +778,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"]) self.assertNotIn("m.relations", channel.json_body["unsigned"])
@unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
def test_edit(self): def test_edit(self):
"""Test that a simple edit works.""" """Test that a simple edit works."""
@ -825,6 +845,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertTrue(room_timeline["limited"]) self.assertTrue(room_timeline["limited"])
assert_bundle(self._find_event_in_chunk(room_timeline["events"])) assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
# Request search.
channel = self.make_request(
"POST",
"/search",
# Search term matches the parent message.
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
chunk = [
result["result"]
for result in channel.json_body["search_categories"]["room_events"][
"results"
]
]
assert_bundle(self._find_event_in_chunk(chunk))
def test_multi_edit(self): def test_multi_edit(self):
"""Test that multiple edits, including attempts by people who """Test that multiple edits, including attempts by people who
shouldn't be allowed, are correctly handled. shouldn't be allowed, are correctly handled.