Include whether the requesting user has participated in a thread. (#11577)

Per updates to MSC3440.

This is implement as a separate method since it needs to be cached
on a per-user basis, instead of a per-thread basis.
This commit is contained in:
Patrick Cloke 2022-01-18 11:38:57 -05:00 committed by GitHub
parent 251b5567ec
commit 68acb0a29d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 86 additions and 19 deletions

View File

@ -0,0 +1 @@
Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).

View File

@ -537,7 +537,7 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values())) state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values() state = state_dict.values()
aggregations = await self.store.get_bundled_aggregations(events) aggregations = await self.store.get_bundled_aggregations(events, user_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -1182,12 +1182,18 @@ class RoomContextHandler:
results["event"] = filtered[0] results["event"] = filtered[0]
# Fetch the aggregations. # Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations([results["event"]]) aggregations = await self.store.get_bundled_aggregations(
aggregations.update( [results["event"]], user.to_string()
await self.store.get_bundled_aggregations(results["events_before"])
) )
aggregations.update( aggregations.update(
await self.store.get_bundled_aggregations(results["events_after"]) await self.store.get_bundled_aggregations(
results["events_before"], user.to_string()
)
)
aggregations.update(
await self.store.get_bundled_aggregations(
results["events_after"], user.to_string()
)
) )
results["aggregations"] = aggregations results["aggregations"] = aggregations

View File

@ -637,7 +637,9 @@ class SyncHandler:
# as clients will have all the necessary information. # as clients will have all the necessary information.
bundled_aggregations = None bundled_aggregations = None
if limited or newly_joined_room: if limited or newly_joined_room:
bundled_aggregations = await self.store.get_bundled_aggregations(recents) bundled_aggregations = await self.store.get_bundled_aggregations(
recents, sync_config.user.to_string()
)
return TimelineBatch( return TimelineBatch(
events=recents, events=recents,

View File

@ -118,7 +118,9 @@ class RelationPaginationServlet(RestServlet):
) )
# The relations returned for the requested event do include their # The relations returned for the requested event do include their
# bundled aggregations. # bundled aggregations.
aggregations = await self.store.get_bundled_aggregations(events) aggregations = await self.store.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events( serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations events, now, bundle_aggregations=aggregations
) )

View File

@ -663,7 +663,9 @@ class RoomEventServlet(RestServlet):
if event: if event:
# Ensure there are bundled aggregations available. # Ensure there are bundled aggregations available.
aggregations = await self._store.get_bundled_aggregations([event]) aggregations = await self._store.get_bundled_aggregations(
[event], requester.user.to_string()
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
event_dict = self._event_serializer.serialize_event( event_dict = self._event_serializer.serialize_event(

View File

@ -1793,6 +1793,13 @@ class PersistEventsStore:
txn.call_after( txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id) self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
) )
# It should be safe to only invalidate the cache if the user has not
# previously participated in the thread, but that's difficult (and
# potentially error-prone) so it is always invalidated.
txn.call_after(
self.store.get_thread_participated.invalidate,
(parent_id, event.room_id, event.sender),
)
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections. """Handles keeping track of insertion events and edges/connections.

View File

@ -384,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_thread_summary( async def get_thread_summary(
self, event_id: str, room_id: str self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]: ) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and """Get the number of threaded replies and the latest reply (if any) for the given event.
the latest reply (if any) for the given event.
Args: Args:
event_id: Summarize the thread related to this event ID. event_id: Summarize the thread related to this event ID.
@ -398,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_thread_summary_txn( def _get_thread_summary_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[int, Optional[str]]: ) -> Tuple[int, Optional[str]]:
# Fetch the count of threaded events and the latest event ID. # Fetch the latest event ID in the thread.
# TODO Should this only allow m.room.message events. # TODO Should this only allow m.room.message events.
sql = """ sql = """
SELECT event_id SELECT event_id
@ -419,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_event_id = row[0] latest_event_id = row[0]
# Fetch the number of threaded replies.
sql = """ sql = """
SELECT COUNT(event_id) SELECT COUNT(event_id)
FROM event_relations FROM event_relations
@ -443,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
return count, latest_event return count, latest_event
@cached()
async def get_thread_participated(
self, event_id: str, room_id: str, user_id: str
) -> bool:
"""Get whether the requesting user participated in a thread.
This is separate from get_thread_summary since that can be cached across
all users while this value is specific to the requeser.
Args:
event_id: The thread related to this event ID.
room_id: The room the event belongs to.
user_id: The user requesting the summary.
Returns:
True if the requesting user participated in the thread, otherwise false.
"""
def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
# Fetch whether the requester has participated or not.
sql = """
SELECT 1
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id = ?
AND room_id = ?
AND relation_type = ?
AND sender = ?
"""
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
return bool(txn.fetchone())
return await self.db_pool.runInteraction(
"get_thread_summary", _get_thread_summary_txn
)
async def events_have_relations( async def events_have_relations(
self, self,
parent_ids: List[str], parent_ids: List[str],
@ -546,7 +584,7 @@ class RelationsWorkerStore(SQLBaseStore):
) )
async def _get_bundled_aggregation_for_event( async def _get_bundled_aggregation_for_event(
self, event: EventBase self, event: EventBase, user_id: str
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
"""Generate bundled aggregations for an event. """Generate bundled aggregations for an event.
@ -554,6 +592,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args: Args:
event: The event to calculate bundled aggregations for. event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns: Returns:
The bundled aggregations for an event, if bundled aggregations are The bundled aggregations for an event, if bundled aggregations are
@ -598,27 +637,32 @@ class RelationsWorkerStore(SQLBaseStore):
# If this event is the start of a thread, include a summary of the replies. # If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled: if self._msc3440_enabled:
( thread_count, latest_thread_event = await self.get_thread_summary(
thread_count, event_id, room_id
latest_thread_event, )
) = await self.get_thread_summary(event_id, room_id) participated = await self.get_thread_participated(
event_id, room_id, user_id
)
if latest_thread_event: if latest_thread_event:
aggregations[RelationTypes.THREAD] = { aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": latest_thread_event, "latest_event": latest_thread_event,
"count": thread_count, "count": thread_count,
"current_user_participated": participated,
} }
# Store the bundled aggregations in the event metadata for later use. # Store the bundled aggregations in the event metadata for later use.
return aggregations return aggregations
async def get_bundled_aggregations( async def get_bundled_aggregations(
self, events: Iterable[EventBase] self,
events: Iterable[EventBase],
user_id: str,
) -> Dict[str, Dict[str, Any]]: ) -> Dict[str, Dict[str, Any]]:
"""Generate bundled aggregations for events. """Generate bundled aggregations for events.
Args: Args:
events: The iterable of events to calculate bundled aggregations for. events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns: Returns:
A map of event ID to the bundled aggregation for the event. Not all A map of event ID to the bundled aggregation for the event. Not all
@ -631,7 +675,7 @@ class RelationsWorkerStore(SQLBaseStore):
# TODO Parallelize. # TODO Parallelize.
results = {} results = {}
for event in events: for event in events:
event_result = await self._get_bundled_aggregation_for_event(event) event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result is not None: if event_result is not None:
results[event.event_id] = event_result results[event.event_id] = event_result

View File

@ -515,6 +515,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
2, 2,
actual[RelationTypes.THREAD].get("count"), actual[RelationTypes.THREAD].get("count"),
) )
self.assertTrue(
actual[RelationTypes.THREAD].get("current_user_participated")
)
# The latest thread event has some fields that don't matter. # The latest thread event has some fields that don't matter.
self.assert_dict( self.assert_dict(
{ {