recursively fetch redactions

This commit is contained in:
Richard van der Hoff 2019-07-24 16:44:10 +01:00
parent e6a6c4fbab
commit 448bcfd0f9

View File

@ -17,7 +17,6 @@ from __future__ import division
import itertools
import logging
import operator
from collections import namedtuple
from canonicaljson import json
@ -30,12 +29,7 @@ from synapse.api.room_versions import EventFormatVersions
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
@ -468,39 +462,49 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
map from event id to result.
map from event id to result. May return extra events which
weren't asked for.
"""
if not event_ids:
return {}
fetched_events = {}
events_to_fetch = event_ids
row_map = yield self._enqueue_events(event_ids)
while events_to_fetch:
row_map = yield self._enqueue_events(events_to_fetch)
rows = (row_map.get(event_id) for event_id in event_ids)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
fetched_events[event_id] = row
if row:
redaction_ids.update(row["redactions"])
# filter out absent rows
rows = filter(operator.truth, rows)
events_to_fetch = redaction_ids.difference(fetched_events.keys())
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
if not allow_rejected:
rows = (r for r in rows if r["rejected_reason"] is None)
result_map = {}
for event_id, row in fetched_events.items():
if not row:
continue
assert row["event_id"] == event_id
res = yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self._get_event_from_row,
rejected_reason = row["rejected_reason"]
if not allow_rejected and rejected_reason:
continue
cache_entry = yield self._get_event_from_row(
row["internal_metadata"],
row["json"],
row["redactions"],
rejected_reason=row["rejected_reason"],
format_version=row["format_version"],
)
for row in rows
],
consumeErrors=True,
)
)
return {e.event.event_id: e for e in res if e}
result_map[event_id] = cache_entry
return result_map
@defer.inlineCallbacks
def _enqueue_events(self, events):