Fix a bug where redactions were not being sent over federation if we did not have the original event. (#13813)

This commit is contained in:
Shay 2022-10-11 11:18:45 -07:00 committed by GitHub
parent 6a92944854
commit a86b2f6837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 62 additions and 38 deletions

1
changelog.d/13813.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where redactions were not being sent over federation if we did not have the original event.

View File

@ -353,21 +353,25 @@ class FederationSender(AbstractFederationSender):
last_token = await self.store.get_federation_out_pos("events") last_token = await self.store.get_federation_out_pos("events")
( (
next_token, next_token,
events,
event_to_received_ts, event_to_received_ts,
) = await self.store.get_all_new_events_stream( ) = await self.store.get_all_new_event_ids_stream(
last_token, self._last_poked_id, limit=100 last_token, self._last_poked_id, limit=100
) )
event_ids = event_to_received_ts.keys()
event_entries = await self.store.get_unredacted_events_from_cache_or_db(
event_ids
)
logger.debug( logger.debug(
"Handling %i -> %i: %i events to send (current id %i)", "Handling %i -> %i: %i events to send (current id %i)",
last_token, last_token,
next_token, next_token,
len(events), len(event_entries),
self._last_poked_id, self._last_poked_id,
) )
if not events and next_token >= self._last_poked_id: if not event_entries and next_token >= self._last_poked_id:
logger.debug("All events processed") logger.debug("All events processed")
break break
@ -508,8 +512,14 @@ class FederationSender(AbstractFederationSender):
await handle_event(event) await handle_event(event)
events_by_room: Dict[str, List[EventBase]] = {} events_by_room: Dict[str, List[EventBase]] = {}
for event in events:
events_by_room.setdefault(event.room_id, []).append(event) for event_id in event_ids:
# `event_entries` is unsorted, so we have to iterate over `event_ids`
# to ensure the events are in the right order
event_cache = event_entries.get(event_id)
if event_cache:
event = event_cache.event
events_by_room.setdefault(event.room_id, []).append(event)
await make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
@ -524,9 +534,10 @@ class FederationSender(AbstractFederationSender):
logger.debug("Successfully handled up to %i", next_token) logger.debug("Successfully handled up to %i", next_token)
await self.store.update_federation_out_pos("events", next_token) await self.store.update_federation_out_pos("events", next_token)
if events: if event_entries:
now = self.clock.time_msec() now = self.clock.time_msec()
ts = event_to_received_ts[events[-1].event_id] last_id = next(reversed(event_ids))
ts = event_to_received_ts[last_id]
assert ts is not None assert ts is not None
synapse.metrics.event_processing_lag.labels( synapse.metrics.event_processing_lag.labels(
@ -536,7 +547,7 @@ class FederationSender(AbstractFederationSender):
"federation_sender" "federation_sender"
).set(ts) ).set(ts)
events_processed_counter.inc(len(events)) events_processed_counter.inc(len(event_entries))
event_processing_loop_room_count.labels("federation_sender").inc( event_processing_loop_room_count.labels("federation_sender").inc(
len(events_by_room) len(events_by_room)

View File

@ -109,10 +109,13 @@ class ApplicationServicesHandler:
last_token = await self.store.get_appservice_last_pos() last_token = await self.store.get_appservice_last_pos()
( (
upper_bound, upper_bound,
events,
event_to_received_ts, event_to_received_ts,
) = await self.store.get_all_new_events_stream( ) = await self.store.get_all_new_event_ids_stream(
last_token, self.current_max, limit=100, get_prev_content=True last_token, self.current_max, limit=100
)
events = await self.store.get_events_as_list(
event_to_received_ts.keys(), get_prev_content=True
) )
events_by_room: Dict[str, List[EventBase]] = {} events_by_room: Dict[str, List[EventBase]] = {}

View File

@ -474,7 +474,7 @@ class EventsWorkerStore(SQLBaseStore):
return [] return []
# there may be duplicates so we cast the list to a set # there may be duplicates so we cast the list to a set
event_entry_map = await self._get_events_from_cache_or_db( event_entry_map = await self.get_unredacted_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected set(event_ids), allow_rejected=allow_rejected
) )
@ -509,7 +509,9 @@ class EventsWorkerStore(SQLBaseStore):
continue continue
redacted_event_id = entry.event.redacts redacted_event_id = entry.event.redacts
event_map = await self._get_events_from_cache_or_db([redacted_event_id]) event_map = await self.get_unredacted_events_from_cache_or_db(
[redacted_event_id]
)
original_event_entry = event_map.get(redacted_event_id) original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry: if not original_event_entry:
# we don't have the redacted event (or it was rejected). # we don't have the redacted event (or it was rejected).
@ -588,11 +590,16 @@ class EventsWorkerStore(SQLBaseStore):
return events return events
@cancellable @cancellable
async def _get_events_from_cache_or_db( async def get_unredacted_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False self,
event_ids: Iterable[str],
allow_rejected: bool = False,
) -> Dict[str, EventCacheEntry]: ) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database. """Fetch a bunch of events from the cache or the database.
Note that the events pulled by this function will not have any redactions
applied, and no guarantee is made about the ordering of the events returned.
If events are pulled from the database, they will be cached for future lookups. If events are pulled from the database, they will be cached for future lookups.
Unknown events are omitted from the response. Unknown events are omitted from the response.

View File

@ -1024,28 +1024,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token}, "after": {"event_ids": events_after, "token": end_token},
} }
async def get_all_new_events_stream( async def get_all_new_event_ids_stream(
self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False self,
) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]: from_id: int,
current_id: int,
limit: int,
) -> Tuple[int, Dict[str, Optional[int]]]:
"""Get all new events """Get all new events
Returns all events with from_id < stream_ordering <= current_id. Returns all event ids with from_id < stream_ordering <= current_id.
Args: Args:
from_id: the stream_ordering of the last event we processed from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return limit: the maximum number of events to return
get_prev_content: whether to fetch previous event content
Returns: Returns:
A tuple of (next_id, events, event_to_received_ts), where `next_id` A tuple of (next_id, event_to_received_ts), where `next_id`
is the next value to pass as `from_id` (it will either be the is the next value to pass as `from_id` (it will either be the
stream_ordering of the last returned event, or, if fewer than `limit` stream_ordering of the last returned event, or, if fewer than `limit`
events were found, the `current_id`). The `event_to_received_ts` is events were found, the `current_id`). The `event_to_received_ts` is
a dictionary mapping event ID to the event `received_ts`. a dictionary mapping event ID to the event `received_ts`, sorted by ascending
stream_ordering.
""" """
def get_all_new_events_stream_txn( def get_all_new_event_ids_stream_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[int, Dict[str, Optional[int]]]: ) -> Tuple[int, Dict[str, Optional[int]]]:
sql = ( sql = (
@ -1070,15 +1073,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, event_to_received_ts return upper_bound, event_to_received_ts
upper_bound, event_to_received_ts = await self.db_pool.runInteraction( upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn
) )
events = await self.get_events_as_list( return upper_bound, event_to_received_ts
event_to_received_ts.keys(),
get_prev_content=get_prev_content,
)
return upper_bound, events, event_to_received_ts
async def get_federation_out_pos(self, typ: str) -> int: async def get_federation_out_pos(self, typ: str) -> int:
if self._need_to_reset_federation_stream_positions: if self._need_to_reset_federation_stream_positions:

View File

@ -76,9 +76,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock( event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
) )
self.mock_store.get_all_new_events_stream.side_effect = [ self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, [], {})), make_awaitable((0, {})),
make_awaitable((1, [event], {event.event_id: 0})), make_awaitable((1, {event.event_id: 0})),
]
self.mock_store.get_events_as_list.side_effect = [
make_awaitable([]),
make_awaitable([event]),
] ]
self.handler.notify_interested_services(RoomStreamToken(None, 1)) self.handler.notify_interested_services(RoomStreamToken(None, 1))
@ -95,10 +99,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_all_new_events_stream.side_effect = [ self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, [event], {event.event_id: 0})), make_awaitable((0, {event.event_id: 0})),
] ]
self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
self.handler.notify_interested_services(RoomStreamToken(None, 0)) self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@ -112,7 +116,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_as_api.query_user.return_value = make_awaitable(True)
self.mock_store.get_all_new_events_stream.side_effect = [ self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, [event], {event.event_id: 0})), make_awaitable((0, [event], {event.event_id: 0})),
] ]