Do not allow cross-room relations, per MSC2674. (#11516)

This commit is contained in:
Patrick Cloke 2021-12-09 13:16:01 -05:00 committed by GitHub
parent 0cc3bf97b4
commit 3b8872299a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 161 additions and 17 deletions

1
changelog.d/11516.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event.

View File

@ -454,23 +454,26 @@ class EventClientSerializer:
return return
event_id = event.event_id event_id = event.event_id
room_id = event.room_id
# The bundled aggregations to include. # The bundled aggregations to include.
aggregations = {} aggregations = {}
annotations = await self.store.get_aggregation_groups_for_event(event_id) annotations = await self.store.get_aggregation_groups_for_event(
event_id, room_id
)
if annotations.chunk: if annotations.chunk:
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
references = await self.store.get_relations_for_event( references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f" event_id, room_id, RelationTypes.REFERENCE, direction="f"
) )
if references.chunk: if references.chunk:
aggregations[RelationTypes.REFERENCE] = references.to_dict() aggregations[RelationTypes.REFERENCE] = references.to_dict()
edit = None edit = None
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
edit = await self.store.get_applicable_edit(event_id) edit = await self.store.get_applicable_edit(event_id, room_id)
if edit: if edit:
# If there is an edit replace the content, preserving existing # If there is an edit replace the content, preserving existing
@ -503,7 +506,7 @@ class EventClientSerializer:
( (
thread_count, thread_count,
latest_thread_event, latest_thread_event,
) = await self.store.get_thread_summary(event_id) ) = await self.store.get_thread_summary(event_id, room_id)
if latest_thread_event: if latest_thread_event:
aggregations[RelationTypes.THREAD] = { aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever. # Don't bundle aggregations as this could recurse forever.

View File

@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet):
pagination_chunk = await self.store.get_relations_for_event( pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id, event_id=parent_id,
room_id=room_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
limit=limit, limit=limit,
@ -317,6 +318,7 @@ class RelationAggregationPaginationServlet(RestServlet):
pagination_chunk = await self.store.get_aggregation_groups_for_event( pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id, event_id=parent_id,
room_id=room_id,
event_type=event_type, event_type=event_type,
limit=limit, limit=limit,
from_token=from_token, from_token=from_token,
@ -383,7 +385,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to # This checks that a) the event exists and b) the user is allowed to
# view it. # view it.
await self.event_handler.get_event(requester.user, room_id, parent_id) event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
if relation_type != RelationTypes.ANNOTATION: if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
@ -402,6 +406,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
result = await self.store.get_relations_for_event( result = await self.store.get_relations_for_event(
event_id=parent_id, event_id=parent_id,
room_id=room_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
aggregation_key=key, aggregation_key=key,

View File

@ -1780,10 +1780,14 @@ class PersistEventsStore:
) )
if rel_type == RelationTypes.REPLACE: if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) txn.call_after(
self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
)
if rel_type == RelationTypes.THREAD: if rel_type == RelationTypes.THREAD:
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
)
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
"""Handles keeping track of insertion events and edges/connections. """Handles keeping track of insertion events and edges/connections.

View File

@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_relations_for_event( async def get_relations_for_event(
self, self,
event_id: str, event_id: str,
room_id: str,
relation_type: Optional[str] = None, relation_type: Optional[str] = None,
event_type: Optional[str] = None, event_type: Optional[str] = None,
aggregation_key: Optional[str] = None, aggregation_key: Optional[str] = None,
@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args: Args:
event_id: Fetch events that relate to this event ID. event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given. relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given. event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given. aggregation_key: Only fetch events with this aggregation key, if given.
@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore):
the form `{"event_id": "..."}`. the form `{"event_id": "..."}`.
""" """
where_clause = ["relates_to_id = ?"] where_clause = ["relates_to_id = ?", "room_id = ?"]
where_args: List[Union[str, int]] = [event_id] where_args: List[Union[str, int]] = [event_id, room_id]
if relation_type is not None: if relation_type is not None:
where_clause.append("relation_type = ?") where_clause.append("relation_type = ?")
@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def get_aggregation_groups_for_event( async def get_aggregation_groups_for_event(
self, self,
event_id: str, event_id: str,
room_id: str,
event_type: Optional[str] = None, event_type: Optional[str] = None,
limit: int = 5, limit: int = 5,
direction: str = "b", direction: str = "b",
@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args: Args:
event_id: Fetch events that relate to this event ID. event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
event_type: Only fetch events with this event type, if given. event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups. limit: Only fetch the `limit` groups.
direction: Whether to fetch the highest count first (`"b"`) or direction: Whether to fetch the highest count first (`"b"`) or
@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore):
`type`, `key` and `count` fields. `type`, `key` and `count` fields.
""" """
where_clause = ["relates_to_id = ?", "relation_type = ?"] where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION] where_args: List[Union[str, int]] = [
event_id,
room_id,
RelationTypes.ANNOTATION,
]
if event_type: if event_type:
where_clause.append("type = ?") where_clause.append("type = ?")
@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore):
) )
@cached() @cached()
async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: async def get_applicable_edit(
self, event_id: str, room_id: str
) -> Optional[EventBase]:
"""Get the most recent edit (if any) that has happened for the given """Get the most recent edit (if any) that has happened for the given
event. event.
@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args: Args:
event_id: The original event ID event_id: The original event ID
room_id: The original event's room ID
Returns: Returns:
The most recent edit, if any. The most recent edit, if any.
@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore):
WHERE WHERE
relates_to_id = ? relates_to_id = ?
AND relation_type = ? AND relation_type = ?
AND edit.room_id = ?
AND edit.type = 'm.room.message' AND edit.type = 'm.room.message'
ORDER by edit.origin_server_ts DESC, edit.event_id DESC ORDER by edit.origin_server_ts DESC, edit.event_id DESC
LIMIT 1 LIMIT 1
""" """
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]: def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
txn.execute(sql, (event_id, RelationTypes.REPLACE)) txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
row = txn.fetchone() row = txn.fetchone()
if row: if row:
return row[0] return row[0]
@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore):
@cached() @cached()
async def get_thread_summary( async def get_thread_summary(
self, event_id: str self, event_id: str, room_id: str
) -> Tuple[int, Optional[EventBase]]: ) -> Tuple[int, Optional[EventBase]]:
"""Get the number of threaded replies, the senders of those replies, and """Get the number of threaded replies, the senders of those replies, and
the latest reply (if any) for the given event. the latest reply (if any) for the given event.
Args: Args:
event_id: The original event ID event_id: Summarize the thread related to this event ID.
room_id: The room the event belongs to.
Returns: Returns:
The number of items in the thread and the most recent response, if any. The number of items in the thread and the most recent response, if any.
@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore):
INNER JOIN events USING (event_id) INNER JOIN events USING (event_id)
WHERE WHERE
relates_to_id = ? relates_to_id = ?
AND room_id = ?
AND relation_type = ? AND relation_type = ?
ORDER BY topological_ordering DESC, stream_ordering DESC ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT 1 LIMIT 1
""" """
txn.execute(sql, (event_id, RelationTypes.THREAD)) txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
row = txn.fetchone() row = txn.fetchone()
if row is None: if row is None:
return 0, None return 0, None
@ -378,11 +392,13 @@ class RelationsWorkerStore(SQLBaseStore):
sql = """ sql = """
SELECT COALESCE(COUNT(event_id), 0) SELECT COALESCE(COUNT(event_id), 0)
FROM event_relations FROM event_relations
INNER JOIN events USING (event_id)
WHERE WHERE
relates_to_id = ? relates_to_id = ?
AND room_id = ?
AND relation_type = ? AND relation_type = ?
""" """
txn.execute(sql, (event_id, RelationTypes.THREAD)) txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
count = txn.fetchone()[0] # type: ignore[index] count = txn.fetchone()[0] # type: ignore[index]
return count, latest_event_id return count, latest_event_id

View File

@ -16,6 +16,7 @@
import itertools import itertools
import urllib.parse import urllib.parse
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from unittest.mock import patch
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin from synapse.rest import admin
@ -23,6 +24,8 @@ from synapse.rest.client import login, register, relations, room, sync
from tests import unittest from tests import unittest
from tests.server import FakeChannel from tests.server import FakeChannel
from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
class RelationsTestCase(unittest.HomeserverTestCase): class RelationsTestCase(unittest.HomeserverTestCase):
@ -651,6 +654,118 @@ class RelationsTestCase(unittest.HomeserverTestCase):
}, },
) )
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_ignore_invalid_room(self):
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
res = self.helper.send(room2, body="Hi!", tok=self.user_token)
parent_id = res["event_id"]
# Disable the validation to pretend this came over federation.
with patch(
"synapse.handlers.message.EventCreationHandler._validate_event_relation",
new=lambda self, event: make_awaitable(None),
):
# Generate a various relations from a different room.
self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.reaction",
sender=self.user_id,
content={
"m.relates_to": {
"rel_type": RelationTypes.ANNOTATION,
"event_id": parent_id,
"key": "A",
}
},
)
)
self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.room.message",
sender=self.user_id,
content={
"body": "foo",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": RelationTypes.REFERENCE,
"event_id": parent_id,
},
},
)
)
self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.room.message",
sender=self.user_id,
content={
"body": "foo",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": parent_id,
},
},
)
)
self.get_success(
inject_event(
self.hs,
room_id=self.room,
type="m.room.message",
sender=self.user_id,
content={
"body": "foo",
"msgtype": "m.text",
"new_content": {
"body": "new content",
"msgtype": "m.text",
},
"m.relates_to": {
"rel_type": RelationTypes.REPLACE,
"event_id": parent_id,
},
},
)
)
# They should be ignored when fetching relations.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
# And when fetching aggregations.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
# And for bundled aggregations.
channel = self.make_request(
"GET",
f"/rooms/{room2}/event/{parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
self.assertNotIn("m.relations", channel.json_body["unsigned"])
def test_edit(self): def test_edit(self):
"""Test that a simple edit works.""" """Test that a simple edit works."""