Speed up @cachedList (#13591)

This speeds things up by ~2x.

The vast majority of the time is now spent in `LruCache` moving things around the linked lists.

We do this via two things:
1. Don't create a deferred per-key during bulk set operations in `DeferredCache`. Instead, only create them if a subsequent caller asks for the key.
2. Add a bulk lookup API to `DeferredCache` rather than use a loop.
This commit is contained in:
Erik Johnston 2022-08-23 15:53:27 +01:00 committed by GitHub
parent 05c9c7363b
commit f7ddfe17a3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 298 additions and 141 deletions

View file

@ -14,15 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import enum
import threading
from typing import (
Callable,
Collection,
Dict,
Generic,
Iterable,
MutableMapping,
Optional,
Set,
Sized,
Tuple,
TypeVar,
Union,
cast,
@ -31,7 +35,6 @@ from typing import (
from prometheus_client import Gauge
from twisted.internet import defer
from twisted.python import failure
from twisted.python.failure import Failure
from synapse.util.async_helpers import ObservableDeferred
@ -94,7 +97,7 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache: Union[
TreeCache, "MutableMapping[KT, CacheEntry]"
TreeCache, "MutableMapping[KT, CacheEntry[KT, VT]]"
] = cache_type()
def metrics_cb() -> None:
@ -159,15 +162,16 @@ class DeferredCache(Generic[KT, VT]):
Raises:
KeyError if the key is not found in the cache
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
val.add_invalidation_callback(key, callback)
if update_metrics:
m = self.cache.metrics
assert m # we always have a name, so should always have metrics
m.inc_hits()
return val.deferred.observe()
return val.deferred(key)
callbacks = (callback,) if callback else ()
val2 = self.cache.get(
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
@ -177,6 +181,73 @@ class DeferredCache(Generic[KT, VT]):
else:
return defer.succeed(val2)
def get_bulk(
self,
keys: Collection[KT],
callback: Optional[Callable[[], None]] = None,
) -> Tuple[Dict[KT, VT], Optional["defer.Deferred[Dict[KT, VT]]"], Collection[KT]]:
"""Bulk lookup of items in the cache.
Returns:
A 3-tuple of:
1. a dict of key/value of items already cached;
2. a deferred that resolves to a dict of key/value of items
we're already fetching; and
3. a collection of keys that don't appear in the previous two.
"""
# The cached results
cached = {}
# List of pending deferreds
pending = []
# Dict that gets filled out when the pending deferreds complete
pending_results = {}
# List of keys that aren't in either cache
missing = []
callbacks = (callback,) if callback else ()
for key in keys:
# Check if its in the main cache.
immediate_value = self.cache.get(
key,
_Sentinel.sentinel,
callbacks=callbacks,
)
if immediate_value is not _Sentinel.sentinel:
cached[key] = immediate_value
continue
# Check if its in the pending cache
pending_value = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if pending_value is not _Sentinel.sentinel:
pending_value.add_invalidation_callback(key, callback)
def completed_cb(value: VT, key: KT) -> VT:
pending_results[key] = value
return value
# Add a callback to fill out `pending_results` when that completes
d = pending_value.deferred(key).addCallback(completed_cb, key)
pending.append(d)
continue
# Not in either cache
missing.append(key)
# If we've got pending deferreds, squash them into a single one that
# returns `pending_results`.
pending_deferred = None
if pending:
pending_deferred = defer.gatherResults(
pending, consumeErrors=True
).addCallback(lambda _: pending_results)
return (cached, pending_deferred, missing)
def get_immediate(
self, key: KT, default: T, update_metrics: bool = True
) -> Union[VT, T]:
@ -218,84 +289,89 @@ class DeferredCache(Generic[KT, VT]):
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache.pop(key, None)
# XXX: why don't we invalidate the entry in `self.cache` yet?
# we can save a whole load of effort if the deferred is ready.
if value.called:
result = value.result
if not isinstance(result, failure.Failure):
self.cache.set(key, cast(VT, result), callbacks)
return value
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
entry = CacheEntrySingle[KT, VT](value)
entry.add_invalidation_callback(key, callback)
self._pending_deferred_cache[key] = entry
def compare_and_pop() -> bool:
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result: VT) -> None:
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail: Failure) -> None:
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
deferred = entry.deferred(key).addCallbacks(
self._completed_callback,
self._error_callback,
callbackArgs=(entry, key),
errbackArgs=(entry, key),
)
# we return a new Deferred which will be called before any subsequent observers.
return observable.observe()
return deferred
def start_bulk_input(
self,
keys: Collection[KT],
callback: Optional[Callable[[], None]] = None,
) -> "CacheMultipleEntries[KT, VT]":
"""Bulk set API for use when fetching multiple keys at once from the DB.
Called *before* starting the fetch from the DB, and the caller *must*
call either `complete_bulk(..)` or `error_bulk(..)` on the return value.
"""
entry = CacheMultipleEntries[KT, VT]()
entry.add_global_invalidation_callback(callback)
for key in keys:
self._pending_deferred_cache[key] = entry
return entry
def _completed_callback(
self, value: VT, entry: "CacheEntry[KT, VT]", key: KT
) -> VT:
"""Called when a deferred is completed."""
# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self._pending_deferred_cache.pop(key, None)
if current_entry is not entry:
if current_entry:
self._pending_deferred_cache[key] = current_entry
return value
self.cache.set(key, value, entry.get_invalidation_callbacks(key))
return value
def _error_callback(
self,
failure: Failure,
entry: "CacheEntry[KT, VT]",
key: KT,
) -> Failure:
"""Called when a deferred errors."""
# We check if the current entry matches the entry associated with the
# deferred. If they don't match then it got invalidated.
current_entry = self._pending_deferred_cache.pop(key, None)
if current_entry is not entry:
if current_entry:
self._pending_deferred_cache[key] = current_entry
return failure
for cb in entry.get_invalidation_callbacks(key):
cb()
return failure
def prefill(
self, key: KT, value: VT, callback: Optional[Callable[[], None]] = None
) -> None:
callbacks = [callback] if callback else []
callbacks = (callback,) if callback else ()
self.cache.set(key, value, callbacks=callbacks)
self._pending_deferred_cache.pop(key, None)
def invalidate(self, key: KT) -> None:
"""Delete a key, or tree of entries
@ -311,41 +387,129 @@ class DeferredCache(Generic[KT, VT]):
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# _pending_deferred_cache, which will (a) stop it being returned for
# future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
# _pending_deferred_cache.pop should either return a CacheEntry, or, in the
# case of a TreeCache, a dict of keys to cache entries. Either way calling
# iterate_tree_cache_entry on it will do the right thing.
for entry in iterate_tree_cache_entry(entry):
entry.invalidate()
for cb in entry.get_invalidation_callbacks(key):
cb()
def invalidate_all(self) -> None:
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
for key, entry in self._pending_deferred_cache.items():
for cb in entry.get_invalidation_callbacks(key):
cb()
self._pending_deferred_cache.clear()
class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
class CacheEntry(Generic[KT, VT], metaclass=abc.ABCMeta):
"""Abstract class for entries in `DeferredCache[KT, VT]`"""
def __init__(
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False
@abc.abstractmethod
def deferred(self, key: KT) -> "defer.Deferred[VT]":
"""Get a deferred that a caller can wait on to get the value at the
given key"""
...
def invalidate(self) -> None:
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
@abc.abstractmethod
def add_invalidation_callback(
self, key: KT, callback: Optional[Callable[[], None]]
) -> None:
"""Add an invalidation callback"""
...
@abc.abstractmethod
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
"""Get all invalidation callbacks"""
...
class CacheEntrySingle(CacheEntry[KT, VT]):
"""An implementation of `CacheEntry` wrapping a deferred that results in a
single cache entry.
"""
__slots__ = ["_deferred", "_callbacks"]
def __init__(self, deferred: "defer.Deferred[VT]") -> None:
self._deferred = ObservableDeferred(deferred, consumeErrors=True)
self._callbacks: Set[Callable[[], None]] = set()
def deferred(self, key: KT) -> "defer.Deferred[VT]":
return self._deferred.observe()
def add_invalidation_callback(
self, key: KT, callback: Optional[Callable[[], None]]
) -> None:
if callback is None:
return
self._callbacks.add(callback)
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
return self._callbacks
class CacheMultipleEntries(CacheEntry[KT, VT]):
"""Cache entry that is used for bulk lookups and insertions."""
__slots__ = ["_deferred", "_callbacks", "_global_callbacks"]
def __init__(self) -> None:
self._deferred: Optional[ObservableDeferred[Dict[KT, VT]]] = None
self._callbacks: Dict[KT, Set[Callable[[], None]]] = {}
self._global_callbacks: Set[Callable[[], None]] = set()
def deferred(self, key: KT) -> "defer.Deferred[VT]":
if not self._deferred:
self._deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
return self._deferred.observe().addCallback(lambda res: res.get(key))
def add_invalidation_callback(
self, key: KT, callback: Optional[Callable[[], None]]
) -> None:
if callback is None:
return
self._callbacks.setdefault(key, set()).add(callback)
def get_invalidation_callbacks(self, key: KT) -> Collection[Callable[[], None]]:
return self._callbacks.get(key, set()) | self._global_callbacks
def add_global_invalidation_callback(
self, callback: Optional[Callable[[], None]]
) -> None:
"""Add a callback for when any keys get invalidated."""
if callback is None:
return
self._global_callbacks.add(callback)
def complete_bulk(
self,
cache: DeferredCache[KT, VT],
result: Dict[KT, VT],
) -> None:
"""Called when there is a result"""
for key, value in result.items():
cache._completed_callback(value, self, key)
if self._deferred:
self._deferred.callback(result)
def error_bulk(
self, cache: DeferredCache[KT, VT], keys: Collection[KT], failure: Failure
) -> None:
"""Called when bulk lookup failed."""
for key in keys:
cache._error_callback(failure, self, key)
if self._deferred:
self._deferred.errback(failure)