Recursively fetch the thread for receipts & notifications. (#13824)

Consider an event to be part of a thread if you can follow a
chain of relations up to a thread root.

Part of MSC3773 & MSC3771.
This commit is contained in:
Patrick Cloke 2022-10-04 11:36:16 -04:00 committed by GitHub
parent 3e74ad20db
commit 2b6d41ebd6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 162 additions and 2 deletions

View File

@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).

View File

@ -286,8 +286,13 @@ class BulkPushRuleEvaluator:
relation.parent_id, relation.parent_id,
itertools.chain(*(r.rules() for r in rules_by_user.values())), itertools.chain(*(r.rules() for r in rules_by_user.values())),
) )
# Recursively attempt to find the thread this event relates to.
if relation.rel_type == RelationTypes.THREAD: if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id thread_id = relation.parent_id
else:
# Since the event has not yet been persisted we check whether
# the parent is part of a thread.
thread_id = await self.store.get_thread_id(relation.parent_id) or "main"
evaluator = PushRuleEvaluator( evaluator = PushRuleEvaluator(
_flatten_dict(event), _flatten_dict(event),

View File

@ -16,7 +16,7 @@ import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.api.errors import SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -43,6 +43,7 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler() self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler() self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self._main_store = hs.get_datastores().main
self._known_receipt_types = { self._known_receipt_types = {
ReceiptTypes.READ, ReceiptTypes.READ,
@ -71,7 +72,24 @@ class ReceiptRestServlet(RestServlet):
thread_id = body.get("thread_id") thread_id = body.get("thread_id")
if not thread_id or not isinstance(thread_id, str): if not thread_id or not isinstance(thread_id, str):
raise SynapseError( raise SynapseError(
400, "thread_id field must be a non-empty string" 400,
"thread_id field must be a non-empty string",
Codes.INVALID_PARAM,
)
if receipt_type == ReceiptTypes.FULLY_READ:
raise SynapseError(
400,
f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.",
Codes.INVALID_PARAM,
)
# Ensure the event ID roughly correlates to the thread ID.
if thread_id != await self._main_store.get_thread_id(event_id):
raise SynapseError(
400,
f"event_id {event_id} is not related to thread {thread_id}",
Codes.INVALID_PARAM,
) )
await self.presence_handler.bump_presence_active_time(requester.user) await self.presence_handler.bump_presence_active_time(requester.user)

View File

@ -832,6 +832,42 @@ class RelationsWorkerStore(SQLBaseStore):
"get_event_relations", _get_event_relations "get_event_relations", _get_event_relations
) )
@cached()
async def get_thread_id(self, event_id: str) -> Optional[str]:
"""
Get the thread ID for an event. This considers multi-level relations,
e.g. an annotation to an event which is part of a thread.
Args:
event_id: The event ID to fetch the thread ID for.
Returns:
The event ID of the root event in the thread, if this event is part
of a thread. None, otherwise.
"""
# Since event relations form a tree, we should only ever find 0 or 1
# results from the below query.
sql = """
WITH RECURSIVE related_events AS (
SELECT event_id, relates_to_id, relation_type
FROM event_relations
WHERE event_id = ?
UNION SELECT e.event_id, e.relates_to_id, e.relation_type
FROM event_relations e
INNER JOIN related_events r ON r.relates_to_id = e.event_id
) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread';
"""
def _get_thread_id(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id,))
# TODO Should we ensure there's only a single result here?
row = txn.fetchone()
if row:
return row[0]
return None
return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
class RelationsStore(RelationsWorkerStore): class RelationsStore(RelationsWorkerStore):
pass pass

View File

@ -588,6 +588,106 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_rotate() _rotate()
_assert_counts(0, 0, 0, 0) _assert_counts(0, 0, 0, 0)
def test_recursive_thread(self) -> None:
"""
Events related to events in a thread should still be considered part of
that thread.
"""
# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")
# And another users to send events.
other_id = self.register_user("other", "pass")
other_token = self.login("other", "pass")
# Create a room and put both users in it.
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)
# Update the user's push rules to care about reaction events.
self.get_success(
self.store.add_push_rule(
user_id,
"related_events",
priority_class=5,
conditions=[
{"kind": "event_match", "key": "type", "pattern": "m.reaction"}
],
actions=["notify"],
)
)
def _create_event(type: str, content: JsonDict) -> str:
result = self.helper.send_event(
room_id, type=type, content=content, tok=other_token
)
return result["event_id"]
def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
self.store._get_unread_counts_by_receipt_txn,
room_id,
user_id,
)
)
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count, unread_count=0, highlight_count=0
),
)
if thread_notif_count:
self.assertEqual(
counts.threads,
{
thread_id: NotifCounts(
notify_count=thread_notif_count,
unread_count=0,
highlight_count=0,
),
},
)
else:
self.assertEqual(counts.threads, {})
# Create a root event.
thread_id = _create_event(
"m.room.message", {"msgtype": "m.text", "body": "msg"}
)
_assert_counts(1, 0)
# Reply, creating a thread.
reply_id = _create_event(
"m.room.message",
{
"msgtype": "m.text",
"body": "msg",
"m.relates_to": {
"rel_type": "m.thread",
"event_id": thread_id,
},
},
)
_assert_counts(1, 1)
# Create an event related to a thread event, this should still appear in
# the thread.
_create_event(
type="m.reaction",
content={
"m.relates_to": {
"rel_type": "m.annotation",
"event_id": reply_id,
"key": "A",
}
},
)
_assert_counts(1, 2)
def test_find_first_stream_ordering_after_ts(self) -> None: def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None: def add_event(so: int, ts: int) -> None:
self.get_success( self.get_success(