Add ephemeral messages support (MSC2228) (#6409)

Implement part [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). The parts that differ are:

* the feature is hidden behind a configuration flag (`enable_ephemeral_messages`)
* self-destruction doesn't happen for state events
* only implement support for the `m.self_destruct_after` field (not the `m.self_destruct` one)
* doesn't send synthetic redactions to clients because for this specific case we consider the clients to be able to destroy an event themselves, instead we just censor it (by pruning its JSON) in the database
This commit is contained in:
Brendan Abolivier 2019-12-03 19:19:45 +00:00 committed by GitHub
parent 620f98b65b
commit 54dd5dc12b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 379 additions and 7 deletions

1
changelog.d/6409.feature Normal file
View File

@ -0,0 +1 @@
Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228).

View File

@ -147,3 +147,7 @@ class EventContentFields(object):
# Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
LABELS = "org.matrix.labels"
# Timestamp to delete the event after
# cf https://github.com/matrix-org/matrix-doc/pull/2228
SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"

View File

@ -490,6 +490,8 @@ class ServerConfig(Config):
"cleanup_extremities_with_dummy_events", True
)
self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False)
def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners)

View File

@ -121,6 +121,7 @@ class FederationHandler(BaseHandler):
self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
@ -141,6 +142,8 @@ class FederationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
@defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
""" Process a PDU received via a federation /send/ transaction, or
@ -2715,6 +2718,11 @@ class FederationHandler(BaseHandler):
event_and_contexts, backfilled=backfilled
)
if self._ephemeral_messages_enabled:
for (event, context) in event_and_contexts:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
yield self._notify_persisted_event(event, max_stream_id)

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from six import iteritems, itervalues, string_types
@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
from twisted.internet.interfaces import IDelayedCall
from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes
from synapse.api.constants import (
EventContentFields,
EventTypes,
Membership,
RelationTypes,
UserTypes,
)
from synapse.api.errors import (
AuthError,
Codes,
@ -62,6 +70,17 @@ class MessageHandler(object):
self.storage = hs.get_storage()
self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
self._is_worker_app = bool(hs.config.worker_app)
# The scheduled call to self._expire_event. None if no call is currently
# scheduled.
self._scheduled_expiry = None # type: Optional[IDelayedCall]
if not hs.config.worker_app:
run_as_background_process(
"_schedule_next_expiry", self._schedule_next_expiry
)
@defer.inlineCallbacks
def get_room_data(
@ -225,6 +244,100 @@ class MessageHandler(object):
for user_id, profile in iteritems(users_with_profile)
}
def maybe_schedule_expiry(self, event):
"""Schedule the expiry of an event if there's not already one scheduled,
or if the one running is for an event that will expire after the provided
timestamp.
This function needs to invalidate the event cache, which is only possible on
the master process, and therefore needs to be run on there.
Args:
event (EventBase): The event to schedule the expiry of.
"""
assert not self._is_worker_app
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
if not isinstance(expiry_ts, int) or event.is_state():
return
# _schedule_expiry_for_event won't actually schedule anything if there's already
# a task scheduled for a timestamp that's sooner than the provided one.
self._schedule_expiry_for_event(event.event_id, expiry_ts)
@defer.inlineCallbacks
def _schedule_next_expiry(self):
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
and schedule an expiry task for it.
If there's no event left to expire, set _expiry_scheduled to None so that a
future call to save_expiry_ts can schedule a new expiry task.
"""
# Try to get the expiry timestamp of the next event to expire.
res = yield self.store.get_next_event_to_expire()
if res:
event_id, expiry_ts = res
self._schedule_expiry_for_event(event_id, expiry_ts)
def _schedule_expiry_for_event(self, event_id, expiry_ts):
"""Schedule an expiry task for the provided event if there's not already one
scheduled at a timestamp that's sooner than the provided one.
Args:
event_id (str): The ID of the event to expire.
expiry_ts (int): The timestamp at which to expire the event.
"""
if self._scheduled_expiry:
# If the provided timestamp refers to a time before the scheduled time of the
# next expiry task, cancel that task and reschedule it for this timestamp.
next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000
if expiry_ts < next_scheduled_expiry_ts:
self._scheduled_expiry.cancel()
else:
return
# Figure out how many seconds we need to wait before expiring the event.
now_ms = self.clock.time_msec()
delay = (expiry_ts - now_ms) / 1000
# callLater doesn't support negative delays, so trim the delay to 0 if we're
# in that case.
if delay < 0:
delay = 0
logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay)
self._scheduled_expiry = self.clock.call_later(
delay,
run_as_background_process,
"_expire_event",
self._expire_event,
event_id,
)
@defer.inlineCallbacks
def _expire_event(self, event_id):
"""Retrieve and expire an event that needs to be expired from the database.
If the event doesn't exist in the database, log it and delete the expiry date
from the database (so that we don't try to expire it again).
"""
assert self._ephemeral_events_enabled
self._scheduled_expiry = None
logger.info("Expiring event %s", event_id)
try:
# Expire the event if we know about it. This function also deletes the expiry
# date from the database in the same database transaction.
yield self.store.expire_event(event_id)
except Exception as e:
logger.error("Could not expire event %s: %r", event_id, e)
# Schedule the expiry of the next event to expire.
yield self._schedule_next_expiry()
# The duration (in ms) after which rooms should be removed
# `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try
@ -295,6 +408,10 @@ class EventCreationHandler(object):
5 * 60 * 1000,
)
self._message_handler = hs.get_message_handler()
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
@defer.inlineCallbacks
def create_event(
self,
@ -877,6 +994,10 @@ class EventCreationHandler(object):
event, context=context
)
if self._ephemeral_events_enabled:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():

View File

@ -130,6 +130,8 @@ class EventsStore(
if self.hs.config.redaction_retention_period is not None:
hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
@defer.inlineCallbacks
def _read_forward_extremities(self):
def fetch(txn):
@ -940,6 +942,12 @@ class EventsStore(
txn, event.event_id, labels, event.room_id, event.depth
)
if self._ephemeral_messages_enabled:
# If there's an expiry timestamp on the event, store it.
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
if isinstance(expiry_ts, int) and not event.is_state():
self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@ -1101,12 +1109,7 @@ class EventsStore(
def _update_censor_txn(txn):
for redaction_id, event_id, pruned_json in updates:
if pruned_json:
self._simple_update_one_txn(
txn,
table="event_json",
keyvalues={"event_id": event_id},
updatevalues={"json": pruned_json},
)
self._censor_event_txn(txn, event_id, pruned_json)
self._simple_update_one_txn(
txn,
@ -1117,6 +1120,22 @@ class EventsStore(
yield self.runInteraction("_update_censor_txn", _update_censor_txn)
def _censor_event_txn(self, txn, event_id, pruned_json):
"""Censor an event by replacing its JSON in the event_json table with the
provided pruned JSON.
Args:
txn (LoggingTransaction): The database transaction.
event_id (str): The ID of the event to censor.
pruned_json (str): The pruned JSON
"""
self._simple_update_one_txn(
txn,
table="event_json",
keyvalues={"event_id": event_id},
updatevalues={"json": pruned_json},
)
@defer.inlineCallbacks
def count_daily_messages(self):
"""
@ -1957,6 +1976,101 @@ class EventsStore(
],
)
def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
"""Save the expiry timestamp associated with a given event ID.
Args:
txn (LoggingTransaction): The database transaction to use.
event_id (str): The event ID the expiry timestamp is associated with.
expiry_ts (int): The timestamp at which to expire (delete) the event.
"""
return self._simple_insert_txn(
txn=txn,
table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts},
)
@defer.inlineCallbacks
def expire_event(self, event_id):
"""Retrieve and expire an event that has expired, and delete its associated
expiry timestamp. If the event can't be retrieved, delete its associated
timestamp so we don't try to expire it again in the future.
Args:
event_id (str): The ID of the event to delete.
"""
# Try to retrieve the event's content from the database or the event cache.
event = yield self.get_event(event_id)
def delete_expired_event_txn(txn):
# Delete the expiry timestamp associated with this event from the database.
self._delete_event_expiry_txn(txn, event_id)
if not event:
# If we can't find the event, log a warning and delete the expiry date
# from the database so that we don't try to expire it again in the
# future.
logger.warning(
"Can't expire event %s because we don't have it.", event_id
)
return
# Prune the event's dict then convert it to JSON.
pruned_json = encode_json(prune_event_dict(event.get_dict()))
# Update the event_json table to replace the event's JSON with the pruned
# JSON.
self._censor_event_txn(txn, event.event_id, pruned_json)
# We need to invalidate the event cache entry for this event because we
# changed its content in the database. We can't call
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
# right type.
txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
# Send that invalidation to replication so that other workers also invalidate
# the event cache.
self._send_invalidation_to_replication(
txn, "_get_event_cache", (event.event_id,)
)
yield self.runInteraction("delete_expired_event", delete_expired_event_txn)
def _delete_event_expiry_txn(self, txn, event_id):
"""Delete the expiry timestamp associated with an event ID without deleting the
actual event.
Args:
txn (LoggingTransaction): The transaction to use to perform the deletion.
event_id (str): The event ID to delete the associated expiry timestamp of.
"""
return self._simple_delete_txn(
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
)
def get_next_event_to_expire(self):
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
table, or None if there's no more event to expire.
Returns: Deferred[Optional[Tuple[str, int]]]
A tuple containing the event ID as its first element and an expiry timestamp
as its second one, if there's at least one row in the event_expiry table.
None otherwise.
"""
def get_next_event_to_expire_txn(txn):
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
ORDER BY expiry_ts ASC LIMIT 1
"""
)
return txn.fetchone()
return self.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
AllNewEventsResult = namedtuple(
"AllNewEventsResult",

View File

@ -0,0 +1,21 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS event_expiry (
event_id TEXT PRIMARY KEY,
expiry_ts BIGINT NOT NULL
);
CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts);

View File

@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest import admin
from synapse.rest.client.v1 import room
from tests import unittest
class EphemeralMessageTestCase(unittest.HomeserverTestCase):
user_id = "@user:test"
servlets = [
admin.register_servlets,
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
config = self.default_config()
config["enable_ephemeral_messages"] = True
self.hs = self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor, clock, homeserver):
self.room_id = self.helper.create_room_as(self.user_id)
def test_message_expiry_no_delay(self):
"""Tests that sending a message sent with a m.self_destruct_after field set to the
past results in that event being deleted right away.
"""
# Send a message in the room that has expired. From here, the reactor clock is
# at 200ms, so 0 is in the past, and even if that wasn't the case and the clock
# is at 0ms the code path is the same if the event's expiry timestamp is the
# current timestamp.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"msgtype": "m.text",
"body": "hello",
EventContentFields.SELF_DESTRUCT_AFTER: 0,
},
)
event_id = res["event_id"]
# Check that we can't retrieve the content of the event.
event_content = self.get_event(self.room_id, event_id)["content"]
self.assertFalse(bool(event_content), event_content)
def test_message_expiry_delay(self):
"""Tests that sending a message with a m.self_destruct_after field set to the
future results in that event not being deleted right away, but advancing the
clock to after that expiry timestamp causes the event to be deleted.
"""
# Send a message in the room that'll expire in 1s.
res = self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"msgtype": "m.text",
"body": "hello",
EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000,
},
)
event_id = res["event_id"]
# Check that we can retrieve the content of the event before it has expired.
event_content = self.get_event(self.room_id, event_id)["content"]
self.assertTrue(bool(event_content), event_content)
# Advance the clock to after the deletion.
self.reactor.advance(1)
# Check that we can't retrieve the content of the event anymore.
event_content = self.get_event(self.room_id, event_id)["content"]
self.assertFalse(bool(event_content), event_content)
def get_event(self, room_id, event_id, expected_code=200):
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
request, channel = self.make_request("GET", url)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
return channel.json_body