Move all the caches into their own package, synapse.util.caches

This commit is contained in:
Erik Johnston 2015-08-11 17:59:32 +01:00
parent 53a817518b
commit 2df8dd9b37
23 changed files with 408 additions and 352 deletions

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,359 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
import synapse.metrics
from twisted.internet import defer
from collections import OrderedDict
import functools
import inspect
import threading
logger = logging.getLogger(__name__)
DEBUG_CACHES = False
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {}
cache_counter = metrics.register_cache(
"cache",
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
labels=["name"],
)
_CacheSentinel = object()
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name
self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
return val
cache_counter.inc_misses(self.name)
if default is _CacheSentinel:
raise KeyError()
else:
return default
def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value)
def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[key] = value
def invalidate(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise ValueError("keyargs must be a tuple.")
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
"invalidate". This can be used to remove individual entries from the cache.
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
inlineCallbacks=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
self.cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@defer.inlineCallbacks
def check_result(cached_result):
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, cache_key,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = self.cache.sequence
ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
def onErr(f):
self.cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
return ret.observe()
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
class CacheListDescriptor(object):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped fucntion.
"""
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
"""
Args:
orig (function)
cache (Cache)
list_name (str): Name of the argument which is the bulk lookup list
num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks
"""
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.num_args = num_args
self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
self.sentinel = object()
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
% (self.list_name, cache.name,)
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
cached = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
res = self.cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
except KeyError:
missing.append(arg)
if missing:
sequence = self.cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
self.function_to_call,
**args_to_call
)
ret_d = ObservableDeferred(ret_d)
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r[arg], arg)
observer = ObservableDeferred(observer)
key = list(keyargs)
key[self.list_pos] = arg
self.cache.update(sequence, tuple(key), observer)
def invalidate(f, key):
self.cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
return defer.gatherResults(
cached.values(),
consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
return lambda orig: CacheListDescriptor(
orig,
cache=cache,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
)

View file

@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.caches.lrucache import LruCache
from collections import namedtuple
import threading
import logging
logger = logging.getLogger(__name__)
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
class DictionaryCache(object):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries)
self.name = name
self.sequence = 0
self.thread = None
# caches_by_name[name] = self.cache
class Sentinel(object):
__slots__ = []
self.sentinel = Sentinel()
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, dict_keys=None):
try:
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
# cache_counter.inc_hits(self.name)
if dict_keys is None:
return DictionaryEntry(entry.full, dict(entry.value))
else:
return DictionaryEntry(entry.full, {
k: entry.value[k]
for k in dict_keys
if k in entry.value
})
# cache_counter.inc_misses(self.name)
return DictionaryEntry(False, {})
except:
logger.exception("get failed")
raise
def invalidate(self, key):
self.check_thread()
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
def update(self, sequence, key, value, full=False):
try:
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
if full:
self._insert(key, value)
else:
self._update_or_insert(key, value)
except:
logger.exception("update failed")
raise
def _update_or_insert(self, key, value):
entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
entry.value.update(value)
def _insert(self, key, value):
self.cache[key] = DictionaryEntry(True, value)

View file

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
logger = logging.getLogger(__name__)
class ExpiringCache(object):
def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
reset_expiry_on_get=False):
"""
Args:
cache_name (str): Name of this cache, used for logging.
clock (Clock)
max_len (int): Max size of dict. If the dict grows larger than this
then the oldest items get automatically evicted. Default is 0,
which indicates there is no max limit.
expiry_ms (int): How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for
an item on access. Defaults to False.
"""
self._cache_name = cache_name
self._clock = clock
self._max_len = max_len
self._expiry_ms = expiry_ms
self._reset_expiry_on_get = reset_expiry_on_get
self._cache = {}
def start(self):
if not self._expiry_ms:
# Don't bother starting the loop if things never expire
return
def f():
self._prune_cache()
self._clock.looping_call(f, self._expiry_ms/2)
def __setitem__(self, key, value):
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
# Evict if there are now too many items
if self._max_len and len(self._cache.keys()) > self._max_len:
sorted_entries = sorted(
self._cache.items(),
key=lambda (k, v): v.time,
)
for k, _ in sorted_entries[self._max_len:]:
self._cache.pop(k)
def __getitem__(self, key):
entry = self._cache[key]
if self._reset_expiry_on_get:
entry.time = self._clock.time_msec()
return entry.value
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def _prune_cache(self):
if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
return
begin_length = len(self._cache)
now = self._clock.time_msec()
keys_to_delete = set()
for key, cache_entry in self._cache.items():
if now - cache_entry.time > self._expiry_ms:
keys_to_delete.add(key)
for k in keys_to_delete:
self._cache.pop(k)
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
self._cache_name, begin_length, len(self._cache.keys())
)
class _CacheEntry(object):
def __init__(self, time, value):
self.time = time
self.value = value

View file

@ -0,0 +1,149 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
import threading
class LruCache(object):
"""Least-recently-used cache."""
def __init__(self, max_size):
cache = {}
list_root = []
list_root[:] = [list_root, list_root, None, None]
PREV, NEXT, KEY, VALUE = 0, 1, 2, 3
lock = threading.Lock()
def synchronized(f):
@wraps(f)
def inner(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return inner
def add_node(key, value):
prev_node = list_root
next_node = prev_node[NEXT]
node = [prev_node, next_node, key, value]
prev_node[NEXT] = node
next_node[PREV] = node
cache[key] = node
def move_node_to_front(node):
prev_node = node[PREV]
next_node = node[NEXT]
prev_node[NEXT] = next_node
next_node[PREV] = prev_node
prev_node = list_root
next_node = prev_node[NEXT]
node[PREV] = prev_node
node[NEXT] = next_node
prev_node[NEXT] = node
next_node[PREV] = node
def delete_node(node):
prev_node = node[PREV]
next_node = node[NEXT]
prev_node[NEXT] = next_node
next_node[PREV] = prev_node
cache.pop(node[KEY], None)
@synchronized
def cache_get(key, default=None):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
return node[VALUE]
else:
return default
@synchronized
def cache_set(key, value):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
node[VALUE] = value
else:
add_node(key, value)
if len(cache) > max_size:
delete_node(list_root[PREV])
@synchronized
def cache_set_default(key, value):
node = cache.get(key, None)
if node is not None:
return node[VALUE]
else:
add_node(key, value)
if len(cache) > max_size:
delete_node(list_root[PREV])
return value
@synchronized
def cache_pop(key, default=None):
node = cache.get(key, None)
if node:
delete_node(node)
return node[VALUE]
else:
return default
@synchronized
def cache_clear():
list_root[NEXT] = list_root
list_root[PREV] = list_root
cache.clear()
@synchronized
def cache_len():
return len(cache)
@synchronized
def cache_contains(key):
return key in cache
self.sentinel = object()
self.get = cache_get
self.set = cache_set
self.setdefault = cache_set_default
self.pop = cache_pop
self.len = cache_len
self.contains = cache_contains
self.clear = cache_clear
def __getitem__(self, key):
result = self.get(key, self.sentinel)
if result is self.sentinel:
raise KeyError()
else:
return result
def __setitem__(self, key, value):
self.set(key, value)
def __delitem__(self, key, value):
result = self.pop(key, self.sentinel)
if result is self.sentinel:
raise KeyError()
def __len__(self):
return self.len()
def __contains__(self, key):
return self.contains(key)