Recursively fetch the thread for receipts & notifications. (#13824)

Consider an event to be part of a thread if you can follow a
chain of relations up to a thread root.

Part of MSC3773 & MSC3771.
This commit is contained in:
Patrick Cloke 2022-10-04 11:36:16 -04:00 committed by GitHub
parent 3e74ad20db
commit 2b6d41ebd6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 162 additions and 2 deletions

View file

@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore):
"get_event_relations", _get_event_relations
)
@cached()
async def get_thread_id(self, event_id: str) -> Optional[str]:
"""
Get the thread ID for an event. This considers multi-level relations,
e.g. an annotation to an event which is part of a thread.
Args:
event_id: The event ID to fetch the thread ID for.
Returns:
The event ID of the root event in the thread, if this event is part
of a thread. None, otherwise.
"""
# Since event relations form a tree, we should only ever find 0 or 1
# results from the below query.
sql = """
WITH RECURSIVE related_events AS (
SELECT event_id, relates_to_id, relation_type
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type
FROM event_relations e
INNER JOIN related_events r ON r.relates_to_id = e.event_id
) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
"""
def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id,))
# TODO Should we ensure there's only a single result here?
row = txn.fetchone()
if row:
return row[0]
return None
return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
class RelationsStore(RelationsWorkerStore):
pass