Use batch_iter and correct docstring

This commit is contained in:
Erik Johnston 2019-06-26 19:10:38 +01:00
parent 806a06daf2
commit f335e77d53

View File

@ -24,6 +24,7 @@ from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import Cache, SQLBaseStore, db_to_json from synapse.storage._base import Cache, SQLBaseStore, db_to_json
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -396,8 +397,8 @@ class DeviceWorkerStore(SQLBaseStore):
are in the given list of user_ids. are in the given list of user_ids.
Args: Args:
from_key (str): The device lists stream token
user_ids (Iterable[str]) user_ids (Iterable[str])
from_key: The device lists stream token
Returns: Returns:
Deferred[set[str]]: The set of user_ids whose devices have changed Deferred[set[str]]: The set of user_ids whose devices have changed
@ -414,23 +415,19 @@ class DeviceWorkerStore(SQLBaseStore):
if not to_check: if not to_check:
return defer.succeed(set()) return defer.succeed(set())
# We now check the database for all users in `to_check`, in batches.
batch_size = 100
chunks = [
to_check[i : i + batch_size] for i in range(0, len(to_check), batch_size)
]
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ?
AND user_id IN (%s)
"""
def _get_users_whose_devices_changed_txn(txn): def _get_users_whose_devices_changed_txn(txn):
changes = set() changes = set()
for chunk in chunks: sql = """
txn.execute(sql % (",".join("?" for _ in chunk),), [from_key] + chunk) SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ?
AND user_id IN (%s)
"""
for chunk in batch_iter(to_check, 100):
txn.execute(
sql % (",".join("?" for _ in chunk),), [from_key] + list(chunk)
)
changes.update(user_id for user_id, in txn) changes.update(user_id for user_id, in txn)
return changes return changes