Cache get_user_devices_from_cache

This commit is contained in:
Erik Johnston 2017-02-27 16:22:12 +00:00
parent b7442c3e2b
commit f58dbb02a6

View File

@ -19,6 +19,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -144,6 +146,7 @@ class DeviceStore(SQLBaseStore):
defer.returnValue({d["device_id"]: d for d in devices}) defer.returnValue({d["device_id"]: d for d in devices})
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id): def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't """Get the last stream_id we got for a user. May be None if we haven't
got any information for them. got any information for them.
@ -156,16 +159,36 @@ class DeviceStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", inlineCallbacks=True)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id",),
desc="get_user_devices_from_cache",
)
results = {user_id: None for user_id in user_ids}
results.update({
row["user_id"]: row["stream_id"] for row in rows
})
defer.returnValue(results)
@defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id): def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user. """Mark that we no longer track device lists for remote user.
""" """
return self._simple_delete( yield self._simple_delete(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
}, },
desc="mark_remote_user_device_list_as_unsubscribed", desc="mark_remote_user_device_list_as_unsubscribed",
) )
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
def update_remote_device_list_cache_entry(self, user_id, device_id, content, def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id): stream_id):
@ -191,6 +214,12 @@ class DeviceStore(SQLBaseStore):
} }
) )
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
@ -234,6 +263,12 @@ class DeviceStore(SQLBaseStore):
] ]
) )
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
@ -320,6 +355,7 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, results) return (now_stream_id, results)
@defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list): def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache. """Get the devices (and keys if any) for remote users from the cache.
@ -332,27 +368,11 @@ class DeviceStore(SQLBaseStore):
a set of user_ids and results_map is a mapping of a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info user_id -> device_id -> device_info
""" """
return self.runInteraction( user_ids = set(user_id for user_id, _ in query_list)
"get_user_devices_from_cache", self._get_user_devices_from_cache_txn, user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
query_list, user_ids_in_cache = set(
user_id for user_id, stream_id in user_map.items() if stream_id
) )
def _get_user_devices_from_cache_txn(self, txn, query_list):
user_ids = {user_id for user_id, _ in query_list}
user_ids_in_cache = set()
for user_id in user_ids:
stream_ids = self._simple_select_onecol_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
retcol="stream_id",
)
if stream_ids:
user_ids_in_cache.add(user_id)
user_ids_not_in_cache = user_ids - user_ids_in_cache user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {} results = {}
@ -361,32 +381,40 @@ class DeviceStore(SQLBaseStore):
continue continue
if device_id: if device_id:
content = self._simple_select_one_onecol_txn( device = yield self._get_cached_user_device(user_id, device_id)
txn, results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = yield self._get_cached_devices_for_user(user_id)
defer.returnValue((user_ids_not_in_cache, results))
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
}, },
retcol="content", retcol="content",
desc="_get_cached_user_device",
) )
results.setdefault(user_id, {})[device_id] = json.loads(content) defer.returnValue(json.loads(content))
else:
devices = self._simple_select_list_txn( @cachedInlineCallbacks()
txn, def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
}, },
retcols=("device_id", "content"), retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
) )
results[user_id] = { defer.returnValue({
device["device_id"]: json.loads(device["content"]) device["device_id"]: json.loads(device["content"])
for device in devices for device in devices
} })
user_ids_in_cache.discard(user_id)
return user_ids_not_in_cache, results
def get_devices_with_keys_by_user(self, user_id): def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user """Get all devices (with any device keys) for a user