mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Fix overzealous cache invalidation
Fixes an issue where a cache invalidation would invalidate *all* pending entries, rather than just the entry that we intended to invalidate.
This commit is contained in:
parent
7d0f712348
commit
01afc563c3
@ -1,5 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -39,12 +40,11 @@ _CacheSentinel = object()
|
|||||||
|
|
||||||
class CacheEntry(object):
|
class CacheEntry(object):
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"deferred", "sequence", "callbacks", "invalidated"
|
"deferred", "callbacks", "invalidated"
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, deferred, sequence, callbacks):
|
def __init__(self, deferred, callbacks):
|
||||||
self.deferred = deferred
|
self.deferred = deferred
|
||||||
self.sequence = sequence
|
|
||||||
self.callbacks = set(callbacks)
|
self.callbacks = set(callbacks)
|
||||||
self.invalidated = False
|
self.invalidated = False
|
||||||
|
|
||||||
@ -62,7 +62,6 @@ class Cache(object):
|
|||||||
"max_entries",
|
"max_entries",
|
||||||
"name",
|
"name",
|
||||||
"keylen",
|
"keylen",
|
||||||
"sequence",
|
|
||||||
"thread",
|
"thread",
|
||||||
"metrics",
|
"metrics",
|
||||||
"_pending_deferred_cache",
|
"_pending_deferred_cache",
|
||||||
@ -80,7 +79,6 @@ class Cache(object):
|
|||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.keylen = keylen
|
self.keylen = keylen
|
||||||
self.sequence = 0
|
|
||||||
self.thread = None
|
self.thread = None
|
||||||
self.metrics = register_cache(name, self.cache)
|
self.metrics = register_cache(name, self.cache)
|
||||||
|
|
||||||
@ -113,11 +111,10 @@ class Cache(object):
|
|||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
val = self._pending_deferred_cache.get(key, _CacheSentinel)
|
val = self._pending_deferred_cache.get(key, _CacheSentinel)
|
||||||
if val is not _CacheSentinel:
|
if val is not _CacheSentinel:
|
||||||
if val.sequence == self.sequence:
|
val.callbacks.update(callbacks)
|
||||||
val.callbacks.update(callbacks)
|
if update_metrics:
|
||||||
if update_metrics:
|
self.metrics.inc_hits()
|
||||||
self.metrics.inc_hits()
|
return val.deferred
|
||||||
return val.deferred
|
|
||||||
|
|
||||||
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
|
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
|
||||||
if val is not _CacheSentinel:
|
if val is not _CacheSentinel:
|
||||||
@ -137,12 +134,9 @@ class Cache(object):
|
|||||||
self.check_thread()
|
self.check_thread()
|
||||||
entry = CacheEntry(
|
entry = CacheEntry(
|
||||||
deferred=value,
|
deferred=value,
|
||||||
sequence=self.sequence,
|
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
entry.callbacks.update(callbacks)
|
|
||||||
|
|
||||||
existing_entry = self._pending_deferred_cache.pop(key, None)
|
existing_entry = self._pending_deferred_cache.pop(key, None)
|
||||||
if existing_entry:
|
if existing_entry:
|
||||||
existing_entry.invalidate()
|
existing_entry.invalidate()
|
||||||
@ -150,13 +144,25 @@ class Cache(object):
|
|||||||
self._pending_deferred_cache[key] = entry
|
self._pending_deferred_cache[key] = entry
|
||||||
|
|
||||||
def shuffle(result):
|
def shuffle(result):
|
||||||
if self.sequence == entry.sequence:
|
existing_entry = self._pending_deferred_cache.pop(key, None)
|
||||||
existing_entry = self._pending_deferred_cache.pop(key, None)
|
if existing_entry is entry:
|
||||||
if existing_entry is entry:
|
self.cache.set(key, result, entry.callbacks)
|
||||||
self.cache.set(key, result, entry.callbacks)
|
|
||||||
else:
|
|
||||||
entry.invalidate()
|
|
||||||
else:
|
else:
|
||||||
|
# oops, the _pending_deferred_cache has been updated since
|
||||||
|
# we started our query, so we are out of date.
|
||||||
|
#
|
||||||
|
# Better put back whatever we took out. (We do it this way
|
||||||
|
# round, rather than peeking into the _pending_deferred_cache
|
||||||
|
# and then removing on a match, to make the common case faster)
|
||||||
|
if existing_entry is not None:
|
||||||
|
self._pending_deferred_cache[key] = existing_entry
|
||||||
|
|
||||||
|
# we're not going to put this entry into the cache, so need
|
||||||
|
# to make sure that the invalidation callbacks are called.
|
||||||
|
# That was probably done when _pending_deferred_cache was
|
||||||
|
# updated, but it's possible that `set` was called without
|
||||||
|
# `invalidate` being previously called, in which case it may
|
||||||
|
# not have been. Either way, let's double-check now.
|
||||||
entry.invalidate()
|
entry.invalidate()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -168,25 +174,29 @@ class Cache(object):
|
|||||||
|
|
||||||
def invalidate(self, key):
|
def invalidate(self, key):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
|
self.cache.pop(key, None)
|
||||||
|
|
||||||
# Increment the sequence number so that any SELECT statements that
|
# if we have a pending lookup for this key, remove it from the
|
||||||
# raced with the INSERT don't update the cache (SYN-369)
|
# _pending_deferred_cache, which will (a) stop it being returned
|
||||||
self.sequence += 1
|
# for future queries and (b) stop it being persisted as a proper entry
|
||||||
|
# in self.cache.
|
||||||
entry = self._pending_deferred_cache.pop(key, None)
|
entry = self._pending_deferred_cache.pop(key, None)
|
||||||
|
|
||||||
|
# run the invalidation callbacks now, rather than waiting for the
|
||||||
|
# deferred to resolve.
|
||||||
if entry:
|
if entry:
|
||||||
entry.invalidate()
|
entry.invalidate()
|
||||||
|
|
||||||
self.cache.pop(key, None)
|
|
||||||
|
|
||||||
def invalidate_many(self, key):
|
def invalidate_many(self, key):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
if not isinstance(key, tuple):
|
if not isinstance(key, tuple):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"The cache key must be a tuple not %r" % (type(key),)
|
"The cache key must be a tuple not %r" % (type(key),)
|
||||||
)
|
)
|
||||||
self.sequence += 1
|
|
||||||
self.cache.del_multi(key)
|
self.cache.del_multi(key)
|
||||||
|
|
||||||
|
# if we have a pending lookup for this key, remove it from the
|
||||||
|
# _pending_deferred_cache, as above
|
||||||
entry_dict = self._pending_deferred_cache.pop(key, None)
|
entry_dict = self._pending_deferred_cache.pop(key, None)
|
||||||
if entry_dict is not None:
|
if entry_dict is not None:
|
||||||
for entry in iterate_tree_cache_entry(entry_dict):
|
for entry in iterate_tree_cache_entry(entry_dict):
|
||||||
@ -194,8 +204,10 @@ class Cache(object):
|
|||||||
|
|
||||||
def invalidate_all(self):
|
def invalidate_all(self):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
self.sequence += 1
|
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
|
for entry in self._pending_deferred_cache.itervalues():
|
||||||
|
entry.invalidate()
|
||||||
|
self._pending_deferred_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
class _CacheDescriptorBase(object):
|
class _CacheDescriptorBase(object):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2016 OpenMarket Ltd
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2018 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -12,6 +13,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from functools import partial
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import mock
|
import mock
|
||||||
@ -25,6 +27,50 @@ from tests import unittest
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheTestCase(unittest.TestCase):
|
||||||
|
def test_invalidate_all(self):
|
||||||
|
cache = descriptors.Cache("testcache")
|
||||||
|
|
||||||
|
callback_record = [False, False]
|
||||||
|
|
||||||
|
def record_callback(idx):
|
||||||
|
callback_record[idx] = True
|
||||||
|
|
||||||
|
# add a couple of pending entries
|
||||||
|
d1 = defer.Deferred()
|
||||||
|
cache.set("key1", d1, partial(record_callback, 0))
|
||||||
|
|
||||||
|
d2 = defer.Deferred()
|
||||||
|
cache.set("key2", d2, partial(record_callback, 1))
|
||||||
|
|
||||||
|
# lookup should return the deferreds
|
||||||
|
self.assertIs(cache.get("key1"), d1)
|
||||||
|
self.assertIs(cache.get("key2"), d2)
|
||||||
|
|
||||||
|
# let one of the lookups complete
|
||||||
|
d2.callback("result2")
|
||||||
|
self.assertEqual(cache.get("key2"), "result2")
|
||||||
|
|
||||||
|
# now do the invalidation
|
||||||
|
cache.invalidate_all()
|
||||||
|
|
||||||
|
# lookup should return none
|
||||||
|
self.assertIsNone(cache.get("key1", None))
|
||||||
|
self.assertIsNone(cache.get("key2", None))
|
||||||
|
|
||||||
|
# both callbacks should have been callbacked
|
||||||
|
self.assertTrue(
|
||||||
|
callback_record[0], "Invalidation callback for key1 not called",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
callback_record[1], "Invalidation callback for key2 not called",
|
||||||
|
)
|
||||||
|
|
||||||
|
# letting the other lookup complete should do nothing
|
||||||
|
d1.callback("result1")
|
||||||
|
self.assertIsNone(cache.get("key1", None))
|
||||||
|
|
||||||
|
|
||||||
class DescriptorTestCase(unittest.TestCase):
|
class DescriptorTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_cache(self):
|
def test_cache(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user