diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 8e1780036..d22db0a0b 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -19,6 +19,8 @@ from twisted.internet import defer from synapse.api.errors import StoreError from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks + logger = logging.getLogger(__name__) @@ -144,6 +146,7 @@ class DeviceStore(SQLBaseStore): defer.returnValue({d["device_id"]: d for d in devices}) + @cached(max_entries=10000) def get_device_list_last_stream_id_for_remote(self, user_id): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. @@ -156,16 +159,36 @@ class DeviceStore(SQLBaseStore): allow_none=True, ) + @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", + list_name="user_ids", inlineCallbacks=True) + def get_device_list_last_stream_id_for_remotes(self, user_ids): + rows = yield self._simple_select_many_batch( + table="device_lists_remote_extremeties", + column="user_id", + iterable=user_ids, + retcols=("user_id", "stream_id",), + desc="get_user_devices_from_cache", + ) + + results = {user_id: None for user_id in user_ids} + results.update({ + row["user_id"]: row["stream_id"] for row in rows + }) + + defer.returnValue(results) + + @defer.inlineCallbacks def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - return self._simple_delete( + yield self._simple_delete( table="device_lists_remote_extremeties", keyvalues={ "user_id": user_id, }, desc="mark_remote_user_device_list_as_unsubscribed", ) + self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) def update_remote_device_list_cache_entry(self, user_id, device_id, content, stream_id): @@ -191,6 +214,12 @@ class DeviceStore(SQLBaseStore): } ) + txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + self._simple_upsert_txn( txn, table="device_lists_remote_extremeties", @@ -234,6 +263,12 @@ class DeviceStore(SQLBaseStore): ] ) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + self._simple_upsert_txn( txn, table="device_lists_remote_extremeties", @@ -320,6 +355,7 @@ class DeviceStore(SQLBaseStore): return (now_stream_id, results) + @defer.inlineCallbacks def get_user_devices_from_cache(self, query_list): """Get the devices (and keys if any) for remote users from the cache. @@ -332,27 +368,11 @@ class DeviceStore(SQLBaseStore): a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info """ - return self.runInteraction( - "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, - query_list, + user_ids = set(user_id for user_id, _ in query_list) + user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_ids_in_cache = set( + user_id for user_id, stream_id in user_map.items() if stream_id ) - - def _get_user_devices_from_cache_txn(self, txn, query_list): - user_ids = {user_id for user_id, _ in query_list} - - user_ids_in_cache = set() - for user_id in user_ids: - stream_ids = self._simple_select_onecol_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={ - "user_id": user_id, - }, - retcol="stream_id", - ) - if stream_ids: - user_ids_in_cache.add(user_id) - user_ids_not_in_cache = user_ids - user_ids_in_cache results = {} @@ -361,32 +381,40 @@ class DeviceStore(SQLBaseStore): continue if device_id: - content = self._simple_select_one_onecol_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - retcol="content", - ) - results.setdefault(user_id, {})[device_id] = json.loads(content) + device = yield self._get_cached_user_device(user_id, device_id) + results.setdefault(user_id, {})[device_id] = device else: - devices = self._simple_select_list_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - }, - retcols=("device_id", "content"), - ) - results[user_id] = { - device["device_id"]: json.loads(device["content"]) - for device in devices - } - user_ids_in_cache.discard(user_id) + results[user_id] = yield self._get_cached_devices_for_user(user_id) - return user_ids_not_in_cache, results + defer.returnValue((user_ids_not_in_cache, results)) + + @cachedInlineCallbacks(num_args=2, tree=True) + def _get_cached_user_device(self, user_id, device_id): + content = yield self._simple_select_one_onecol( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + desc="_get_cached_user_device", + ) + defer.returnValue(json.loads(content)) + + @cachedInlineCallbacks() + def _get_cached_devices_for_user(self, user_id): + devices = yield self._simple_select_list( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + desc="_get_cached_devices_for_user", + ) + defer.returnValue({ + device["device_id"]: json.loads(device["content"]) + for device in devices + }) def get_devices_with_keys_by_user(self, user_id): """Get all devices (with any device keys) for a user