Better return type for get_all_entities_changed (#14604)

Help callers from using the return value incorrectly by ensuring
that callers explicitly check if there was a cache hit or not.
This commit is contained in:
Erik Johnston 2022-12-05 20:19:14 +00:00 committed by GitHub
parent 6a8310f3df
commit cee9445884
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 138 additions and 76 deletions

View file

@ -16,6 +16,7 @@ import logging
import math
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union
import attr
from sortedcontainers import SortedDict
from synapse.util import caches
@ -26,6 +27,29 @@ logger = logging.getLogger(__name__)
EntityType = str
@attr.s(auto_attribs=True, frozen=True, slots=True)
class AllEntitiesChangedResult:
"""Return type of `get_all_entities_changed`.
Callers must check that there was a cache hit, via `result.hit`, before
using the entities in `result.entities`.
This specifically does *not* implement helpers such as `__bool__` to ensure
that callers do the correct checks.
"""
_entities: Optional[List[EntityType]]
@property
def hit(self) -> bool:
return self._entities is not None
@property
def entities(self) -> List[EntityType]:
assert self._entities is not None
return self._entities
class StreamChangeCache:
"""
Keeps track of the stream positions of the latest change in a set of entities.
@ -153,19 +177,19 @@ class StreamChangeCache:
This will be all entities if the given stream position is at or earlier
than the earliest known stream position.
"""
changed_entities = self.get_all_entities_changed(stream_pos)
if changed_entities is not None:
cache_result = self.get_all_entities_changed(stream_pos)
if cache_result.hit:
# We now do an intersection, trying to do so in the most efficient
# way possible (some of these sets are *large*). First check in the
# given iterable is already a set that we can reuse, otherwise we
# create a set of the *smallest* of the two iterables and call
# `intersection(..)` on it (this can be twice as fast as the reverse).
if isinstance(entities, (set, frozenset)):
result = entities.intersection(changed_entities)
elif len(changed_entities) < len(entities):
result = set(changed_entities).intersection(entities)
result = entities.intersection(cache_result.entities)
elif len(cache_result.entities) < len(entities):
result = set(cache_result.entities).intersection(entities)
else:
result = set(entities).intersection(changed_entities)
result = set(entities).intersection(cache_result.entities)
self.metrics.inc_hits()
else:
result = set(entities)
@ -202,12 +226,12 @@ class StreamChangeCache:
self.metrics.inc_hits()
return stream_pos < self._cache.peekitem()[0]
def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
"""
Returns all entities that have had changes after the given position.
If the stream change cache does not go far enough back, i.e. the position
is too old, it will return None.
If the stream change cache does not go far enough back, i.e. the
position is too old, it will return None.
Returns the entities in the order that they were changed.
@ -215,23 +239,21 @@ class StreamChangeCache:
stream_pos: The stream position to check for changes after.
Return:
Entities which have changed after the given stream position.
None if the given stream position is at or earlier than the earliest
known stream position.
A class indicating if we have the requested data cached, and if so
includes the entities in the order they were changed.
"""
assert isinstance(stream_pos, int)
# _cache is not valid at or before the earliest known stream position, so
# return None to mark that it is unknown if an entity has changed.
if stream_pos <= self._earliest_known_stream_pos:
return None
return AllEntitiesChangedResult(None)
changed_entities: List[EntityType] = []
for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
changed_entities.extend(self._cache[k])
return changed_entities
return AllEntitiesChangedResult(changed_entities)
def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
"""