mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-05-07 22:24:56 -04:00
Move get_bundled_aggregations to relations handler. (#12237)
The get_bundled_aggregations code is fairly high-level and uses a lot of store methods, we move it into the handler as that seems like a better fit.
This commit is contained in:
parent
80e0e1f35e
commit
8fe930c215
9 changed files with 173 additions and 157 deletions
|
@ -27,7 +27,6 @@ from typing import (
|
|||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import RelationTypes
|
||||
from synapse.events import EventBase
|
||||
|
@ -41,45 +40,15 @@ from synapse.storage.database import (
|
|||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamToken
|
||||
from synapse.types import RoomStreamToken, StreamToken
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class _ThreadAggregation:
|
||||
# The latest event in the thread.
|
||||
latest_event: EventBase
|
||||
# The latest edit to the latest event in the thread.
|
||||
latest_edit: Optional[EventBase]
|
||||
# The total number of events in the thread.
|
||||
count: int
|
||||
# True if the current user has sent an event to the thread.
|
||||
current_user_participated: bool
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class BundledAggregations:
|
||||
"""
|
||||
The bundled aggregations for an event.
|
||||
|
||||
Some values require additional processing during serialization.
|
||||
"""
|
||||
|
||||
annotations: Optional[JsonDict] = None
|
||||
references: Optional[JsonDict] = None
|
||||
replace: Optional[EventBase] = None
|
||||
thread: Optional[_ThreadAggregation] = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.annotations or self.references or self.replace or self.thread)
|
||||
|
||||
|
||||
class RelationsWorkerStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
|
||||
async def _get_applicable_edits(
|
||||
async def get_applicable_edits(
|
||||
self, event_ids: Collection[str]
|
||||
) -> Dict[str, Optional[EventBase]]:
|
||||
"""Get the most recent edit (if any) that has happened for the given
|
||||
|
@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
|
||||
async def _get_thread_summaries(
|
||||
async def get_thread_summaries(
|
||||
self, event_ids: Collection[str]
|
||||
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
|
||||
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
|
||||
|
@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
|
||||
|
||||
# Check to see if any of those events are edited.
|
||||
latest_edits = await self._get_applicable_edits(latest_event_ids.values())
|
||||
latest_edits = await self.get_applicable_edits(latest_event_ids.values())
|
||||
|
||||
# Map to the event IDs to the thread summary.
|
||||
#
|
||||
|
@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
|
||||
async def _get_threads_participated(
|
||||
async def get_threads_participated(
|
||||
self, event_ids: Collection[str], user_id: str
|
||||
) -> Dict[str, bool]:
|
||||
"""Get whether the requesting user participated in the given threads.
|
||||
|
@ -766,116 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
||||
)
|
||||
|
||||
async def _get_bundled_aggregation_for_event(
|
||||
self, event: EventBase, user_id: str
|
||||
) -> Optional[BundledAggregations]:
|
||||
"""Generate bundled aggregations for an event.
|
||||
|
||||
Note that this does not use a cache, but depends on cached methods.
|
||||
|
||||
Args:
|
||||
event: The event to calculate bundled aggregations for.
|
||||
user_id: The user requesting the bundled aggregations.
|
||||
|
||||
Returns:
|
||||
The bundled aggregations for an event, if bundled aggregations are
|
||||
enabled and the event can have bundled aggregations.
|
||||
"""
|
||||
|
||||
# Do not bundle aggregations for an event which represents an edit or an
|
||||
# annotation. It does not make sense for them to have related events.
|
||||
relates_to = event.content.get("m.relates_to")
|
||||
if isinstance(relates_to, (dict, frozendict)):
|
||||
relation_type = relates_to.get("rel_type")
|
||||
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
|
||||
return None
|
||||
|
||||
event_id = event.event_id
|
||||
room_id = event.room_id
|
||||
|
||||
# The bundled aggregations to include, a mapping of relation type to a
|
||||
# type-specific value. Some types include the direct return type here
|
||||
# while others need more processing during serialization.
|
||||
aggregations = BundledAggregations()
|
||||
|
||||
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
|
||||
if annotations.chunk:
|
||||
aggregations.annotations = await annotations.to_dict(
|
||||
cast("DataStore", self)
|
||||
)
|
||||
|
||||
references = await self.get_relations_for_event(
|
||||
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
|
||||
)
|
||||
if references.chunk:
|
||||
aggregations.references = await references.to_dict(cast("DataStore", self))
|
||||
|
||||
# Store the bundled aggregations in the event metadata for later use.
|
||||
return aggregations
|
||||
|
||||
async def get_bundled_aggregations(
|
||||
self, events: Iterable[EventBase], user_id: str
|
||||
) -> Dict[str, BundledAggregations]:
|
||||
"""Generate bundled aggregations for events.
|
||||
|
||||
Args:
|
||||
events: The iterable of events to calculate bundled aggregations for.
|
||||
user_id: The user requesting the bundled aggregations.
|
||||
|
||||
Returns:
|
||||
A map of event ID to the bundled aggregation for the event. Not all
|
||||
events may have bundled aggregations in the results.
|
||||
"""
|
||||
# De-duplicate events by ID to handle the same event requested multiple times.
|
||||
#
|
||||
# State events do not get bundled aggregations.
|
||||
events_by_id = {
|
||||
event.event_id: event for event in events if not event.is_state()
|
||||
}
|
||||
|
||||
# event ID -> bundled aggregation in non-serialized form.
|
||||
results: Dict[str, BundledAggregations] = {}
|
||||
|
||||
# Fetch other relations per event.
|
||||
for event in events_by_id.values():
|
||||
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
|
||||
if event_result:
|
||||
results[event.event_id] = event_result
|
||||
|
||||
# Fetch any edits (but not for redacted events).
|
||||
edits = await self._get_applicable_edits(
|
||||
[
|
||||
event_id
|
||||
for event_id, event in events_by_id.items()
|
||||
if not event.internal_metadata.is_redacted()
|
||||
]
|
||||
)
|
||||
for event_id, edit in edits.items():
|
||||
results.setdefault(event_id, BundledAggregations()).replace = edit
|
||||
|
||||
# Fetch thread summaries.
|
||||
summaries = await self._get_thread_summaries(events_by_id.keys())
|
||||
# Only fetch participated for a limited selection based on what had
|
||||
# summaries.
|
||||
participated = await self._get_threads_participated(
|
||||
[event_id for event_id, summary in summaries.items() if summary], user_id
|
||||
)
|
||||
for event_id, summary in summaries.items():
|
||||
if summary:
|
||||
thread_count, latest_thread_event, edit = summary
|
||||
results.setdefault(
|
||||
event_id, BundledAggregations()
|
||||
).thread = _ThreadAggregation(
|
||||
latest_event=latest_thread_event,
|
||||
latest_edit=edit,
|
||||
count=thread_count,
|
||||
# If there's a thread summary it must also exist in the
|
||||
# participated dictionary.
|
||||
current_user_participated=participated[event_id],
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class RelationsStore(RelationsWorkerStore):
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue