Implement MSC3912: Relation-based redactions (#14260)

Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
This commit is contained in:
Brendan Abolivier 2022-11-03 16:21:31 +00:00 committed by GitHub
parent e5cd278f3f
commit 86c5a710d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 486 additions and 28 deletions

View File

@ -0,0 +1 @@
Add experimental support for [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3912): Relation-based redactions.

View File

@ -125,6 +125,8 @@ class EventTypes:
MSC2716_BATCH: Final = "org.matrix.msc2716.batch" MSC2716_BATCH: Final = "org.matrix.msc2716.batch"
MSC2716_MARKER: Final = "org.matrix.msc2716.marker" MSC2716_MARKER: Final = "org.matrix.msc2716.marker"
Reaction: Final = "m.reaction"
class ToDeviceEventTypes: class ToDeviceEventTypes:
RoomKeyRequest: Final = "m.room_key_request" RoomKeyRequest: Final = "m.room_key_request"

View File

@ -128,3 +128,6 @@ class ExperimentalConfig(Config):
self.msc3886_endpoint: Optional[str] = experimental.get( self.msc3886_endpoint: Optional[str] = experimental.get(
"msc3886_endpoint", None "msc3886_endpoint", None
) )
# MSC3912: Relation-based redactions.
self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False)

View File

@ -877,6 +877,36 @@ class EventCreationHandler:
return prev_event return prev_event
return None return None
async def get_event_from_transaction(
self,
requester: Requester,
txn_id: str,
room_id: str,
) -> Optional[EventBase]:
"""For the given transaction ID and room ID, check if there is a matching event.
If so, fetch it and return it.
Args:
requester: The requester making the request in the context of which we want
to fetch the event.
txn_id: The transaction ID.
room_id: The room ID.
Returns:
An event if one could be found, None otherwise.
"""
if requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
room_id,
requester.user.to_string(),
requester.access_token_id,
txn_id,
)
if existing_event_id:
return await self.store.get_event(existing_event_id)
return None
async def create_and_send_nonmember_event( async def create_and_send_nonmember_event(
self, self,
requester: Requester, requester: Requester,
@ -956,18 +986,17 @@ class EventCreationHandler:
# extremities to pile up, which in turn leads to state resolution # extremities to pile up, which in turn leads to state resolution
# taking longer. # taking longer.
async with self.limiter.queue(event_dict["room_id"]): async with self.limiter.queue(event_dict["room_id"]):
if txn_id and requester.access_token_id: if txn_id:
existing_event_id = await self.store.get_event_id_from_transaction_id( event = await self.get_event_from_transaction(
event_dict["room_id"], requester, txn_id, event_dict["room_id"]
requester.user.to_string(),
requester.access_token_id,
txn_id,
) )
if existing_event_id: if event:
event = await self.store.get_event(existing_event_id)
# we know it was persisted, so must have a stream ordering # we know it was persisted, so must have a stream ordering
assert event.internal_metadata.stream_ordering assert event.internal_metadata.stream_ordering
return event, event.internal_metadata.stream_ordering return (
event,
event.internal_metadata.stream_ordering,
)
event, context = await self.create_event( event, context = await self.create_event(
requester, requester,

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup
import attr import attr
from synapse.api.constants import RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event from synapse.events import EventBase, relation_from_event
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
@ -75,6 +75,7 @@ class RelationsHandler:
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler() self._event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._event_creation_handler = hs.get_event_creation_handler()
async def get_relations( async def get_relations(
self, self,
@ -205,6 +206,59 @@ class RelationsHandler:
return related_events, next_token return related_events, next_token
async def redact_events_related_to(
self,
requester: Requester,
event_id: str,
initial_redaction_event: EventBase,
relation_types: List[str],
) -> None:
"""Redacts all events related to the given event ID with one of the given
relation types.
This method is expected to be called when redacting the event referred to by
the given event ID.
If an event cannot be redacted (e.g. because of insufficient permissions), log
the error and try to redact the next one.
Args:
requester: The requester to redact events on behalf of.
event_id: The event IDs to look and redact relations of.
initial_redaction_event: The redaction for the event referred to by
event_id.
relation_types: The types of relations to look for.
Raises:
ShadowBanError if the requester is shadow-banned
"""
related_event_ids = (
await self._main_store.get_all_relations_for_event_with_types(
event_id, relation_types
)
)
for related_event_id in related_event_ids:
try:
await self._event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
"content": initial_redaction_event.content,
"room_id": initial_redaction_event.room_id,
"sender": requester.user.to_string(),
"redacts": related_event_id,
},
ratelimit=False,
)
except SynapseError as e:
logger.warning(
"Failed to redact event %s (related to event %s): %s",
related_event_id,
event_id,
e.msg,
)
async def get_annotations_for_event( async def get_annotations_for_event(
self, self,
event_id: str, event_id: str,

View File

@ -52,6 +52,7 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag from synapse.logging.opentracing import set_tag
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client._base import client_patterns from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
super().__init__(hs) super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._relation_handler = hs.get_relations_handler()
self._msc3912_enabled = hs.config.experimental.msc3912_enabled
def register(self, http_server: HttpServer) -> None: def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
@ -1045,6 +1048,21 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
try: try:
with_relations = None
if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content:
with_relations = content["org.matrix.msc3912.with_relations"]
del content["org.matrix.msc3912.with_relations"]
# Check if there's an existing event for this transaction now (even though
# create_and_send_nonmember_event also does it) because, if there's one,
# then we want to skip the call to redact_events_related_to.
event = None
if txn_id:
event = await self.event_creation_handler.get_event_from_transaction(
requester, txn_id, room_id
)
if event is None:
( (
event, event,
_, _,
@ -1059,6 +1077,17 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
}, },
txn_id=txn_id, txn_id=txn_id,
) )
if with_relations:
run_as_background_process(
"redact_related_events",
self._relation_handler.redact_events_related_to,
requester=requester,
event_id=event_id,
initial_redaction_event=event,
relation_types=with_relations,
)
event_id = event.event_id event_id = event.event_id
except ShadowBanError: except ShadowBanError:
event_id = "$" + random_string(43) event_id = "$" + random_string(43)

View File

@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet):
# Adds support for simple HTTP rendezvous as per MSC3886 # Adds support for simple HTTP rendezvous as per MSC3886
"org.matrix.msc3886": self.config.experimental.msc3886_endpoint "org.matrix.msc3886": self.config.experimental.msc3886_endpoint
is not None, is not None,
# Adds support for relation-based redactions as per MSC3912.
"org.matrix.msc3912": self.config.experimental.msc3912_enabled,
}, },
}, },
) )

View File

@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore):
"get_recent_references_for_event", _get_recent_references_for_event_txn "get_recent_references_for_event", _get_recent_references_for_event_txn
) )
async def get_all_relations_for_event_with_types(
self,
event_id: str,
relation_types: List[str],
) -> List[str]:
"""Get the event IDs of all events that have a relation to the given event with
one of the given relation types.
Args:
event_id: The event for which to look for related events.
relation_types: The types of relations to look for.
Returns:
A list of the IDs of the events that relate to the given event with one of
the given relation types.
"""
def get_all_relation_ids_for_event_with_types_txn(
txn: LoggingTransaction,
) -> List[str]:
rows = self.db_pool.simple_select_many_txn(
txn=txn,
table="event_relations",
column="relation_type",
iterable=relation_types,
keyvalues={"relates_to_id": event_id},
retcols=["event_id"],
)
return [row["event_id"] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event_with_types",
func=get_all_relation_ids_for_event_with_types_txn,
)
async def event_includes_relation(self, event_id: str) -> bool: async def event_includes_relation(self, event_id: str) -> bool:
"""Check if the given event relates to another event. """Check if the given event relates to another event.

View File

@ -11,17 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List from typing import List, Optional
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room, sync from synapse.rest.client import login, room, sync
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase, override_config
class RedactionsTestCase(HomeserverTestCase): class RedactionsTestCase(HomeserverTestCase):
@ -67,7 +68,12 @@ class RedactionsTestCase(HomeserverTestCase):
) )
def _redact_event( def _redact_event(
self, access_token: str, room_id: str, event_id: str, expect_code: int = 200 self,
access_token: str,
room_id: str,
event_id: str,
expect_code: int = 200,
with_relations: Optional[List[str]] = None,
) -> JsonDict: ) -> JsonDict:
"""Helper function to send a redaction event. """Helper function to send a redaction event.
@ -75,7 +81,13 @@ class RedactionsTestCase(HomeserverTestCase):
""" """
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id) path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
channel = self.make_request("POST", path, content={}, access_token=access_token) request_content = {}
if with_relations:
request_content["org.matrix.msc3912.with_relations"] = with_relations
channel = self.make_request(
"POST", path, request_content, access_token=access_token
)
self.assertEqual(channel.code, expect_code) self.assertEqual(channel.code, expect_code)
return channel.json_body return channel.json_body
@ -201,3 +213,256 @@ class RedactionsTestCase(HomeserverTestCase):
# These should all succeed, even though this would be denied by # These should all succeed, even though this would be denied by
# the standard message ratelimiter # the standard message ratelimiter
self._redact_event(self.mod_access_token, self.room_id, msg_id) self._redact_event(self.mod_access_token, self.room_id, msg_id)
@override_config({"experimental_features": {"msc3912_enabled": True}})
def test_redact_relations(self) -> None:
"""Tests that we can redact the relations of an event at the same time as the
event itself.
"""
# Send a root event.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "hello"},
tok=self.mod_access_token,
)
root_event_id = res["event_id"]
# Send an edit to this root event.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"body": " * hello world",
"m.new_content": {
"body": "hello world",
"msgtype": "m.text",
},
"m.relates_to": {
"event_id": root_event_id,
"rel_type": RelationTypes.REPLACE,
},
"msgtype": "m.text",
},
tok=self.mod_access_token,
)
edit_event_id = res["event_id"]
# Also send a threaded message whose root is the same as the edit's.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"msgtype": "m.text",
"body": "message 1",
"m.relates_to": {
"event_id": root_event_id,
"rel_type": RelationTypes.THREAD,
},
},
tok=self.mod_access_token,
)
threaded_event_id = res["event_id"]
# Also send a reaction, again with the same root.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Reaction,
content={
"m.relates_to": {
"rel_type": RelationTypes.ANNOTATION,
"event_id": root_event_id,
"key": "👍",
}
},
tok=self.mod_access_token,
)
reaction_event_id = res["event_id"]
# Redact the root event, specifying that we also want to delete events that
# relate to it with m.replace.
self._redact_event(
self.mod_access_token,
self.room_id,
root_event_id,
with_relations=[
RelationTypes.REPLACE,
RelationTypes.THREAD,
],
)
# Check that the root event got redacted.
event_dict = self.helper.get_event(
self.room_id, root_event_id, self.mod_access_token
)
self.assertIn("redacted_because", event_dict, event_dict)
# Check that the edit got redacted.
event_dict = self.helper.get_event(
self.room_id, edit_event_id, self.mod_access_token
)
self.assertIn("redacted_because", event_dict, event_dict)
# Check that the threaded message got redacted.
event_dict = self.helper.get_event(
self.room_id, threaded_event_id, self.mod_access_token
)
self.assertIn("redacted_because", event_dict, event_dict)
# Check that the reaction did not get redacted.
event_dict = self.helper.get_event(
self.room_id, reaction_event_id, self.mod_access_token
)
self.assertNotIn("redacted_because", event_dict, event_dict)
@override_config({"experimental_features": {"msc3912_enabled": True}})
def test_redact_relations_no_perms(self) -> None:
"""Tests that, when redacting a message along with its relations, if not all
the related messages can be redacted because of insufficient permissions, the
server still redacts all the ones that can be.
"""
# Send a root event.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"msgtype": "m.text",
"body": "root",
},
tok=self.other_access_token,
)
root_event_id = res["event_id"]
# Send a first threaded message, this one from the moderator. We do this for the
# first message with the m.thread relation (and not the last one) to ensure
# that, when the server fails to redact it, it doesn't stop there, and it
# instead goes on to redact the other one.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"msgtype": "m.text",
"body": "message 1",
"m.relates_to": {
"event_id": root_event_id,
"rel_type": RelationTypes.THREAD,
},
},
tok=self.mod_access_token,
)
first_threaded_event_id = res["event_id"]
# Send a second threaded message, this time from the user who'll perform the
# redaction.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"msgtype": "m.text",
"body": "message 2",
"m.relates_to": {
"event_id": root_event_id,
"rel_type": RelationTypes.THREAD,
},
},
tok=self.other_access_token,
)
second_threaded_event_id = res["event_id"]
# Redact the thread's root, and request that all threaded messages are also
# redacted. Send that request from the non-mod user, so that the first threaded
# event cannot be redacted.
self._redact_event(
self.other_access_token,
self.room_id,
root_event_id,
with_relations=[RelationTypes.THREAD],
)
# Check that the thread root got redacted.
event_dict = self.helper.get_event(
self.room_id, root_event_id, self.other_access_token
)
self.assertIn("redacted_because", event_dict, event_dict)
# Check that the last message in the thread got redacted, despite failing to
# redact the one before it.
event_dict = self.helper.get_event(
self.room_id, second_threaded_event_id, self.other_access_token
)
self.assertIn("redacted_because", event_dict, event_dict)
# Check that the message that was sent into the tread by the mod user is not
# redacted.
event_dict = self.helper.get_event(
self.room_id, first_threaded_event_id, self.other_access_token
)
self.assertIn("body", event_dict["content"], event_dict)
self.assertEqual("message 1", event_dict["content"]["body"])
@override_config({"experimental_features": {"msc3912_enabled": True}})
def test_redact_relations_txn_id_reuse(self) -> None:
"""Tests that redacting a message using a transaction ID, then reusing the same
transaction ID but providing an additional list of relations to redact, is
effectively a no-op.
"""
# Send a root event.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"msgtype": "m.text",
"body": "root",
},
tok=self.mod_access_token,
)
root_event_id = res["event_id"]
# Send a first threaded message.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"msgtype": "m.text",
"body": "I'm in a thread!",
"m.relates_to": {
"event_id": root_event_id,
"rel_type": RelationTypes.THREAD,
},
},
tok=self.mod_access_token,
)
threaded_event_id = res["event_id"]
# Send a first redaction request which redacts only the root event.
channel = self.make_request(
method="PUT",
path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo",
content={},
access_token=self.mod_access_token,
)
self.assertEqual(channel.code, 200)
# Send a second redaction request which redacts the root event as well as
# threaded messages.
channel = self.make_request(
method="PUT",
path=f"/rooms/{self.room_id}/redact/{root_event_id}/foo",
content={"org.matrix.msc3912.with_relations": [RelationTypes.THREAD]},
access_token=self.mod_access_token,
)
self.assertEqual(channel.code, 200)
# Check that the root event got redacted.
event_dict = self.helper.get_event(
self.room_id, root_event_id, self.mod_access_token
)
self.assertIn("redacted_because", event_dict)
# Check that the threaded message didn't get redacted (since that wasn't part of
# the original redaction).
event_dict = self.helper.get_event(
self.room_id, threaded_event_id, self.mod_access_token
)
self.assertIn("body", event_dict["content"], event_dict)
self.assertEqual("I'm in a thread!", event_dict["content"]["body"])

View File

@ -410,6 +410,43 @@ class RestHelper:
return channel.json_body return channel.json_body
def get_event(
self,
room_id: str,
event_id: str,
tok: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
) -> JsonDict:
"""Request a specific event from the server.
Args:
room_id: the room in which the event was sent.
event_id: the event's ID.
tok: the token to request the event with.
expect_code: the expected HTTP status for the response.
Returns:
The event as a dict.
"""
path = f"/_matrix/client/v3/rooms/{room_id}/event/{event_id}"
if tok:
path = path + f"?access_token={tok}"
channel = make_request(
self.hs.get_reactor(),
self.site,
"GET",
path,
)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
expect_code,
channel.code,
channel.result["body"],
)
return channel.json_body
def _read_write_state( def _read_write_state(
self, self,
room_id: str, room_id: str,