mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-07-26 21:15:18 -04:00
Better return type for get_all_entities_changed
(#14604)
Help callers from using the return value incorrectly by ensuring that callers explicitly check if there was a cache hit or not.
This commit is contained in:
parent
6a8310f3df
commit
cee9445884
8 changed files with 138 additions and 76 deletions
|
@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
|
|||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.caches.stream_change_cache import (
|
||||
AllEntitiesChangedResult,
|
||||
StreamChangeCache,
|
||||
)
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.stringutils import shortstr
|
||||
|
@ -799,18 +802,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
def get_cached_device_list_changes(
|
||||
self,
|
||||
from_key: int,
|
||||
) -> Optional[List[str]]:
|
||||
) -> AllEntitiesChangedResult:
|
||||
"""Get set of users whose devices have changed since `from_key`, or None
|
||||
if that information is not in our cache.
|
||||
"""
|
||||
|
||||
return self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||
|
||||
@cancellable
|
||||
async def get_all_devices_changed(
|
||||
self,
|
||||
from_key: int,
|
||||
to_key: int,
|
||||
) -> Set[str]:
|
||||
"""Get all users whose devices have changed in the given range.
|
||||
|
||||
Args:
|
||||
from_key: The minimum device lists stream token to query device list
|
||||
changes for, exclusive.
|
||||
to_key: The maximum device lists stream token to query device list
|
||||
changes for, inclusive.
|
||||
|
||||
Returns:
|
||||
The set of user_ids whose devices have changed since `from_key`
|
||||
(exclusive) until `to_key` (inclusive).
|
||||
"""
|
||||
|
||||
result = self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||
|
||||
if result.hit:
|
||||
# We know which users might have changed devices.
|
||||
if not result.entities:
|
||||
# If no users then we can return early.
|
||||
return set()
|
||||
|
||||
# Otherwise we need to filter down the list
|
||||
return await self.get_users_whose_devices_changed(
|
||||
from_key, result.entities, to_key
|
||||
)
|
||||
|
||||
# If the cache didn't tell us anything, we just need to query the full
|
||||
# range.
|
||||
sql = """
|
||||
SELECT DISTINCT user_id FROM device_lists_stream
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.execute(
|
||||
"get_all_devices_changed",
|
||||
None,
|
||||
sql,
|
||||
from_key,
|
||||
to_key,
|
||||
)
|
||||
return {u for u, in rows}
|
||||
|
||||
@cancellable
|
||||
async def get_users_whose_devices_changed(
|
||||
self,
|
||||
from_key: int,
|
||||
user_ids: Optional[Collection[str]] = None,
|
||||
user_ids: Collection[str],
|
||||
to_key: Optional[int] = None,
|
||||
) -> Set[str]:
|
||||
"""Get set of users whose devices have changed since `from_key` that
|
||||
|
@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
|||
"""
|
||||
# Get set of users who *may* have changed. Users not in the returned
|
||||
# list have definitely not changed.
|
||||
user_ids_to_check: Optional[Collection[str]]
|
||||
if user_ids is None:
|
||||
# Get set of all users that have had device list changes since 'from_key'
|
||||
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
|
||||
from_key
|
||||
)
|
||||
else:
|
||||
# The same as above, but filter results to only those users in 'user_ids'
|
||||
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
|
||||
user_ids, from_key
|
||||
)
|
||||
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
|
||||
user_ids, from_key
|
||||
)
|
||||
|
||||
# If an empty set was returned, there's nothing to do.
|
||||
if user_ids_to_check is not None and not user_ids_to_check:
|
||||
if not user_ids_to_check:
|
||||
return set()
|
||||
|
||||
if to_key is None:
|
||||
to_key = self._device_list_id_gen.get_current_token()
|
||||
|
||||
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
|
||||
stream_id_where_clause = "stream_id > ?"
|
||||
sql_args = [from_key]
|
||||
|
||||
if to_key:
|
||||
stream_id_where_clause += " AND stream_id <= ?"
|
||||
sql_args.append(to_key)
|
||||
|
||||
sql = f"""
|
||||
sql = """
|
||||
SELECT DISTINCT user_id FROM device_lists_stream
|
||||
WHERE {stream_id_where_clause}
|
||||
WHERE ? < stream_id AND stream_id <= ? AND %s
|
||||
"""
|
||||
|
||||
# If the stream change cache gave us no information, fetch *all*
|
||||
# users between the stream IDs.
|
||||
if user_ids_to_check is None:
|
||||
txn.execute(sql, sql_args)
|
||||
return {user_id for user_id, in txn}
|
||||
changes: Set[str] = set()
|
||||
|
||||
# Otherwise, fetch changes for the given users.
|
||||
else:
|
||||
changes: Set[str] = set()
|
||||
|
||||
# Query device changes with a batch of users at a time
|
||||
for chunk in batch_iter(user_ids_to_check, 100):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "user_id", chunk
|
||||
)
|
||||
txn.execute(sql + " AND " + clause, sql_args + args)
|
||||
changes.update(user_id for user_id, in txn)
|
||||
# Query device changes with a batch of users at a time
|
||||
for chunk in batch_iter(user_ids_to_check, 100):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "user_id", chunk
|
||||
)
|
||||
txn.execute(sql % (clause,), [from_key, to_key] + args)
|
||||
changes.update(user_id for user_id, in txn)
|
||||
|
||||
return changes
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue