Move get_bundled_aggregations to relations handler. (#12237)

The get_bundled_aggregations code is fairly high-level and uses
a lot of store methods, we move it into the handler as that seems
like a better fit.
This commit is contained in:
Patrick Cloke 2022-03-18 13:49:32 -04:00 committed by GitHub
parent 80e0e1f35e
commit 8fe930c215
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 173 additions and 157 deletions

1
changelog.d/12237.misc Normal file
View File

@ -0,0 +1 @@
Refactor the relations endpoints to add a `RelationsHandler`.

View File

@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze
from . import EventBase from . import EventBase
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.relations import BundledAggregations
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'

View File

@ -134,6 +134,7 @@ class PaginationHandler:
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._server_name = hs.hostname self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler() self._room_shutdown_handler = hs.get_room_shutdown_handler()
self._relations_handler = hs.get_relations_handler()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation. # IDs of rooms in which there currently an active purge *or delete* operation.
@ -539,7 +540,9 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values())) state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values() state = state_dict.values()
aggregations = await self.store.get_bundled_aggregations(events, user_id) aggregations = await self._relations_handler.get_bundled_aggregations(
events, user_id
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -12,18 +12,53 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
import attr
from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StreamToken from synapse.types import JsonDict, Requester, StreamToken
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool
@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.
Some values require additional processing during serialization.
"""
annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None
def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)
class RelationsHandler: class RelationsHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main self._main_store = hs.get_datastores().main
@ -103,7 +138,7 @@ class RelationsHandler:
) )
# The relations returned for the requested event do include their # The relations returned for the requested event do include their
# bundled aggregations. # bundled aggregations.
aggregations = await self._main_store.get_bundled_aggregations( aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string() events, requester.user.to_string()
) )
serialized_events = self._event_serializer.serialize_events( serialized_events = self._event_serializer.serialize_events(
@ -115,3 +150,115 @@ class RelationsHandler:
return_value["original_event"] = original_event return_value["original_event"] = original_event
return return_value return return_value
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return None
event_id = event.event_id
room_id = event.room_id
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations = BundledAggregations()
annotations = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id
)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)
references = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
# De-duplicate events by ID to handle the same event requested multiple times.
#
# State events do not get bundled aggregations.
events_by_id = {
event.event_id: event for event in events if not event.is_state()
}
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
edits = await self._main_store.get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
if not event.internal_metadata.is_redacted()
]
)
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._main_store.get_threads_participated(
[event_id for event_id, summary in summaries.items() if summary], user_id
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results

View File

@ -60,8 +60,8 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state from synapse.handlers.federation import get_domains_from_state
from synapse.handlers.relations import BundledAggregations
from synapse.rest.admin._base import assert_user_is_admin from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import ( from synapse.types import (
@ -1118,6 +1118,7 @@ class RoomContextHandler:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
self._relations_handler = hs.get_relations_handler()
async def get_event_context( async def get_event_context(
self, self,
@ -1190,7 +1191,7 @@ class RoomContextHandler:
event = filtered[0] event = filtered[0]
# Fetch the aggregations. # Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations( aggregations = await self._relations_handler.get_bundled_aggregations(
itertools.chain(events_before, (event,), events_after), itertools.chain(events_before, (event,), events_after),
user.to_string(), user.to_string(),
) )

View File

@ -54,6 +54,7 @@ class SearchHandler:
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.state_store = self.storage.state self.state_store = self.storage.state
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -354,7 +355,7 @@ class SearchHandler:
aggregations = None aggregations = None
if self._msc3666_enabled: if self._msc3666_enabled:
aggregations = await self.store.get_bundled_aggregations( aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be # Generate an iterable of EventBase for all the events that will be
# returned, including contextual events. # returned, including contextual events.
itertools.chain( itertools.chain(

View File

@ -33,11 +33,11 @@ from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
@ -269,6 +269,7 @@ class SyncHandler:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self._relations_handler = hs.get_relations_handler()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@ -638,8 +639,10 @@ class SyncHandler:
# as clients will have all the necessary information. # as clients will have all the necessary information.
bundled_aggregations = None bundled_aggregations = None
if limited or newly_joined_room: if limited or newly_joined_room:
bundled_aggregations = await self.store.get_bundled_aggregations( bundled_aggregations = (
recents, sync_config.user.to_string() await self._relations_handler.get_bundled_aggregations(
recents, sync_config.user.to_string()
)
) )
return TimelineBatch( return TimelineBatch(

View File

@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet):
self._store = hs.get_datastores().main self._store = hs.get_datastores().main
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET( async def on_GET(
@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet):
if event: if event:
# Ensure there are bundled aggregations available. # Ensure there are bundled aggregations available.
aggregations = await self._store.get_bundled_aggregations( aggregations = await self._relations_handler.get_bundled_aggregations(
[event], requester.user.to_string() [event], requester.user.to_string()
) )

View File

@ -27,7 +27,6 @@ from typing import (
) )
import attr import attr
from frozendict import frozendict
from synapse.api.constants import RelationTypes from synapse.api.constants import RelationTypes
from synapse.events import EventBase from synapse.events import EventBase
@ -41,45 +40,15 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
from synapse.types import JsonDict, RoomStreamToken, StreamToken from synapse.types import RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
latest_event: EventBase
# The latest edit to the latest event in the thread.
latest_edit: Optional[EventBase]
# The total number of events in the thread.
count: int
# True if the current user has sent an event to the thread.
current_user_participated: bool
@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.
Some values require additional processing during serialization.
"""
annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None
def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)
class RelationsWorkerStore(SQLBaseStore): class RelationsWorkerStore(SQLBaseStore):
def __init__( def __init__(
self, self,
@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
async def _get_applicable_edits( async def get_applicable_edits(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, Optional[EventBase]]: ) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given """Get the most recent edit (if any) that has happened for the given
@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids") @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def _get_thread_summaries( async def get_thread_summaries(
self, event_ids: Collection[str] self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited. # Check to see if any of those events are edited.
latest_edits = await self._get_applicable_edits(latest_event_ids.values()) latest_edits = await self.get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary. # Map to the event IDs to the thread summary.
# #
@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError() raise NotImplementedError()
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids") @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
async def _get_threads_participated( async def get_threads_participated(
self, event_ids: Collection[str], user_id: str self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]: ) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads. """Get whether the requesting user participated in the given threads.
@ -766,116 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
) )
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return None
event_id = event.event_id
room_id = event.room_id
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
aggregations.annotations = await annotations.to_dict(
cast("DataStore", self)
)
references = await self.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
events: The iterable of events to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
Returns:
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
# De-duplicate events by ID to handle the same event requested multiple times.
#
# State events do not get bundled aggregations.
events_by_id = {
event.event_id: event for event in events if not event.is_state()
}
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
edits = await self._get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
if not event.internal_metadata.is_redacted()
]
)
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
summaries = await self._get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(
[event_id for event_id, summary in summaries.items() if summary], user_id
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results
class RelationsStore(RelationsWorkerStore): class RelationsStore(RelationsWorkerStore):
pass pass