mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Include whether the requesting user has participated in a thread. (#11577)
Per updates to MSC3440. This is implement as a separate method since it needs to be cached on a per-user basis, instead of a per-thread basis.
This commit is contained in:
parent
251b5567ec
commit
68acb0a29d
1
changelog.d/11577.feature
Normal file
1
changelog.d/11577.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Include whether the requesting user has participated in a thread when generating a summary for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440).
|
@ -537,7 +537,7 @@ class PaginationHandler:
|
|||||||
state_dict = await self.store.get_events(list(state_ids.values()))
|
state_dict = await self.store.get_events(list(state_ids.values()))
|
||||||
state = state_dict.values()
|
state = state_dict.values()
|
||||||
|
|
||||||
aggregations = await self.store.get_bundled_aggregations(events)
|
aggregations = await self.store.get_bundled_aggregations(events, user_id)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
|
@ -1182,12 +1182,18 @@ class RoomContextHandler:
|
|||||||
results["event"] = filtered[0]
|
results["event"] = filtered[0]
|
||||||
|
|
||||||
# Fetch the aggregations.
|
# Fetch the aggregations.
|
||||||
aggregations = await self.store.get_bundled_aggregations([results["event"]])
|
aggregations = await self.store.get_bundled_aggregations(
|
||||||
aggregations.update(
|
[results["event"]], user.to_string()
|
||||||
await self.store.get_bundled_aggregations(results["events_before"])
|
|
||||||
)
|
)
|
||||||
aggregations.update(
|
aggregations.update(
|
||||||
await self.store.get_bundled_aggregations(results["events_after"])
|
await self.store.get_bundled_aggregations(
|
||||||
|
results["events_before"], user.to_string()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
aggregations.update(
|
||||||
|
await self.store.get_bundled_aggregations(
|
||||||
|
results["events_after"], user.to_string()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
results["aggregations"] = aggregations
|
results["aggregations"] = aggregations
|
||||||
|
|
||||||
|
@ -637,7 +637,9 @@ class SyncHandler:
|
|||||||
# as clients will have all the necessary information.
|
# as clients will have all the necessary information.
|
||||||
bundled_aggregations = None
|
bundled_aggregations = None
|
||||||
if limited or newly_joined_room:
|
if limited or newly_joined_room:
|
||||||
bundled_aggregations = await self.store.get_bundled_aggregations(recents)
|
bundled_aggregations = await self.store.get_bundled_aggregations(
|
||||||
|
recents, sync_config.user.to_string()
|
||||||
|
)
|
||||||
|
|
||||||
return TimelineBatch(
|
return TimelineBatch(
|
||||||
events=recents,
|
events=recents,
|
||||||
|
@ -118,7 +118,9 @@ class RelationPaginationServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
# The relations returned for the requested event do include their
|
# The relations returned for the requested event do include their
|
||||||
# bundled aggregations.
|
# bundled aggregations.
|
||||||
aggregations = await self.store.get_bundled_aggregations(events)
|
aggregations = await self.store.get_bundled_aggregations(
|
||||||
|
events, requester.user.to_string()
|
||||||
|
)
|
||||||
serialized_events = self._event_serializer.serialize_events(
|
serialized_events = self._event_serializer.serialize_events(
|
||||||
events, now, bundle_aggregations=aggregations
|
events, now, bundle_aggregations=aggregations
|
||||||
)
|
)
|
||||||
|
@ -663,7 +663,9 @@ class RoomEventServlet(RestServlet):
|
|||||||
|
|
||||||
if event:
|
if event:
|
||||||
# Ensure there are bundled aggregations available.
|
# Ensure there are bundled aggregations available.
|
||||||
aggregations = await self._store.get_bundled_aggregations([event])
|
aggregations = await self._store.get_bundled_aggregations(
|
||||||
|
[event], requester.user.to_string()
|
||||||
|
)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
event_dict = self._event_serializer.serialize_event(
|
event_dict = self._event_serializer.serialize_event(
|
||||||
|
@ -1793,6 +1793,13 @@ class PersistEventsStore:
|
|||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
|
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
|
||||||
)
|
)
|
||||||
|
# It should be safe to only invalidate the cache if the user has not
|
||||||
|
# previously participated in the thread, but that's difficult (and
|
||||||
|
# potentially error-prone) so it is always invalidated.
|
||||||
|
txn.call_after(
|
||||||
|
self.store.get_thread_participated.invalidate,
|
||||||
|
(parent_id, event.room_id, event.sender),
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
|
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
|
||||||
"""Handles keeping track of insertion events and edges/connections.
|
"""Handles keeping track of insertion events and edges/connections.
|
||||||
|
@ -384,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
async def get_thread_summary(
|
async def get_thread_summary(
|
||||||
self, event_id: str, room_id: str
|
self, event_id: str, room_id: str
|
||||||
) -> Tuple[int, Optional[EventBase]]:
|
) -> Tuple[int, Optional[EventBase]]:
|
||||||
"""Get the number of threaded replies, the senders of those replies, and
|
"""Get the number of threaded replies and the latest reply (if any) for the given event.
|
||||||
the latest reply (if any) for the given event.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_id: Summarize the thread related to this event ID.
|
event_id: Summarize the thread related to this event ID.
|
||||||
@ -398,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
def _get_thread_summary_txn(
|
def _get_thread_summary_txn(
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
) -> Tuple[int, Optional[str]]:
|
) -> Tuple[int, Optional[str]]:
|
||||||
# Fetch the count of threaded events and the latest event ID.
|
# Fetch the latest event ID in the thread.
|
||||||
# TODO Should this only allow m.room.message events.
|
# TODO Should this only allow m.room.message events.
|
||||||
sql = """
|
sql = """
|
||||||
SELECT event_id
|
SELECT event_id
|
||||||
@ -419,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
latest_event_id = row[0]
|
latest_event_id = row[0]
|
||||||
|
|
||||||
|
# Fetch the number of threaded replies.
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COUNT(event_id)
|
SELECT COUNT(event_id)
|
||||||
FROM event_relations
|
FROM event_relations
|
||||||
@ -443,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return count, latest_event
|
return count, latest_event
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
async def get_thread_participated(
|
||||||
|
self, event_id: str, room_id: str, user_id: str
|
||||||
|
) -> bool:
|
||||||
|
"""Get whether the requesting user participated in a thread.
|
||||||
|
|
||||||
|
This is separate from get_thread_summary since that can be cached across
|
||||||
|
all users while this value is specific to the requeser.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id: The thread related to this event ID.
|
||||||
|
room_id: The room the event belongs to.
|
||||||
|
user_id: The user requesting the summary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the requesting user participated in the thread, otherwise false.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
|
||||||
|
# Fetch whether the requester has participated or not.
|
||||||
|
sql = """
|
||||||
|
SELECT 1
|
||||||
|
FROM event_relations
|
||||||
|
INNER JOIN events USING (event_id)
|
||||||
|
WHERE
|
||||||
|
relates_to_id = ?
|
||||||
|
AND room_id = ?
|
||||||
|
AND relation_type = ?
|
||||||
|
AND sender = ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
|
||||||
|
return bool(txn.fetchone())
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_thread_summary", _get_thread_summary_txn
|
||||||
|
)
|
||||||
|
|
||||||
async def events_have_relations(
|
async def events_have_relations(
|
||||||
self,
|
self,
|
||||||
parent_ids: List[str],
|
parent_ids: List[str],
|
||||||
@ -546,7 +584,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _get_bundled_aggregation_for_event(
|
async def _get_bundled_aggregation_for_event(
|
||||||
self, event: EventBase
|
self, event: EventBase, user_id: str
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""Generate bundled aggregations for an event.
|
"""Generate bundled aggregations for an event.
|
||||||
|
|
||||||
@ -554,6 +592,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
event: The event to calculate bundled aggregations for.
|
event: The event to calculate bundled aggregations for.
|
||||||
|
user_id: The user requesting the bundled aggregations.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The bundled aggregations for an event, if bundled aggregations are
|
The bundled aggregations for an event, if bundled aggregations are
|
||||||
@ -598,27 +637,32 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
# If this event is the start of a thread, include a summary of the replies.
|
# If this event is the start of a thread, include a summary of the replies.
|
||||||
if self._msc3440_enabled:
|
if self._msc3440_enabled:
|
||||||
(
|
thread_count, latest_thread_event = await self.get_thread_summary(
|
||||||
thread_count,
|
event_id, room_id
|
||||||
latest_thread_event,
|
)
|
||||||
) = await self.get_thread_summary(event_id, room_id)
|
participated = await self.get_thread_participated(
|
||||||
|
event_id, room_id, user_id
|
||||||
|
)
|
||||||
if latest_thread_event:
|
if latest_thread_event:
|
||||||
aggregations[RelationTypes.THREAD] = {
|
aggregations[RelationTypes.THREAD] = {
|
||||||
# Don't bundle aggregations as this could recurse forever.
|
|
||||||
"latest_event": latest_thread_event,
|
"latest_event": latest_thread_event,
|
||||||
"count": thread_count,
|
"count": thread_count,
|
||||||
|
"current_user_participated": participated,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Store the bundled aggregations in the event metadata for later use.
|
# Store the bundled aggregations in the event metadata for later use.
|
||||||
return aggregations
|
return aggregations
|
||||||
|
|
||||||
async def get_bundled_aggregations(
|
async def get_bundled_aggregations(
|
||||||
self, events: Iterable[EventBase]
|
self,
|
||||||
|
events: Iterable[EventBase],
|
||||||
|
user_id: str,
|
||||||
) -> Dict[str, Dict[str, Any]]:
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
"""Generate bundled aggregations for events.
|
"""Generate bundled aggregations for events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
events: The iterable of events to calculate bundled aggregations for.
|
events: The iterable of events to calculate bundled aggregations for.
|
||||||
|
user_id: The user requesting the bundled aggregations.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A map of event ID to the bundled aggregation for the event. Not all
|
A map of event ID to the bundled aggregation for the event. Not all
|
||||||
@ -631,7 +675,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||||||
# TODO Parallelize.
|
# TODO Parallelize.
|
||||||
results = {}
|
results = {}
|
||||||
for event in events:
|
for event in events:
|
||||||
event_result = await self._get_bundled_aggregation_for_event(event)
|
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
|
||||||
if event_result is not None:
|
if event_result is not None:
|
||||||
results[event.event_id] = event_result
|
results[event.event_id] = event_result
|
||||||
|
|
||||||
|
@ -515,6 +515,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
|||||||
2,
|
2,
|
||||||
actual[RelationTypes.THREAD].get("count"),
|
actual[RelationTypes.THREAD].get("count"),
|
||||||
)
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
actual[RelationTypes.THREAD].get("current_user_participated")
|
||||||
|
)
|
||||||
# The latest thread event has some fields that don't matter.
|
# The latest thread event has some fields that don't matter.
|
||||||
self.assert_dict(
|
self.assert_dict(
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user