Implement MSC3912: Relation-based redactions (#14260)

Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com>
This commit is contained in:
Brendan Abolivier 2022-11-03 16:21:31 +00:00 committed by GitHub
parent e5cd278f3f
commit 86c5a710d8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 486 additions and 28 deletions

View file

@ -125,6 +125,8 @@ class EventTypes:
MSC2716_BATCH: Final = "org.matrix.msc2716.batch"
MSC2716_MARKER: Final = "org.matrix.msc2716.marker"
Reaction: Final = "m.reaction"
class ToDeviceEventTypes:
RoomKeyRequest: Final = "m.room_key_request"

View file

@ -128,3 +128,6 @@ class ExperimentalConfig(Config):
self.msc3886_endpoint: Optional[str] = experimental.get(
"msc3886_endpoint", None
)
# MSC3912: Relation-based redactions.
self.msc3912_enabled: bool = experimental.get("msc3912_enabled", False)

View file

@ -877,6 +877,36 @@ class EventCreationHandler:
return prev_event
return None
async def get_event_from_transaction(
self,
requester: Requester,
txn_id: str,
room_id: str,
) -> Optional[EventBase]:
"""For the given transaction ID and room ID, check if there is a matching event.
If so, fetch it and return it.
Args:
requester: The requester making the request in the context of which we want
to fetch the event.
txn_id: The transaction ID.
room_id: The room ID.
Returns:
An event if one could be found, None otherwise.
"""
if requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
room_id,
requester.user.to_string(),
requester.access_token_id,
txn_id,
)
if existing_event_id:
return await self.store.get_event(existing_event_id)
return None
async def create_and_send_nonmember_event(
self,
requester: Requester,
@ -956,18 +986,17 @@ class EventCreationHandler:
# extremities to pile up, which in turn leads to state resolution
# taking longer.
async with self.limiter.queue(event_dict["room_id"]):
if txn_id and requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
event_dict["room_id"],
requester.user.to_string(),
requester.access_token_id,
txn_id,
if txn_id:
event = await self.get_event_from_transaction(
requester, txn_id, event_dict["room_id"]
)
if existing_event_id:
event = await self.store.get_event(existing_event_id)
if event:
# we know it was persisted, so must have a stream ordering
assert event.internal_metadata.stream_ordering
return event, event.internal_metadata.stream_ordering
return (
event,
event.internal_metadata.stream_ordering,
)
event, context = await self.create_event(
requester,

View file

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tup
import attr
from synapse.api.constants import RelationTypes
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.logging.opentracing import trace
@ -75,6 +75,7 @@ class RelationsHandler:
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self._event_creation_handler = hs.get_event_creation_handler()
async def get_relations(
self,
@ -205,6 +206,59 @@ class RelationsHandler:
return related_events, next_token
async def redact_events_related_to(
self,
requester: Requester,
event_id: str,
initial_redaction_event: EventBase,
relation_types: List[str],
) -> None:
"""Redacts all events related to the given event ID with one of the given
relation types.
This method is expected to be called when redacting the event referred to by
the given event ID.
If an event cannot be redacted (e.g. because of insufficient permissions), log
the error and try to redact the next one.
Args:
requester: The requester to redact events on behalf of.
event_id: The event IDs to look and redact relations of.
initial_redaction_event: The redaction for the event referred to by
event_id.
relation_types: The types of relations to look for.
Raises:
ShadowBanError if the requester is shadow-banned
"""
related_event_ids = (
await self._main_store.get_all_relations_for_event_with_types(
event_id, relation_types
)
)
for related_event_id in related_event_ids:
try:
await self._event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
"content": initial_redaction_event.content,
"room_id": initial_redaction_event.room_id,
"sender": requester.user.to_string(),
"redacts": related_event_id,
},
ratelimit=False,
)
except SynapseError as e:
logger.warning(
"Failed to redact event %s (related to event %s): %s",
related_event_id,
event_id,
e.msg,
)
async def get_annotations_for_event(
self,
event_id: str,

View file

@ -52,6 +52,7 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.state import StateFilter
@ -1029,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
self._relation_handler = hs.get_relations_handler()
self._msc3912_enabled = hs.config.experimental.msc3912_enabled
def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
@ -1045,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
content = parse_json_object_from_request(request)
try:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"redacts": event_id,
},
txn_id=txn_id,
)
with_relations = None
if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content:
with_relations = content["org.matrix.msc3912.with_relations"]
del content["org.matrix.msc3912.with_relations"]
# Check if there's an existing event for this transaction now (even though
# create_and_send_nonmember_event also does it) because, if there's one,
# then we want to skip the call to redact_events_related_to.
event = None
if txn_id:
event = await self.event_creation_handler.get_event_from_transaction(
requester, txn_id, room_id
)
if event is None:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"redacts": event_id,
},
txn_id=txn_id,
)
if with_relations:
run_as_background_process(
"redact_related_events",
self._relation_handler.redact_events_related_to,
requester=requester,
event_id=event_id,
initial_redaction_event=event,
relation_types=with_relations,
)
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)

View file

@ -119,6 +119,8 @@ class VersionsRestServlet(RestServlet):
# Adds support for simple HTTP rendezvous as per MSC3886
"org.matrix.msc3886": self.config.experimental.msc3886_endpoint
is not None,
# Adds support for relation-based redactions as per MSC3912.
"org.matrix.msc3912": self.config.experimental.msc3912_enabled,
},
},
)

View file

@ -295,6 +295,42 @@ class RelationsWorkerStore(SQLBaseStore):
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
async def get_all_relations_for_event_with_types(
self,
event_id: str,
relation_types: List[str],
) -> List[str]:
"""Get the event IDs of all events that have a relation to the given event with
one of the given relation types.
Args:
event_id: The event for which to look for related events.
relation_types: The types of relations to look for.
Returns:
A list of the IDs of the events that relate to the given event with one of
the given relation types.
"""
def get_all_relation_ids_for_event_with_types_txn(
txn: LoggingTransaction,
) -> List[str]:
rows = self.db_pool.simple_select_many_txn(
txn=txn,
table="event_relations",
column="relation_type",
iterable=relation_types,
keyvalues={"relates_to_id": event_id},
retcols=["event_id"],
)
return [row["event_id"] for row in rows]
return await self.db_pool.runInteraction(
desc="get_all_relation_ids_for_event_with_types",
func=get_all_relation_ids_for_event_with_types_txn,
)
async def event_includes_relation(self, event_id: str) -> bool:
"""Check if the given event relates to another event.