Fix overzealous cache invalidation

Fixes an issue where a cache invalidation would invalidate *all* pending
entries, rather than just the entry that we intended to invalidate.
This commit is contained in:
Richard van der Hoff 2018-04-05 16:24:04 +01:00
parent 7d0f712348
commit 01afc563c3
2 changed files with 84 additions and 26 deletions

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -39,12 +40,11 @@ _CacheSentinel = object()
class CacheEntry(object): class CacheEntry(object):
__slots__ = [ __slots__ = [
"deferred", "sequence", "callbacks", "invalidated" "deferred", "callbacks", "invalidated"
] ]
def __init__(self, deferred, sequence, callbacks): def __init__(self, deferred, callbacks):
self.deferred = deferred self.deferred = deferred
self.sequence = sequence
self.callbacks = set(callbacks) self.callbacks = set(callbacks)
self.invalidated = False self.invalidated = False
@ -62,7 +62,6 @@ class Cache(object):
"max_entries", "max_entries",
"name", "name",
"keylen", "keylen",
"sequence",
"thread", "thread",
"metrics", "metrics",
"_pending_deferred_cache", "_pending_deferred_cache",
@ -80,7 +79,6 @@ class Cache(object):
self.name = name self.name = name
self.keylen = keylen self.keylen = keylen
self.sequence = 0
self.thread = None self.thread = None
self.metrics = register_cache(name, self.cache) self.metrics = register_cache(name, self.cache)
@ -113,11 +111,10 @@ class Cache(object):
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel) val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel: if val is not _CacheSentinel:
if val.sequence == self.sequence: val.callbacks.update(callbacks)
val.callbacks.update(callbacks) if update_metrics:
if update_metrics: self.metrics.inc_hits()
self.metrics.inc_hits() return val.deferred
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks) val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel: if val is not _CacheSentinel:
@ -137,12 +134,9 @@ class Cache(object):
self.check_thread() self.check_thread()
entry = CacheEntry( entry = CacheEntry(
deferred=value, deferred=value,
sequence=self.sequence,
callbacks=callbacks, callbacks=callbacks,
) )
entry.callbacks.update(callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None) existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry: if existing_entry:
existing_entry.invalidate() existing_entry.invalidate()
@ -150,13 +144,25 @@ class Cache(object):
self._pending_deferred_cache[key] = entry self._pending_deferred_cache[key] = entry
def shuffle(result): def shuffle(result):
if self.sequence == entry.sequence: existing_entry = self._pending_deferred_cache.pop(key, None)
existing_entry = self._pending_deferred_cache.pop(key, None) if existing_entry is entry:
if existing_entry is entry: self.cache.set(key, result, entry.callbacks)
self.cache.set(key, result, entry.callbacks)
else:
entry.invalidate()
else: else:
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate() entry.invalidate()
return result return result
@ -168,25 +174,29 @@ class Cache(object):
def invalidate(self, key): def invalidate(self, key):
self.check_thread() self.check_thread()
self.cache.pop(key, None)
# Increment the sequence number so that any SELECT statements that # if we have a pending lookup for this key, remove it from the
# raced with the INSERT don't update the cache (SYN-369) # _pending_deferred_cache, which will (a) stop it being returned
self.sequence += 1 # for future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None) entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry: if entry:
entry.invalidate() entry.invalidate()
self.cache.pop(key, None)
def invalidate_many(self, key): def invalidate_many(self, key):
self.check_thread() self.check_thread()
if not isinstance(key, tuple): if not isinstance(key, tuple):
raise TypeError( raise TypeError(
"The cache key must be a tuple not %r" % (type(key),) "The cache key must be a tuple not %r" % (type(key),)
) )
self.sequence += 1
self.cache.del_multi(key) self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None) entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None: if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict): for entry in iterate_tree_cache_entry(entry_dict):
@ -194,8 +204,10 @@ class Cache(object):
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
self.sequence += 1
self.cache.clear() self.cache.clear()
for entry in self._pending_deferred_cache.itervalues():
entry.invalidate()
self._pending_deferred_cache.clear()
class _CacheDescriptorBase(object): class _CacheDescriptorBase(object):

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import logging import logging
import mock import mock
@ -25,6 +27,50 @@ from tests import unittest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CacheTestCase(unittest.TestCase):
def test_invalidate_all(self):
cache = descriptors.Cache("testcache")
callback_record = [False, False]
def record_callback(idx):
callback_record[idx] = True
# add a couple of pending entries
d1 = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return the deferreds
self.assertIs(cache.get("key1"), d1)
self.assertIs(cache.get("key2"), d2)
# let one of the lookups complete
d2.callback("result2")
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
cache.invalidate_all()
# lookup should return none
self.assertIsNone(cache.get("key1", None))
self.assertIsNone(cache.get("key2", None))
# both callbacks should have been callbacked
self.assertTrue(
callback_record[0], "Invalidation callback for key1 not called",
)
self.assertTrue(
callback_record[1], "Invalidation callback for key2 not called",
)
# letting the other lookup complete should do nothing
d1.callback("result1")
self.assertIsNone(cache.get("key1", None))
class DescriptorTestCase(unittest.TestCase): class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache(self): def test_cache(self):