mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Merge pull request #2075 from matrix-org/erikj/cache_speed
Speed up cached function access
This commit is contained in:
commit
9cee0ce7db
@ -17,15 +17,12 @@ from twisted.internet import defer
|
||||
from synapse.push.presentable_names import (
|
||||
calculate_room_name, name_from_member_event
|
||||
)
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_badge_count(store, user_id):
|
||||
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(store.get_invited_rooms_for_user)(user_id),
|
||||
preserve_fn(store.get_rooms_for_user)(user_id),
|
||||
], consumeErrors=True))
|
||||
invites = yield store.get_invited_rooms_for_user(user_id)
|
||||
joins = yield store.get_rooms_for_user(user_id)
|
||||
|
||||
my_receipts_by_room = yield store.get_receipts_for_user(
|
||||
user_id, "m.read",
|
||||
|
@ -89,6 +89,11 @@ class ObservableDeferred(object):
|
||||
deferred.addCallbacks(callback, errback)
|
||||
|
||||
def observe(self):
|
||||
"""Observe the underlying deferred.
|
||||
|
||||
Can return either a deferred if the underlying deferred is still pending
|
||||
(or has failed), or the actual value. Callers may need to use maybeDeferred.
|
||||
"""
|
||||
if not self._result:
|
||||
d = defer.Deferred()
|
||||
|
||||
@ -101,7 +106,7 @@ class ObservableDeferred(object):
|
||||
return d
|
||||
else:
|
||||
success, res = self._result
|
||||
return defer.succeed(res) if success else defer.fail(res)
|
||||
return res if success else defer.fail(res)
|
||||
|
||||
def observers(self):
|
||||
return self._observers
|
||||
|
@ -224,8 +224,20 @@ class _CacheDescriptorBase(object):
|
||||
)
|
||||
|
||||
self.num_args = num_args
|
||||
|
||||
# list of the names of the args used as the cache key
|
||||
self.arg_names = all_args[1:num_args + 1]
|
||||
|
||||
# self.arg_defaults is a map of arg name to its default value for each
|
||||
# argument that has a default value
|
||||
if arg_spec.defaults:
|
||||
self.arg_defaults = dict(zip(
|
||||
all_args[-len(arg_spec.defaults):],
|
||||
arg_spec.defaults
|
||||
))
|
||||
else:
|
||||
self.arg_defaults = {}
|
||||
|
||||
if "cache_context" in self.arg_names:
|
||||
raise Exception(
|
||||
"cache_context arg cannot be included among the cache keys"
|
||||
@ -289,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||
iterable=self.iterable,
|
||||
)
|
||||
|
||||
def get_cache_key(args, kwargs):
|
||||
"""Given some args/kwargs return a generator that resolves into
|
||||
the cache_key.
|
||||
|
||||
We loop through each arg name, looking up if its in the `kwargs`,
|
||||
otherwise using the next argument in `args`. If there are no more
|
||||
args then we try looking the arg name up in the defaults
|
||||
"""
|
||||
pos = 0
|
||||
for nm in self.arg_names:
|
||||
if nm in kwargs:
|
||||
yield kwargs[nm]
|
||||
elif pos < len(args):
|
||||
yield args[pos]
|
||||
pos += 1
|
||||
else:
|
||||
yield self.arg_defaults[nm]
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def wrapped(*args, **kwargs):
|
||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||
# whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
|
||||
# Add temp cache_context so inspect.getcallargs doesn't explode
|
||||
if self.add_cache_context:
|
||||
kwargs["cache_context"] = None
|
||||
|
||||
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||
cache_key = tuple(get_cache_key(args, kwargs))
|
||||
|
||||
# Add our own `cache_context` to argument list if the wrapped function
|
||||
# has asked for one
|
||||
@ -341,7 +366,10 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||
cache.set(cache_key, result_d, callback=invalidate_callback)
|
||||
observer = result_d.observe()
|
||||
|
||||
if isinstance(observer, defer.Deferred):
|
||||
return logcontext.make_deferred_yieldable(observer)
|
||||
else:
|
||||
return observer
|
||||
|
||||
wrapped.invalidate = cache.invalidate
|
||||
wrapped.invalidate_all = cache.invalidate_all
|
||||
|
@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
|
||||
events ([synapse.events.EventBase]): list of events to filter
|
||||
"""
|
||||
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(store.who_forgot_in_room)(
|
||||
defer.maybeDeferred(
|
||||
preserve_fn(store.who_forgot_in_room),
|
||||
room_id,
|
||||
)
|
||||
for room_id in frozenset(e.room_id for e in events)
|
||||
|
@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
|
||||
|
||||
a.func.prefill(("foo",), ObservableDeferred(d))
|
||||
|
||||
self.assertEquals(a.func("foo").result, d.result)
|
||||
self.assertEquals(a.func("foo"), d.result)
|
||||
self.assertEquals(callcount[0], 0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -175,3 +175,41 @@ class DescriptorTestCase(unittest.TestCase):
|
||||
logcontext.LoggingContext.sentinel)
|
||||
|
||||
return d1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_cache_default_args(self):
|
||||
class Cls(object):
|
||||
def __init__(self):
|
||||
self.mock = mock.Mock()
|
||||
|
||||
@descriptors.cached()
|
||||
def fn(self, arg1, arg2=2, arg3=3):
|
||||
return self.mock(arg1, arg2, arg3)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
obj.mock.return_value = 'fish'
|
||||
r = yield obj.fn(1, 2, 3)
|
||||
self.assertEqual(r, 'fish')
|
||||
obj.mock.assert_called_once_with(1, 2, 3)
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# a call with same params shouldn't call the mock again
|
||||
r = yield obj.fn(1, 2)
|
||||
self.assertEqual(r, 'fish')
|
||||
obj.mock.assert_not_called()
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# a call with different params should call the mock again
|
||||
obj.mock.return_value = 'chips'
|
||||
r = yield obj.fn(2, 3)
|
||||
self.assertEqual(r, 'chips')
|
||||
obj.mock.assert_called_once_with(2, 3, 3)
|
||||
obj.mock.reset_mock()
|
||||
|
||||
# the two values should now be cached
|
||||
r = yield obj.fn(1, 2)
|
||||
self.assertEqual(r, 'fish')
|
||||
r = yield obj.fn(2, 3)
|
||||
self.assertEqual(r, 'chips')
|
||||
obj.mock.assert_not_called()
|
||||
|
@ -53,6 +53,8 @@ class SnapshotCacheTestCase(unittest.TestCase):
|
||||
# before the cache expires returns a resolved deferred.
|
||||
get_result_at_11 = self.cache.get(11, "key")
|
||||
self.assertIsNotNone(get_result_at_11)
|
||||
if isinstance(get_result_at_11, Deferred):
|
||||
# The cache may return the actual result rather than a deferred
|
||||
self.assertTrue(get_result_at_11.called)
|
||||
|
||||
# Check that getting the key after the deferred has resolved
|
||||
|
Loading…
Reference in New Issue
Block a user