Better return type for get_all_entities_changed (#14604)

Help callers from using the return value incorrectly by ensuring
that callers explicitly check if there was a cache hit or not.
This commit is contained in:
Erik Johnston 2022-12-05 20:19:14 +00:00 committed by GitHub
parent 6a8310f3df
commit cee9445884
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 138 additions and 76 deletions

View file

@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.stream_change_cache import (
AllEntitiesChangedResult,
StreamChangeCache,
)
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@ -799,18 +802,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_cached_device_list_changes(
self,
from_key: int,
) -> Optional[List[str]]:
) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
return self._device_list_stream_cache.get_all_entities_changed(from_key)
@cancellable
async def get_all_devices_changed(
self,
from_key: int,
to_key: int,
) -> Set[str]:
"""Get all users whose devices have changed in the given range.
Args:
from_key: The minimum device lists stream token to query device list
changes for, exclusive.
to_key: The maximum device lists stream token to query device list
changes for, inclusive.
Returns:
The set of user_ids whose devices have changed since `from_key`
(exclusive) until `to_key` (inclusive).
"""
result = self._device_list_stream_cache.get_all_entities_changed(from_key)
if result.hit:
# We know which users might have changed devices.
if not result.entities:
# If no users then we can return early.
return set()
# Otherwise we need to filter down the list
return await self.get_users_whose_devices_changed(
from_key, result.entities, to_key
)
# If the cache didn't tell us anything, we just need to query the full
# range.
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE ? < stream_id AND stream_id <= ?
"""
rows = await self.db_pool.execute(
"get_all_devices_changed",
None,
sql,
from_key,
to_key,
)
return {u for u, in rows}
@cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
user_ids: Optional[Collection[str]] = None,
user_ids: Collection[str],
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
user_ids_to_check: Optional[Collection[str]]
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
from_key
)
else:
# The same as above, but filter results to only those users in 'user_ids'
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
# If an empty set was returned, there's nothing to do.
if user_ids_to_check is not None and not user_ids_to_check:
if not user_ids_to_check:
return set()
if to_key is None:
to_key = self._device_list_id_gen.get_current_token()
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
if to_key:
stream_id_where_clause += " AND stream_id <= ?"
sql_args.append(to_key)
sql = f"""
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE {stream_id_where_clause}
WHERE ? < stream_id AND stream_id <= ? AND %s
"""
# If the stream change cache gave us no information, fetch *all*
# users between the stream IDs.
if user_ids_to_check is None:
txn.execute(sql, sql_args)
return {user_id for user_id, in txn}
changes: Set[str] = set()
# Otherwise, fetch changes for the given users.
else:
changes: Set[str] = set()
# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + " AND " + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql % (clause,), [from_key, to_key] + args)
changes.update(user_id for user_id, in txn)
return changes