Take named arguments to @cached() decorator, add a 'max_entries' limit

This commit is contained in:
Paul "LeoNerd" Evans 2015-02-19 18:36:02 +00:00
parent 077d200342
commit ebc3db295b
3 changed files with 114 additions and 16 deletions

View File

@ -39,8 +39,8 @@ transaction_logger = logging.getLogger("synapse.storage.txn")
# * Move this somewhere higher-level, shared; # * Move this somewhere higher-level, shared;
# * more generic key management # * more generic key management
# * export monitoring stats # * export monitoring stats
# * maximum size; just evict things at random, or consider LRU? # * consider other eviction strategies - LRU?
def cached(orig): def cached(max_entries=1000):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
The function is presumed to take one additional argument, which is used as The function is presumed to take one additional argument, which is used as
@ -50,24 +50,33 @@ def cached(orig):
The wrapped function has an additional member, a callable called The wrapped function has an additional member, a callable called
"invalidate". This can be used to remove individual entries from the cache. "invalidate". This can be used to remove individual entries from the cache.
""" """
cache = {} def wrap(orig):
cache = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def wrapped(self, key): def wrapped(self, key):
if key in cache: if key in cache:
defer.returnValue(cache[key]) defer.returnValue(cache[key])
ret = yield orig(self, key) ret = yield orig(self, key)
cache[key] = ret; while len(cache) > max_entries:
defer.returnValue(ret) # TODO(paul): This feels too biased. However, a random index
# would be a bit inefficient, walking the list of keys just
# to ignore most of them?
del cache[cache.keys()[0]]
def invalidate(key): cache[key] = ret;
if key in cache: defer.returnValue(ret)
del cache[key]
wrapped.invalidate = invalidate def invalidate(key):
return wrapped if key in cache:
del cache[key]
wrapped.invalidate = invalidate
return wrapped
return wrap
class LoggingTransaction(object): class LoggingTransaction(object):

View File

@ -247,7 +247,7 @@ class RoomMemberStore(SQLBaseStore):
results = self._parse_events_txn(txn, rows) results = self._parse_events_txn(txn, rows)
return results return results
@cached @cached()
def get_rooms_for_user(self, user_id): def get_rooms_for_user(self, user_id):
return self.get_rooms_for_user_where_membership_is( return self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN], user_id, membership_list=[Membership.JOIN],

View File

@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from twisted.internet import defer
from synapse.storage._base import cached
class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_passthrough(self):
@cached()
def func(self, key):
return key
self.assertEquals((yield func(self, "foo")), "foo")
self.assertEquals((yield func(self, "bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
yield func(self, "foo")
self.assertEquals(callcount[0], 1)
self.assertEquals((yield func(self, "foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
@cached()
def func(self, key):
callcount[0] += 1
return key
yield func(self, "foo")
self.assertEquals(callcount[0], 1)
func.invalidate("foo")
yield func(self, "foo")
self.assertEquals(callcount[0], 2)
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
return key
for k in range(0,12):
yield func(self, k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0,12):
yield func(self, k)
self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0]))