Improvements to bundling aggregations. (#11815)

This is some odds and ends found during the review of #11791
and while continuing to work in this code:

* Return attrs classes instead of dictionaries from some methods
  to improve type safety.
* Call `get_bundled_aggregations` fewer times.
* Adds a missing assertion in the tests.
* Do not return empty bundled aggregations for an event (preferring
  to not include the bundle at all, as the docstring states).
This commit is contained in:
Patrick Cloke 2022-01-26 08:27:04 -05:00 committed by GitHub
parent d8df8e6c14
commit 2897fb6b4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 212 additions and 139 deletions

View file

@ -13,17 +13,7 @@
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
cast,
)
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
import attr
from frozendict import frozendict
@ -43,6 +33,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@ -51,6 +42,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
latest_event: EventBase
count: int
current_user_participated: bool
@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.
Some values require additional processing during serialization.
"""
annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None
def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
) -> Optional[Dict[str, Any]]:
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore):
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations: Dict[str, Any] = {}
aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
aggregations.annotations = annotations.to_dict()
references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations[RelationTypes.REFERENCE] = references.to_dict()
aggregations.references = references.to_dict()
edit = None
if event.type == EventTypes.Message:
edit = await self.get_applicable_edit(event_id, room_id)
if edit:
aggregations[RelationTypes.REPLACE] = edit
aggregations.replace = edit
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore):
event_id, room_id, user_id
)
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
"latest_event": latest_thread_event,
"count": thread_count,
"current_user_participated": participated,
}
aggregations.thread = _ThreadAggregation(
latest_event=latest_thread_event,
count=thread_count,
current_user_participated=participated,
)
# Store the bundled aggregations in the event metadata for later use.
return aggregations
@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore):
self,
events: Iterable[EventBase],
user_id: str,
) -> Dict[str, Dict[str, Any]]:
) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore):
results = {}
for event in events:
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result is not None:
if event_result:
results[event.event_id] = event_result
return results