Improvements to bundling aggregations. (#11815)

This is some odds and ends found during the review of #11791
and while continuing to work in this code:

* Return attrs classes instead of dictionaries from some methods
  to improve type safety.
* Call `get_bundled_aggregations` fewer times.
* Adds a missing assertion in the tests.
* Do not return empty bundled aggregations for an event (preferring
  to not include the bundle at all, as the docstring states).
This commit is contained in:
Patrick Cloke 2022-01-26 08:27:04 -05:00 committed by GitHub
parent d8df8e6c14
commit 2897fb6b4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 212 additions and 139 deletions

1
changelog.d/11815.misc Normal file
View File

@ -0,0 +1 @@
Improve type safety of bundled aggregations code.

View File

@ -14,7 +14,17 @@
# limitations under the License. # limitations under the License.
import collections.abc import collections.abc
import re import re
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Union,
)
from frozendict import frozendict from frozendict import frozendict
@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze
from . import EventBase from . import EventBase
if TYPE_CHECKING:
from synapse.storage.databases.main.relations import BundledAggregations
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded # (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'. # by a match for 'stuff'.
@ -376,7 +390,7 @@ class EventClientSerializer:
event: Union[JsonDict, EventBase], event: Union[JsonDict, EventBase],
time_now: int, time_now: int,
*, *,
bundle_aggregations: Optional[Dict[str, JsonDict]] = None, bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
**kwargs: Any, **kwargs: Any,
) -> JsonDict: ) -> JsonDict:
"""Serializes a single event. """Serializes a single event.
@ -415,7 +429,7 @@ class EventClientSerializer:
self, self,
event: EventBase, event: EventBase,
time_now: int, time_now: int,
aggregations: JsonDict, aggregations: "BundledAggregations",
serialized_event: JsonDict, serialized_event: JsonDict,
) -> None: ) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event. """Potentially injects bundled aggregations into the unsigned portion of the serialized event.
@ -427,13 +441,18 @@ class EventClientSerializer:
serialized_event: The serialized event which may be modified. serialized_event: The serialized event which may be modified.
""" """
# Make a copy in-case the object is cached. serialized_aggregations = {}
aggregations = aggregations.copy()
if RelationTypes.REPLACE in aggregations: if aggregations.annotations:
serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations
if aggregations.references:
serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
if aggregations.replace:
# If there is an edit replace the content, preserving existing # If there is an edit replace the content, preserving existing
# relations. # relations.
edit = aggregations[RelationTypes.REPLACE] edit = aggregations.replace
# Ensure we take copies of the edit content, otherwise we risk modifying # Ensure we take copies of the edit content, otherwise we risk modifying
# the original event. # the original event.
@ -451,24 +470,28 @@ class EventClientSerializer:
else: else:
serialized_event["content"].pop("m.relates_to", None) serialized_event["content"].pop("m.relates_to", None)
aggregations[RelationTypes.REPLACE] = { serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id, "event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts, "origin_server_ts": edit.origin_server_ts,
"sender": edit.sender, "sender": edit.sender,
} }
# If this event is the start of a thread, include a summary of the replies. # If this event is the start of a thread, include a summary of the replies.
if RelationTypes.THREAD in aggregations: if aggregations.thread:
# Serialize the latest thread event. serialized_aggregations[RelationTypes.THREAD] = {
latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]
# Don't bundle aggregations as this could recurse forever. # Don't bundle aggregations as this could recurse forever.
aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event( "latest_event": self.serialize_event(
latest_thread_event, time_now, bundle_aggregations=None aggregations.thread.latest_event, time_now, bundle_aggregations=None
) ),
"count": aggregations.thread.count,
"current_user_participated": aggregations.thread.current_user_participated,
}
# Include the bundled aggregations in the event. # Include the bundled aggregations in the event.
serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations) if serialized_aggregations:
serialized_event["unsigned"].setdefault("m.relations", {}).update(
serialized_aggregations
)
def serialize_events( def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any

View File

@ -30,6 +30,7 @@ from typing import (
Tuple, Tuple,
) )
import attr
from typing_extensions import TypedDict from typing_extensions import TypedDict
from synapse.api.constants import ( from synapse.api.constants import (
@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state from synapse.handlers.federation import get_domains_from_state
from synapse.rest.admin._base import assert_user_is_admin from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams import EventSource from synapse.streams import EventSource
from synapse.types import ( from synapse.types import (
@ -90,6 +92,17 @@ id_server_scheme = "https://"
FIVE_MINUTES_IN_MS = 5 * 60 * 1000 FIVE_MINUTES_IN_MS = 5 * 60 * 1000
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventContext:
events_before: List[EventBase]
event: EventBase
events_after: List[EventBase]
state: List[EventBase]
aggregations: Dict[str, BundledAggregations]
start: str
end: str
class RoomCreationHandler: class RoomCreationHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -1119,7 +1132,7 @@ class RoomContextHandler:
limit: int, limit: int,
event_filter: Optional[Filter], event_filter: Optional[Filter],
use_admin_priviledge: bool = False, use_admin_priviledge: bool = False,
) -> Optional[JsonDict]: ) -> Optional[EventContext]:
"""Retrieves events, pagination tokens and state around a given event """Retrieves events, pagination tokens and state around a given event
in a room. in a room.
@ -1167,38 +1180,28 @@ class RoomContextHandler:
results = await self.store.get_events_around( results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter room_id, event_id, before_limit, after_limit, event_filter
) )
events_before = results.events_before
events_after = results.events_after
if event_filter: if event_filter:
results["events_before"] = await event_filter.filter( events_before = await event_filter.filter(events_before)
results["events_before"] events_after = await event_filter.filter(events_after)
)
results["events_after"] = await event_filter.filter(results["events_after"])
results["events_before"] = await filter_evts(results["events_before"]) events_before = await filter_evts(events_before)
results["events_after"] = await filter_evts(results["events_after"]) events_after = await filter_evts(events_after)
# filter_evts can return a pruned event in case the user is allowed to see that # filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in # there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore. # `filtered` rather than the event we retrieved from the datastore.
results["event"] = filtered[0] event = filtered[0]
# Fetch the aggregations. # Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations( aggregations = await self.store.get_bundled_aggregations(
[results["event"]], user.to_string() itertools.chain(events_before, (event,), events_after),
user.to_string(),
) )
aggregations.update(
await self.store.get_bundled_aggregations(
results["events_before"], user.to_string()
)
)
aggregations.update(
await self.store.get_bundled_aggregations(
results["events_after"], user.to_string()
)
)
results["aggregations"] = aggregations
if results["events_after"]: if events_after:
last_event_id = results["events_after"][-1].event_id last_event_id = events_after[-1].event_id
else: else:
last_event_id = event_id last_event_id = event_id
@ -1206,9 +1209,9 @@ class RoomContextHandler:
state_filter = StateFilter.from_lazy_load_member_list( state_filter = StateFilter.from_lazy_load_member_list(
ev.sender ev.sender
for ev in itertools.chain( for ev in itertools.chain(
results["events_before"], events_before,
(results["event"],), (event,),
results["events_after"], events_after,
) )
) )
else: else:
@ -1226,21 +1229,23 @@ class RoomContextHandler:
if event_filter: if event_filter:
state_events = await event_filter.filter(state_events) state_events = await event_filter.filter(state_events)
results["state"] = await filter_evts(state_events)
# We use a dummy token here as we only care about the room portion of # We use a dummy token here as we only care about the room portion of
# the token, which we replace. # the token, which we replace.
token = StreamToken.START token = StreamToken.START
results["start"] = await token.copy_and_replace( return EventContext(
"room_key", results["start"] events_before=events_before,
).to_string(self.store) event=event,
events_after=events_after,
results["end"] = await token.copy_and_replace( state=await filter_evts(state_events),
"room_key", results["end"] aggregations=aggregations,
).to_string(self.store) start=await token.copy_and_replace("room_key", results.start).to_string(
self.store
return results ),
end=await token.copy_and_replace("room_key", results.end).to_string(
self.store
),
)
class TimestampLookupHandler: class TimestampLookupHandler:

View File

@ -361,36 +361,37 @@ class SearchHandler:
logger.info( logger.info(
"Context for search returned %d and %d events", "Context for search returned %d and %d events",
len(res["events_before"]), len(res.events_before),
len(res["events_after"]), len(res.events_after),
) )
res["events_before"] = await filter_events_for_client( events_before = await filter_events_for_client(
self.storage, user.to_string(), res["events_before"] self.storage, user.to_string(), res.events_before
) )
res["events_after"] = await filter_events_for_client( events_after = await filter_events_for_client(
self.storage, user.to_string(), res["events_after"] self.storage, user.to_string(), res.events_after
) )
res["start"] = await now_token.copy_and_replace( context = {
"room_key", res["start"] "events_before": events_before,
).to_string(self.store) "events_after": events_after,
"start": await now_token.copy_and_replace(
res["end"] = await now_token.copy_and_replace( "room_key", res.start
"room_key", res["end"] ).to_string(self.store),
).to_string(self.store) "end": await now_token.copy_and_replace(
"room_key", res.end
).to_string(self.store),
}
if include_profile: if include_profile:
senders = { senders = {
ev.sender ev.sender
for ev in itertools.chain( for ev in itertools.chain(events_before, [event], events_after)
res["events_before"], [event], res["events_after"]
)
} }
if res["events_after"]: if events_after:
last_event_id = res["events_after"][-1].event_id last_event_id = events_after[-1].event_id
else: else:
last_event_id = event.event_id last_event_id = event.event_id
@ -402,7 +403,7 @@ class SearchHandler:
last_event_id, state_filter last_event_id, state_filter
) )
res["profile_info"] = { context["profile_info"] = {
s.state_key: { s.state_key: {
"displayname": s.content.get("displayname", None), "displayname": s.content.get("displayname", None),
"avatar_url": s.content.get("avatar_url", None), "avatar_url": s.content.get("avatar_url", None),
@ -411,7 +412,7 @@ class SearchHandler:
if s.type == EventTypes.Member and s.state_key in senders if s.type == EventTypes.Member and s.state_key in senders
} }
contexts[event.event_id] = res contexts[event.event_id] = context
else: else:
contexts = {} contexts = {}
@ -421,10 +422,10 @@ class SearchHandler:
for context in contexts.values(): for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events( context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now context["events_before"], time_now # type: ignore[arg-type]
) )
context["events_after"] = self._event_serializer.serialize_events( context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now context["events_after"], time_now # type: ignore[arg-type]
) )
state_results = {} state_results = {}

View File

@ -37,6 +37,7 @@ from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
@ -100,7 +101,7 @@ class TimelineBatch:
limited: bool limited: bool
# A mapping of event ID to the bundled aggregations for the above events. # A mapping of event ID to the bundled aggregations for the above events.
# This is only calculated if limited is true. # This is only calculated if limited is true.
bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None
def __bool__(self) -> bool: def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used """Make the result appear empty if there are no updates. This is used

View File

@ -455,7 +455,7 @@ class Mailer:
} }
the_events = await filter_events_for_client( the_events = await filter_events_for_client(
self.storage, user_id, results["events_before"] self.storage, user_id, results.events_before
) )
the_events.append(notif_event) the_events.append(notif_event)

View File

@ -729,7 +729,7 @@ class RoomEventContextServlet(RestServlet):
else: else:
event_filter = None event_filter = None
results = await self.room_context_handler.get_event_context( event_context = await self.room_context_handler.get_event_context(
requester, requester,
room_id, room_id,
event_id, event_id,
@ -738,25 +738,34 @@ class RoomEventContextServlet(RestServlet):
use_admin_priviledge=True, use_admin_priviledge=True,
) )
if not results: if not event_context:
raise SynapseError( raise SynapseError(
HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
aggregations = results.pop("aggregations", None) results = {
results["events_before"] = self._event_serializer.serialize_events( "events_before": self._event_serializer.serialize_events(
results["events_before"], time_now, bundle_aggregations=aggregations event_context.events_before,
) time_now,
results["event"] = self._event_serializer.serialize_event( bundle_aggregations=event_context.aggregations,
results["event"], time_now, bundle_aggregations=aggregations ),
) "event": self._event_serializer.serialize_event(
results["events_after"] = self._event_serializer.serialize_events( event_context.event,
results["events_after"], time_now, bundle_aggregations=aggregations time_now,
) bundle_aggregations=event_context.aggregations,
results["state"] = self._event_serializer.serialize_events( ),
results["state"], time_now "events_after": self._event_serializer.serialize_events(
) event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
),
"state": self._event_serializer.serialize_events(
event_context.state, time_now
),
"start": event_context.start,
"end": event_context.end,
}
return HTTPStatus.OK, results return HTTPStatus.OK, results

View File

@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet):
else: else:
event_filter = None event_filter = None
results = await self.room_context_handler.get_event_context( event_context = await self.room_context_handler.get_event_context(
requester, room_id, event_id, limit, event_filter requester, room_id, event_id, limit, event_filter
) )
if not results: if not event_context:
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
aggregations = results.pop("aggregations", None) results = {
results["events_before"] = self._event_serializer.serialize_events( "events_before": self._event_serializer.serialize_events(
results["events_before"], time_now, bundle_aggregations=aggregations event_context.events_before,
) time_now,
results["event"] = self._event_serializer.serialize_event( bundle_aggregations=event_context.aggregations,
results["event"], time_now, bundle_aggregations=aggregations ),
) "event": self._event_serializer.serialize_event(
results["events_after"] = self._event_serializer.serialize_events( event_context.event,
results["events_after"], time_now, bundle_aggregations=aggregations time_now,
) bundle_aggregations=event_context.aggregations,
results["state"] = self._event_serializer.serialize_events( ),
results["state"], time_now "events_after": self._event_serializer.serialize_events(
) event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
),
"state": self._event_serializer.serialize_events(
event_context.state, time_now
),
"start": event_context.start,
"end": event_context.end,
}
return 200, results return 200, results

View File

@ -48,6 +48,7 @@ from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.types import JsonDict, StreamToken from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder from synapse.util import json_decoder
@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet):
def serialize( def serialize(
events: Iterable[EventBase], events: Iterable[EventBase],
aggregations: Optional[Dict[str, Dict[str, Any]]] = None, aggregations: Optional[Dict[str, BundledAggregations]] = None,
) -> List[JsonDict]: ) -> List[JsonDict]:
return self._event_serializer.serialize_events( return self._event_serializer.serialize_events(
events, events,

View File

@ -13,17 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import ( from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
cast,
)
import attr import attr
from frozendict import frozendict from frozendict import frozendict
@ -43,6 +33,7 @@ from synapse.storage.relations import (
PaginationChunk, PaginationChunk,
RelationPaginationToken, RelationPaginationToken,
) )
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING: if TYPE_CHECKING:
@ -51,6 +42,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
latest_event: EventBase
count: int
current_user_participated: bool
@attr.s(slots=True, auto_attribs=True)
class BundledAggregations:
"""
The bundled aggregations for an event.
Some values require additional processing during serialization.
"""
annotations: Optional[JsonDict] = None
references: Optional[JsonDict] = None
replace: Optional[EventBase] = None
thread: Optional[_ThreadAggregation] = None
def __bool__(self) -> bool:
return bool(self.annotations or self.references or self.replace or self.thread)
class RelationsWorkerStore(SQLBaseStore): class RelationsWorkerStore(SQLBaseStore):
def __init__( def __init__(
self, self,
@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def _get_bundled_aggregation_for_event( async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str self, event: EventBase, user_id: str
) -> Optional[Dict[str, Any]]: ) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event. """Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods. Note that this does not use a cache, but depends on cached methods.
@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore):
# The bundled aggregations to include, a mapping of relation type to a # The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here # type-specific value. Some types include the direct return type here
# while others need more processing during serialization. # while others need more processing during serialization.
aggregations: Dict[str, Any] = {} aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id) annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk: if annotations.chunk:
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() aggregations.annotations = annotations.to_dict()
references = await self.get_relations_for_event( references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f" event_id, room_id, RelationTypes.REFERENCE, direction="f"
) )
if references.chunk: if references.chunk:
aggregations[RelationTypes.REFERENCE] = references.to_dict() aggregations.references = references.to_dict()
edit = None edit = None
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
edit = await self.get_applicable_edit(event_id, room_id) edit = await self.get_applicable_edit(event_id, room_id)
if edit: if edit:
aggregations[RelationTypes.REPLACE] = edit aggregations.replace = edit
# If this event is the start of a thread, include a summary of the replies. # If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled: if self._msc3440_enabled:
@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore):
event_id, room_id, user_id event_id, room_id, user_id
) )
if latest_thread_event: if latest_thread_event:
aggregations[RelationTypes.THREAD] = { aggregations.thread = _ThreadAggregation(
"latest_event": latest_thread_event, latest_event=latest_thread_event,
"count": thread_count, count=thread_count,
"current_user_participated": participated, current_user_participated=participated,
} )
# Store the bundled aggregations in the event metadata for later use. # Store the bundled aggregations in the event metadata for later use.
return aggregations return aggregations
@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore):
self, self,
events: Iterable[EventBase], events: Iterable[EventBase],
user_id: str, user_id: str,
) -> Dict[str, Dict[str, Any]]: ) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events. """Generate bundled aggregations for events.
Args: Args:
@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore):
results = {} results = {}
for event in events: for event in events:
event_result = await self._get_bundled_aggregation_for_event(event, user_id) event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result is not None: if event_result:
results[event.event_id] = event_result results[event.event_id] = event_result
return results return results

View File

@ -81,6 +81,14 @@ class _EventDictReturn:
stream_ordering: int stream_ordering: int
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventsAround:
events_before: List[EventBase]
events_after: List[EventBase]
start: RoomStreamToken
end: RoomStreamToken
def generate_pagination_where_clause( def generate_pagination_where_clause(
direction: str, direction: str,
column_names: Tuple[str, str], column_names: Tuple[str, str],
@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
before_limit: int, before_limit: int,
after_limit: int, after_limit: int,
event_filter: Optional[Filter] = None, event_filter: Optional[Filter] = None,
) -> dict: ) -> _EventsAround:
"""Retrieve events and pagination tokens around a given event in a """Retrieve events and pagination tokens around a given event in a
room. room.
""" """
@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
list(results["after"]["event_ids"]), get_prev_content=True list(results["after"]["event_ids"]), get_prev_content=True
) )
return { return _EventsAround(
"events_before": events_before, events_before=events_before,
"events_after": events_after, events_after=events_after,
"start": results["before"]["token"], start=results["before"]["token"],
"end": results["after"]["token"], end=results["after"]["token"],
} )
def _get_events_around_txn( def _get_events_around_txn(
self, self,

View File

@ -577,7 +577,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body) self.assertEquals(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
self.assertTrue(room_timeline["limited"]) self.assertTrue(room_timeline["limited"])
self._find_event_in_chunk(room_timeline["events"]) assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
def test_aggregation_get_event_for_annotation(self): def test_aggregation_get_event_for_annotation(self):
"""Test that annotations do not get bundled aggregations included """Test that annotations do not get bundled aggregations included