Fix message duplication if something goes wrong after persisting the event (#8476)

Should fix #3365.
This commit is contained in:
Erik Johnston 2020-10-13 12:07:56 +01:00 committed by GitHub
parent a9a8f29729
commit b2486f6656
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 481 additions and 32 deletions

View file

@ -361,6 +361,8 @@ class PersistEventsStore:
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
self._persist_transaction_ids_txn(txn, events_and_contexts)
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
@ -405,6 +407,35 @@ class PersistEventsStore:
# room_memberships, where applicable.
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
):
"""Persist the mapping from transaction IDs to event IDs (if defined).
"""
to_insert = []
for event, _ in events_and_contexts:
token_id = getattr(event.internal_metadata, "token_id", None)
txn_id = getattr(event.internal_metadata, "txn_id", None)
if token_id and txn_id:
to_insert.append(
{
"event_id": event.event_id,
"room_id": event.room_id,
"user_id": event.sender,
"token_id": token_id,
"txn_id": txn_id,
"inserted_ts": self._clock.time_msec(),
}
)
if to_insert:
self.db_pool.simple_insert_many_txn(
txn, table="event_txn_id", values=to_insert,
)
def _update_current_state_txn(
self,
txn: LoggingTransaction,

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
import threading
@ -137,6 +136,15 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1
)
if not hs.config.worker.worker_app:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
"_cleanup_old_transaction_ids",
self._cleanup_old_transaction_ids,
)
self._get_event_cache = Cache(
"*getEvent*",
keylen=3,
@ -1308,3 +1316,76 @@ class EventsWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
async def get_event_id_from_transaction_id(
self, room_id: str, user_id: str, token_id: int, txn_id: str
) -> Optional[str]:
"""Look up if we have already persisted an event for the transaction ID,
returning the event ID if so.
"""
return await self.db_pool.simple_select_one_onecol(
table="event_txn_id",
keyvalues={
"room_id": room_id,
"user_id": user_id,
"token_id": token_id,
"txn_id": txn_id,
},
retcol="event_id",
allow_none=True,
desc="get_event_id_from_transaction_id",
)
async def get_already_persisted_events(
self, events: Iterable[EventBase]
) -> Dict[str, str]:
"""Look up if we have already persisted an event for the transaction ID,
returning a mapping from event ID in the given list to the event ID of
an existing event.
Also checks if there are duplicates in the given events, if there are
will map duplicates to the *first* event.
"""
mapping = {}
txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str]
for event in events:
token_id = getattr(event.internal_metadata, "token_id", None)
txn_id = getattr(event.internal_metadata, "txn_id", None)
if token_id and txn_id:
# Check if this is a duplicate of an event in the given events.
existing = txn_id_to_event.get((event.room_id, token_id, txn_id))
if existing:
mapping[event.event_id] = existing
continue
# Check if this is a duplicate of an event we've already
# persisted.
existing = await self.get_event_id_from_transaction_id(
event.room_id, event.sender, token_id, txn_id
)
if existing:
mapping[event.event_id] = existing
txn_id_to_event[(event.room_id, token_id, txn_id)] = existing
else:
txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id
return mapping
async def _cleanup_old_transaction_ids(self):
"""Cleans out transaction id mappings older than 24hrs.
"""
def _cleanup_old_transaction_ids_txn(txn):
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
"""
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
txn.execute(sql, (one_day_ago,))
return await self.db_pool.runInteraction(
"_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
)

View file

@ -1003,7 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
) -> None:
) -> int:
"""Adds an access token for the given user.
Args:
@ -1013,6 +1013,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
valid_until_ms: when the token is valid until. None for no expiry.
Raises:
StoreError if there was a problem adding this.
Returns:
The token ID
"""
next_id = self._access_tokens_id_gen.get_next()
@ -1028,6 +1030,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
return next_id
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id"

View file

@ -0,0 +1,40 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- A map of recent events persisted with transaction IDs. Used to deduplicate
-- send event requests with the same transaction ID.
--
-- Note: transaction IDs are scoped to the room ID/user ID/access token that was
-- used to make the request.
--
-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
-- events or access token we don't want to try and de-duplicate the event.
CREATE TABLE IF NOT EXISTS event_txn_id (
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
token_id BIGINT NOT NULL,
txn_id TEXT NOT NULL,
inserted_ts BIGINT NOT NULL,
FOREIGN KEY (event_id)
REFERENCES events (event_id) ON DELETE CASCADE,
FOREIGN KEY (token_id)
REFERENCES access_tokens (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);

View file

@ -96,7 +96,9 @@ class _EventPeristenceQueue:
Returns:
defer.Deferred: a deferred which will resolve once the events are
persisted. Runs its callbacks *without* a logcontext.
persisted. Runs its callbacks *without* a logcontext. The result
is the same as that returned by the callback passed to
`handle_queue`.
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
@ -199,7 +201,7 @@ class EventsPersistenceStorage:
self,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> RoomStreamToken:
) -> Tuple[List[EventBase], RoomStreamToken]:
"""
Write events to the database
Args:
@ -209,7 +211,11 @@ class EventsPersistenceStorage:
which might update the current state etc.
Returns:
the stream ordering of the latest persisted event
List of events persisted, the current position room stream position.
The list of events persisted may not be the same as those passed in
if they were deduplicated due to an event already existing that
matched the transcation ID; the existing event is returned in such
a case.
"""
partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts:
@ -225,19 +231,41 @@ class EventsPersistenceStorage:
for room_id in partitioned:
self._maybe_start_persisting(room_id)
await make_deferred_yieldable(
# Each deferred returns a map from event ID to existing event ID if the
# event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events).
#
# Since we use `defer.gatherResults` we need to merge the returned list
# of dicts into one.
ret_vals = await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
replaced_events = {}
for d in ret_vals:
replaced_events.update(d)
return self.main_store.get_room_max_token()
events = []
for event, _ in events_and_contexts:
existing_event_id = replaced_events.get(event.event_id)
if existing_event_id:
events.append(await self.main_store.get_event(existing_event_id))
else:
events.append(event)
return (
events,
self.main_store.get_room_max_token(),
)
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[PersistedEventPosition, RoomStreamToken]:
) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
"""
Returns:
The stream ordering of `event`, and the stream ordering of the
latest persisted event
The event, stream ordering of `event`, and the stream ordering of the
latest persisted event. The returned event may not match the given
event if it was deduplicated due to an existing event matching the
transaction ID.
"""
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled
@ -245,19 +273,33 @@ class EventsPersistenceStorage:
self._maybe_start_persisting(event.room_id)
await make_deferred_yieldable(deferred)
# The deferred returns a map from event ID to existing event ID if the
# event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events.)
replaced_events = await make_deferred_yieldable(deferred)
replaced_event = replaced_events.get(event.event_id)
if replaced_event:
event = await self.main_store.get_event(replaced_event)
event_stream_id = event.internal_metadata.stream_ordering
# stream ordering should have been assigned by now
assert event_stream_id
pos = PersistedEventPosition(self._instance_name, event_stream_id)
return pos, self.main_store.get_room_max_token()
return event, pos, self.main_store.get_room_max_token()
def _maybe_start_persisting(self, room_id: str):
"""Pokes the `_event_persist_queue` to start handling new items in the
queue, if not already in progress.
Causes the deferreds returned by `add_to_queue` to resolve with: a
dictionary of event ID to event ID we didn't persist as we already had
another event persisted with the same TXN ID.
"""
async def persisting_queue(item):
with Measure(self._clock, "persist_events"):
await self._persist_events(
return await self._persist_events(
item.events_and_contexts, backfilled=item.backfilled
)
@ -267,12 +309,38 @@ class EventsPersistenceStorage:
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
):
) -> Dict[str, str]:
"""Calculates the change to current state and forward extremities, and
persists the given events and with those updates.
Returns:
A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID.
"""
replaced_events = {} # type: Dict[str, str]
if not events_and_contexts:
return
return replaced_events
# Check if any of the events have a transaction ID that has already been
# persisted, and if so we don't persist it again.
#
# We should have checked this a long time before we get here, but it's
# possible that different send event requests race in such a way that
# they both pass the earlier checks. Checking here isn't racey as we can
# have only one `_persist_events` per room being called at a time.
replaced_events = await self.main_store.get_already_persisted_events(
(event for event, _ in events_and_contexts)
)
if replaced_events:
events_and_contexts = [
(e, ctx)
for e, ctx in events_and_contexts
if e.event_id not in replaced_events
]
if not events_and_contexts:
return replaced_events
chunks = [
events_and_contexts[x : x + 100]
@ -441,6 +509,8 @@ class EventsPersistenceStorage:
await self._handle_potentially_left_users(potentially_left_users)
return replaced_events
async def _calculate_new_extremities(
self,
room_id: str,