mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Merge pull request #209 from matrix-org/erikj/cached_keyword_args
Add support for using keyword arguments with cached functions
This commit is contained in:
commit
8049c9a71e
@ -27,6 +27,7 @@ from twisted.internet import defer
|
|||||||
from collections import namedtuple, OrderedDict
|
from collections import namedtuple, OrderedDict
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import inspect
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
@ -141,13 +142,28 @@ class CacheDescriptor(object):
|
|||||||
which can be used to insert values into the cache specifically, without
|
which can be used to insert values into the cache specifically, without
|
||||||
calling the calculation function.
|
calling the calculation function.
|
||||||
"""
|
"""
|
||||||
def __init__(self, orig, max_entries=1000, num_args=1, lru=True):
|
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
|
||||||
|
inlineCallbacks=False):
|
||||||
self.orig = orig
|
self.orig = orig
|
||||||
|
|
||||||
|
if inlineCallbacks:
|
||||||
|
self.function_to_call = defer.inlineCallbacks(orig)
|
||||||
|
else:
|
||||||
|
self.function_to_call = orig
|
||||||
|
|
||||||
self.max_entries = max_entries
|
self.max_entries = max_entries
|
||||||
self.num_args = num_args
|
self.num_args = num_args
|
||||||
self.lru = lru
|
self.lru = lru
|
||||||
|
|
||||||
|
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
|
||||||
|
|
||||||
|
if len(self.arg_names) < self.num_args:
|
||||||
|
raise Exception(
|
||||||
|
"Not enough explicit positional arguments to key off of for %r."
|
||||||
|
" (@cached cannot key off of *args or **kwars)"
|
||||||
|
% (orig.__name__,)
|
||||||
|
)
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
def __get__(self, obj, objtype=None):
|
||||||
cache = Cache(
|
cache = Cache(
|
||||||
name=self.orig.__name__,
|
name=self.orig.__name__,
|
||||||
@ -158,11 +174,13 @@ class CacheDescriptor(object):
|
|||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wrapped(*keyargs):
|
def wrapped(*args, **kwargs):
|
||||||
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
|
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
||||||
try:
|
try:
|
||||||
cached_result = cache.get(*keyargs[:self.num_args])
|
cached_result = cache.get(*keyargs)
|
||||||
if DEBUG_CACHES:
|
if DEBUG_CACHES:
|
||||||
actual_result = yield self.orig(obj, *keyargs)
|
actual_result = yield self.function_to_call(obj, *args, **kwargs)
|
||||||
if actual_result != cached_result:
|
if actual_result != cached_result:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Stale cache entry %s%r: cached: %r, actual %r",
|
"Stale cache entry %s%r: cached: %r, actual %r",
|
||||||
@ -177,9 +195,9 @@ class CacheDescriptor(object):
|
|||||||
# while the SELECT is executing (SYN-369)
|
# while the SELECT is executing (SYN-369)
|
||||||
sequence = cache.sequence
|
sequence = cache.sequence
|
||||||
|
|
||||||
ret = yield self.orig(obj, *keyargs)
|
ret = yield self.function_to_call(obj, *args, **kwargs)
|
||||||
|
|
||||||
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
|
cache.update(sequence, *(keyargs + [ret]))
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
@ -201,6 +219,16 @@ def cached(max_entries=1000, num_args=1, lru=True):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
|
||||||
|
return lambda orig: CacheDescriptor(
|
||||||
|
orig,
|
||||||
|
max_entries=max_entries,
|
||||||
|
num_args=num_args,
|
||||||
|
lru=lru,
|
||||||
|
inlineCallbacks=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoggingTransaction(object):
|
class LoggingTransaction(object):
|
||||||
"""An object that almost-transparently proxies for the 'txn' object
|
"""An object that almost-transparently proxies for the 'txn' object
|
||||||
passed to the constructor. Adds logging and metrics to the .execute()
|
passed to the constructor. Adds logging and metrics to the .execute()
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from _base import SQLBaseStore, cached
|
from _base import SQLBaseStore, cachedInlineCallbacks
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -71,8 +71,7 @@ class KeyStore(SQLBaseStore):
|
|||||||
desc="store_server_certificate",
|
desc="store_server_certificate",
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_all_server_verify_keys(self, server_name):
|
def get_all_server_verify_keys(self, server_name):
|
||||||
rows = yield self._simple_select_list(
|
rows = yield self._simple_select_list(
|
||||||
table="server_signature_keys",
|
table="server_signature_keys",
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class PushRuleStore(SQLBaseStore):
|
class PushRuleStore(SQLBaseStore):
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_push_rules_for_user(self, user_name):
|
def get_push_rules_for_user(self, user_name):
|
||||||
rows = yield self._simple_select_list(
|
rows = yield self._simple_select_list(
|
||||||
table=PushRuleTable.table_name,
|
table=PushRuleTable.table_name,
|
||||||
@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue(rows)
|
defer.returnValue(rows)
|
||||||
|
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_push_rules_enabled_for_user(self, user_name):
|
def get_push_rules_enabled_for_user(self, user_name):
|
||||||
results = yield self._simple_select_list(
|
results = yield self._simple_select_list(
|
||||||
table=PushRuleEnableTable.table_name,
|
table=PushRuleEnableTable.table_name,
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -128,8 +128,7 @@ class ReceiptsStore(SQLBaseStore):
|
|||||||
def get_max_receipt_stream_id(self):
|
def get_max_receipt_stream_id(self):
|
||||||
return self._receipts_id_gen.get_max_token(self)
|
return self._receipts_id_gen.get_max_token(self)
|
||||||
|
|
||||||
@cached
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_graph_receipts_for_room(self, room_id):
|
def get_graph_receipts_for_room(self, room_id):
|
||||||
"""Get receipts for sending to remote servers.
|
"""Get receipts for sending to remote servers.
|
||||||
"""
|
"""
|
||||||
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore, cachedInlineCallbacks
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_room_name_and_aliases(self, room_id):
|
def get_room_name_and_aliases(self, room_id):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
sql = (
|
sql = (
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore, cached, cachedInlineCallbacks
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -189,8 +189,7 @@ class StateStore(SQLBaseStore):
|
|||||||
events = yield self._get_events(event_ids, get_prev_content=False)
|
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
@cached(num_args=3)
|
@cachedInlineCallbacks(num_args=3)
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_current_state_for_key(self, room_id, event_type, state_key):
|
def get_current_state_for_key(self, room_id, event_type, state_key):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
sql = (
|
sql = (
|
||||||
|
Loading…
Reference in New Issue
Block a user