Sequence the modifications to the cache so that selects don't race with inserts

This commit is contained in:
Mark Haines 2015-05-05 14:08:03 +01:00
parent d9cc5de9e5
commit 261d809a47

View File

@ -31,6 +31,7 @@ import functools
import simplejson as json import simplejson as json
import sys import sys
import time import time
import threading
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -68,9 +69,20 @@ class Cache(object):
self.name = name self.name = name
self.keylen = keylen self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, *keyargs): def get(self, *keyargs):
if len(keyargs) != self.keylen: if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen) raise ValueError("Expected a key to have %d items", self.keylen)
@ -82,6 +94,11 @@ class Cache(object):
cache_counter.inc_misses(self.name) cache_counter.inc_misses(self.name)
raise KeyError() raise KeyError()
def update(self, sequence, *args):
self.check_thread()
if self.sequence == sequence:
self.prefill(*args)
def prefill(self, *args): # because I can't *keyargs, value def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1] keyargs = args[:-1]
value = args[-1] value = args[-1]
@ -96,9 +113,10 @@ class Cache(object):
self.cache[keyargs] = value self.cache[keyargs] = value
def invalidate(self, *keyargs): def invalidate(self, *keyargs):
self.check_thread()
if len(keyargs) != self.keylen: if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen) raise ValueError("Expected a key to have %d items", self.keylen)
self.sequence += 1
self.cache.pop(keyargs, None) self.cache.pop(keyargs, None)
@ -130,9 +148,11 @@ def cached(max_entries=1000, num_args=1, lru=False):
try: try:
defer.returnValue(cache.get(*keyargs)) defer.returnValue(cache.get(*keyargs))
except KeyError: except KeyError:
sequence = cache.sequence
ret = yield orig(self, *keyargs) ret = yield orig(self, *keyargs)
cache.prefill(*keyargs + (ret,)) cache.update(sequence, *keyargs + (ret,))
defer.returnValue(ret) defer.returnValue(ret)