Fix message duplication if something goes wrong after persisting the event (#8476)

Should fix #3365.
This commit is contained in:
Erik Johnston 2020-10-13 12:07:56 +01:00 committed by GitHub
parent a9a8f29729
commit b2486f6656
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 481 additions and 32 deletions

1
changelog.d/8476.bugfix Normal file
View File

@ -0,0 +1 @@
Fix message duplication if something goes wrong after persisting the event.

View File

@ -2966,17 +2966,20 @@ class FederationHandler(BaseHandler):
return result["max_stream_id"] return result["max_stream_id"]
else: else:
assert self.storage.persistence assert self.storage.persistence
max_stream_token = await self.storage.persistence.persist_events(
# Note that this returns the events that were persisted, which may not be
# the same as were passed in if some were deduplicated due to transaction IDs.
events, max_stream_token = await self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled event_and_contexts, backfilled=backfilled
) )
if self._ephemeral_messages_enabled: if self._ephemeral_messages_enabled:
for (event, context) in event_and_contexts: for event in events:
# If there's an expiry timestamp on the event, schedule its expiry. # If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event) self._message_handler.maybe_schedule_expiry(event)
if not backfilled: # Never notify for backfilled events if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts: for event in events:
await self._notify_persisted_event(event, max_stream_token) await self._notify_persisted_event(event, max_stream_token)
return max_stream_token.stream return max_stream_token.stream

View File

@ -689,7 +689,7 @@ class EventCreationHandler:
send this event. send this event.
Returns: Returns:
The event, and its stream ordering (if state event deduplication happened, The event, and its stream ordering (if deduplication happened,
the previous, duplicate event). the previous, duplicate event).
Raises: Raises:
@ -712,6 +712,19 @@ class EventCreationHandler:
# extremities to pile up, which in turn leads to state resolution # extremities to pile up, which in turn leads to state resolution
# taking longer. # taking longer.
with (await self.limiter.queue(event_dict["room_id"])): with (await self.limiter.queue(event_dict["room_id"])):
if txn_id and requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
event_dict["room_id"],
requester.user.to_string(),
requester.access_token_id,
txn_id,
)
if existing_event_id:
event = await self.store.get_event(existing_event_id)
# we know it was persisted, so must have a stream ordering
assert event.internal_metadata.stream_ordering
return event, event.internal_metadata.stream_ordering
event, context = await self.create_event( event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
) )
@ -913,10 +926,20 @@ class EventCreationHandler:
extra_users=extra_users, extra_users=extra_users,
) )
stream_id = result["stream_id"] stream_id = result["stream_id"]
event.internal_metadata.stream_ordering = stream_id event_id = result["event_id"]
if event_id != event.event_id:
# If we get a different event back then it means that its
# been de-duplicated, so we replace the given event with the
# one already persisted.
event = await self.store.get_event(event_id)
else:
# If we newly persisted the event then we need to update its
# stream_ordering entry manually (as it was persisted on
# another worker).
event.internal_metadata.stream_ordering = stream_id
return event return event
stream_id = await self.persist_and_notify_client_event( event = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users requester, event, context, ratelimit=ratelimit, extra_users=extra_users
) )
@ -965,11 +988,16 @@ class EventCreationHandler:
context: EventContext, context: EventContext,
ratelimit: bool = True, ratelimit: bool = True,
extra_users: List[UserID] = [], extra_users: List[UserID] = [],
) -> int: ) -> EventBase:
"""Called when we have fully built the event, have already """Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth. calculated the push actions for the event, and checked auth.
This should only be run on the instance in charge of persisting events. This should only be run on the instance in charge of persisting events.
Returns:
The persisted event. This may be different than the given event if
it was de-duplicated (e.g. because we had already persisted an
event with the same transaction ID.)
""" """
assert self.storage.persistence is not None assert self.storage.persistence is not None
assert self._events_shard_config.should_handle( assert self._events_shard_config.should_handle(
@ -1137,9 +1165,13 @@ class EventCreationHandler:
if prev_state_ids: if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden") raise AuthError(403, "Changing the room create event is forbidden")
event_pos, max_stream_token = await self.storage.persistence.persist_event( # Note that this returns the event that was persisted, which may not be
event, context=context # the same as we passed in if it was deduplicated due transaction IDs.
) (
event,
event_pos,
max_stream_token,
) = await self.storage.persistence.persist_event(event, context=context)
if self._ephemeral_events_enabled: if self._ephemeral_events_enabled:
# If there's an expiry timestamp on the event, schedule its expiry. # If there's an expiry timestamp on the event, schedule its expiry.
@ -1160,7 +1192,7 @@ class EventCreationHandler:
# matters as sometimes presence code can take a while. # matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user) run_in_background(self._bump_active_time, requester.user)
return event_pos.stream return event
async def _bump_active_time(self, user: UserID) -> None: async def _bump_active_time(self, user: UserID) -> None:
try: try:

View File

@ -171,6 +171,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
# Check if we already have an event with a matching transaction ID. (We
# do this check just before we persist an event as well, but may as well
# do it up front for efficiency.)
if txn_id and requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
room_id, requester.user.to_string(), requester.access_token_id, txn_id,
)
if existing_event_id:
event_pos = await self.store.get_position_for_event(existing_event_id)
return existing_event_id, event_pos.stream
event, context = await self.event_creation_handler.create_event( event, context = await self.event_creation_handler.create_event(
requester, requester,
{ {
@ -679,7 +690,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
await self.event_creation_handler.handle_new_client_event( event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target_user], ratelimit=ratelimit requester, event, context, extra_users=[target_user], ratelimit=ratelimit
) )

View File

@ -46,6 +46,12 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"ratelimit": true, "ratelimit": true,
"extra_users": [], "extra_users": [],
} }
200 OK
{ "stream_id": 12345, "event_id": "$abcdef..." }
The returned event ID may not match the sent event if it was deduplicated.
""" """
NAME = "send_event" NAME = "send_event"
@ -116,11 +122,17 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
) )
stream_id = await self.event_creation_handler.persist_and_notify_client_event( event = await self.event_creation_handler.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users requester, event, context, ratelimit=ratelimit, extra_users=extra_users
) )
return 200, {"stream_id": stream_id} return (
200,
{
"stream_id": event.internal_metadata.stream_ordering,
"event_id": event.event_id,
},
)
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View File

@ -361,6 +361,8 @@ class PersistEventsStore:
self._store_event_txn(txn, events_and_contexts=events_and_contexts) self._store_event_txn(txn, events_and_contexts=events_and_contexts)
self._persist_transaction_ids_txn(txn, events_and_contexts)
# Insert into event_to_state_groups. # Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts) self._store_event_state_mappings_txn(txn, events_and_contexts)
@ -405,6 +407,35 @@ class PersistEventsStore:
# room_memberships, where applicable. # room_memberships, where applicable.
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
):
"""Persist the mapping from transaction IDs to event IDs (if defined).
"""
to_insert = []
for event, _ in events_and_contexts:
token_id = getattr(event.internal_metadata, "token_id", None)
txn_id = getattr(event.internal_metadata, "txn_id", None)
if token_id and txn_id:
to_insert.append(
{
"event_id": event.event_id,
"room_id": event.room_id,
"user_id": event.sender,
"token_id": token_id,
"txn_id": txn_id,
"inserted_ts": self._clock.time_msec(),
}
)
if to_insert:
self.db_pool.simple_insert_many_txn(
txn, table="event_txn_id", values=to_insert,
)
def _update_current_state_txn( def _update_current_state_txn(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,

View File

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools import itertools
import logging import logging
import threading import threading
@ -137,6 +136,15 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1 db_conn, "events", "stream_ordering", step=-1
) )
if not hs.config.worker.worker_app:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
"_cleanup_old_transaction_ids",
self._cleanup_old_transaction_ids,
)
self._get_event_cache = Cache( self._get_event_cache = Cache(
"*getEvent*", "*getEvent*",
keylen=3, keylen=3,
@ -1308,3 +1316,76 @@ class EventsWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
) )
async def get_event_id_from_transaction_id(
self, room_id: str, user_id: str, token_id: int, txn_id: str
) -> Optional[str]:
"""Look up if we have already persisted an event for the transaction ID,
returning the event ID if so.
"""
return await self.db_pool.simple_select_one_onecol(
table="event_txn_id",
keyvalues={
"room_id": room_id,
"user_id": user_id,
"token_id": token_id,
"txn_id": txn_id,
},
retcol="event_id",
allow_none=True,
desc="get_event_id_from_transaction_id",
)
async def get_already_persisted_events(
self, events: Iterable[EventBase]
) -> Dict[str, str]:
"""Look up if we have already persisted an event for the transaction ID,
returning a mapping from event ID in the given list to the event ID of
an existing event.
Also checks if there are duplicates in the given events, if there are
will map duplicates to the *first* event.
"""
mapping = {}
txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str]
for event in events:
token_id = getattr(event.internal_metadata, "token_id", None)
txn_id = getattr(event.internal_metadata, "txn_id", None)
if token_id and txn_id:
# Check if this is a duplicate of an event in the given events.
existing = txn_id_to_event.get((event.room_id, token_id, txn_id))
if existing:
mapping[event.event_id] = existing
continue
# Check if this is a duplicate of an event we've already
# persisted.
existing = await self.get_event_id_from_transaction_id(
event.room_id, event.sender, token_id, txn_id
)
if existing:
mapping[event.event_id] = existing
txn_id_to_event[(event.room_id, token_id, txn_id)] = existing
else:
txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id
return mapping
async def _cleanup_old_transaction_ids(self):
"""Cleans out transaction id mappings older than 24hrs.
"""
def _cleanup_old_transaction_ids_txn(txn):
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
"""
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
txn.execute(sql, (one_day_ago,))
return await self.db_pool.runInteraction(
"_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
)

View File

@ -1003,7 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
token: str, token: str,
device_id: Optional[str], device_id: Optional[str],
valid_until_ms: Optional[int], valid_until_ms: Optional[int],
) -> None: ) -> int:
"""Adds an access token for the given user. """Adds an access token for the given user.
Args: Args:
@ -1013,6 +1013,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
valid_until_ms: when the token is valid until. None for no expiry. valid_until_ms: when the token is valid until. None for no expiry.
Raises: Raises:
StoreError if there was a problem adding this. StoreError if there was a problem adding this.
Returns:
The token ID
""" """
next_id = self._access_tokens_id_gen.get_next() next_id = self._access_tokens_id_gen.get_next()
@ -1028,6 +1030,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )
return next_id
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str: def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn( old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id" txn, "access_tokens", {"token": token}, "device_id"

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- A map of recent events persisted with transaction IDs. Used to deduplicate
-- send event requests with the same transaction ID.
--
-- Note: transaction IDs are scoped to the room ID/user ID/access token that was
-- used to make the request.
--
-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
-- events or access token we don't want to try and de-duplicate the event.
CREATE TABLE IF NOT EXISTS event_txn_id (
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
user_id TEXT NOT NULL,
token_id BIGINT NOT NULL,
txn_id TEXT NOT NULL,
inserted_ts BIGINT NOT NULL,
FOREIGN KEY (event_id)
REFERENCES events (event_id) ON DELETE CASCADE,
FOREIGN KEY (token_id)
REFERENCES access_tokens (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);

View File

@ -96,7 +96,9 @@ class _EventPeristenceQueue:
Returns: Returns:
defer.Deferred: a deferred which will resolve once the events are defer.Deferred: a deferred which will resolve once the events are
persisted. Runs its callbacks *without* a logcontext. persisted. Runs its callbacks *without* a logcontext. The result
is the same as that returned by the callback passed to
`handle_queue`.
""" """
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())
if queue: if queue:
@ -199,7 +201,7 @@ class EventsPersistenceStorage:
self, self,
events_and_contexts: Iterable[Tuple[EventBase, EventContext]], events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
) -> RoomStreamToken: ) -> Tuple[List[EventBase], RoomStreamToken]:
""" """
Write events to the database Write events to the database
Args: Args:
@ -209,7 +211,11 @@ class EventsPersistenceStorage:
which might update the current state etc. which might update the current state etc.
Returns: Returns:
the stream ordering of the latest persisted event List of events persisted, the current position room stream position.
The list of events persisted may not be the same as those passed in
if they were deduplicated due to an event already existing that
matched the transcation ID; the existing event is returned in such
a case.
""" """
partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]] partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts: for event, ctx in events_and_contexts:
@ -225,19 +231,41 @@ class EventsPersistenceStorage:
for room_id in partitioned: for room_id in partitioned:
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
await make_deferred_yieldable( # Each deferred returns a map from event ID to existing event ID if the
# event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events).
#
# Since we use `defer.gatherResults` we need to merge the returned list
# of dicts into one.
ret_vals = await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
replaced_events = {}
for d in ret_vals:
replaced_events.update(d)
return self.main_store.get_room_max_token() events = []
for event, _ in events_and_contexts:
existing_event_id = replaced_events.get(event.event_id)
if existing_event_id:
events.append(await self.main_store.get_event(existing_event_id))
else:
events.append(event)
return (
events,
self.main_store.get_room_max_token(),
)
async def persist_event( async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[PersistedEventPosition, RoomStreamToken]: ) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
""" """
Returns: Returns:
The stream ordering of `event`, and the stream ordering of the The event, stream ordering of `event`, and the stream ordering of the
latest persisted event latest persisted event. The returned event may not match the given
event if it was deduplicated due to an existing event matching the
transaction ID.
""" """
deferred = self._event_persist_queue.add_to_queue( deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled event.room_id, [(event, context)], backfilled=backfilled
@ -245,19 +273,33 @@ class EventsPersistenceStorage:
self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)
await make_deferred_yieldable(deferred) # The deferred returns a map from event ID to existing event ID if the
# event was deduplicated. (The dict may also include other entries if
# the event was persisted in a batch with other events.)
replaced_events = await make_deferred_yieldable(deferred)
replaced_event = replaced_events.get(event.event_id)
if replaced_event:
event = await self.main_store.get_event(replaced_event)
event_stream_id = event.internal_metadata.stream_ordering event_stream_id = event.internal_metadata.stream_ordering
# stream ordering should have been assigned by now # stream ordering should have been assigned by now
assert event_stream_id assert event_stream_id
pos = PersistedEventPosition(self._instance_name, event_stream_id) pos = PersistedEventPosition(self._instance_name, event_stream_id)
return pos, self.main_store.get_room_max_token() return event, pos, self.main_store.get_room_max_token()
def _maybe_start_persisting(self, room_id: str): def _maybe_start_persisting(self, room_id: str):
"""Pokes the `_event_persist_queue` to start handling new items in the
queue, if not already in progress.
Causes the deferreds returned by `add_to_queue` to resolve with: a
dictionary of event ID to event ID we didn't persist as we already had
another event persisted with the same TXN ID.
"""
async def persisting_queue(item): async def persisting_queue(item):
with Measure(self._clock, "persist_events"): with Measure(self._clock, "persist_events"):
await self._persist_events( return await self._persist_events(
item.events_and_contexts, backfilled=item.backfilled item.events_and_contexts, backfilled=item.backfilled
) )
@ -267,12 +309,38 @@ class EventsPersistenceStorage:
self, self,
events_and_contexts: List[Tuple[EventBase, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
): ) -> Dict[str, str]:
"""Calculates the change to current state and forward extremities, and """Calculates the change to current state and forward extremities, and
persists the given events and with those updates. persists the given events and with those updates.
Returns:
A dictionary of event ID to event ID we didn't persist as we already
had another event persisted with the same TXN ID.
""" """
replaced_events = {} # type: Dict[str, str]
if not events_and_contexts: if not events_and_contexts:
return return replaced_events
# Check if any of the events have a transaction ID that has already been
# persisted, and if so we don't persist it again.
#
# We should have checked this a long time before we get here, but it's
# possible that different send event requests race in such a way that
# they both pass the earlier checks. Checking here isn't racey as we can
# have only one `_persist_events` per room being called at a time.
replaced_events = await self.main_store.get_already_persisted_events(
(event for event, _ in events_and_contexts)
)
if replaced_events:
events_and_contexts = [
(e, ctx)
for e, ctx in events_and_contexts
if e.event_id not in replaced_events
]
if not events_and_contexts:
return replaced_events
chunks = [ chunks = [
events_and_contexts[x : x + 100] events_and_contexts[x : x + 100]
@ -441,6 +509,8 @@ class EventsPersistenceStorage:
await self._handle_potentially_left_users(potentially_left_users) await self._handle_potentially_left_users(potentially_left_users)
return replaced_events
async def _calculate_new_extremities( async def _calculate_new_extremities(
self, self,
room_id: str, room_id: str,

View File

@ -0,0 +1,157 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Tuple
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import create_requester
from synapse.util.stringutils import random_string
from tests import unittest
logger = logging.getLogger(__name__)
class EventCreationTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.handler = self.hs.get_event_creation_handler()
self.persist_event_storage = self.hs.get_storage().persistence
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.info = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
)
self.token_id = self.info["token_id"]
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
"""Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates.
"""
# We create a new event with a random body, as otherwise we'll produce
# *exactly* the same event with the same hash, and so same event ID.
return self.get_success(
self.handler.create_event(
self.requester,
{
"type": EventTypes.Message,
"room_id": self.room_id,
"sender": self.requester.user.to_string(),
"content": {"msgtype": "m.text", "body": random_string(5)},
},
token_id=self.token_id,
txn_id=txn_id,
)
)
def test_duplicated_txn_id(self):
"""Test that attempting to handle/persist an event with a transaction ID
that has already been persisted correctly returns the old event and does
*not* produce duplicate messages.
"""
txn_id = "something_suitably_random"
event1, context = self._create_duplicate_event(txn_id)
ret_event1 = self.get_success(
self.handler.handle_new_client_event(self.requester, event1, context)
)
stream_id1 = ret_event1.internal_metadata.stream_ordering
self.assertEqual(event1.event_id, ret_event1.event_id)
event2, context = self._create_duplicate_event(txn_id)
# We want to test that the deduplication at the persit event end works,
# so we want to make sure we test with different events.
self.assertNotEqual(event1.event_id, event2.event_id)
ret_event2 = self.get_success(
self.handler.handle_new_client_event(self.requester, event2, context)
)
stream_id2 = ret_event2.internal_metadata.stream_ordering
# Assert that the returned values match those from the initial event
# rather than the new one.
self.assertEqual(ret_event1.event_id, ret_event2.event_id)
self.assertEqual(stream_id1, stream_id2)
# Let's test that calling `persist_event` directly also does the right
# thing.
event3, context = self._create_duplicate_event(txn_id)
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
self.persist_event_storage.persist_event(event3, context)
)
# Assert that the returned values match those from the initial event
# rather than the new one.
self.assertEqual(ret_event1.event_id, ret_event3.event_id)
self.assertEqual(stream_id1, event_pos3.stream)
# Let's test that calling `persist_events` directly also does the right
# thing.
event4, context = self._create_duplicate_event(txn_id)
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
self.persist_event_storage.persist_events([(event3, context)])
)
ret_event4 = events[0]
# Assert that the returned values match those from the initial event
# rather than the new one.
self.assertEqual(ret_event1.event_id, ret_event4.event_id)
def test_duplicated_txn_id_one_call(self):
"""Test that we correctly handle duplicates that we try and persist at
the same time.
"""
txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time
event1, context1 = self._create_duplicate_event(txn_id)
event2, context2 = self._create_duplicate_event(txn_id)
# Ensure their event IDs are different to start with
self.assertNotEqual(event1.event_id, event2.event_id)
events, _ = self.get_success(
self.persist_event_storage.persist_events(
[(event1, context1), (event2, context2)]
)
)
# Check that we've deduplicated the events.
self.assertEqual(len(events), 2)
self.assertEqual(events[0].event_id, events[1].event_id)

View File

@ -107,7 +107,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
request, channel = self.make_request( request, channel = self.make_request(
"PUT", "PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id, "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
{}, {},
access_token=self.tok, access_token=self.tok,
) )

View File

@ -254,17 +254,24 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"): if hasattr(self, "user_id"):
if self.hijack_auth: if self.hijack_auth:
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
self.hs.get_datastore().add_access_token_to_user(
self.helper.auth_user_id, "some_fake_token", None, None,
)
)
async def get_user_by_access_token(token=None, allow_guest=False): async def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.helper.auth_user_id), "user": UserID.from_string(self.helper.auth_user_id),
"token_id": 1, "token_id": token_id,
"is_guest": False, "is_guest": False,
} }
async def get_user_by_req(request, allow_guest=False, rights="access"): async def get_user_by_req(request, allow_guest=False, rights="access"):
return create_requester( return create_requester(
UserID.from_string(self.helper.auth_user_id), UserID.from_string(self.helper.auth_user_id),
1, token_id,
False, False,
False, False,
None, None,