Minor @cachedList enhancements (#9975)

- use a tuple rather than a list for the iterable that is passed into the
  wrapped function, for performance

- test that we can pass an iterable and that keys are correctly deduped.
This commit is contained in:
Richard van der Hoff 2021-05-14 11:12:36 +01:00 committed by GitHub
parent 52ed9655ed
commit 5090f26b63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 31 additions and 20 deletions

View file

@ -665,7 +665,7 @@ class DeviceWorkerStore(SQLBaseStore):
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",

View file

@ -473,7 +473,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
num_args=1,
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
self, user_ids: Iterable[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
@ -497,7 +497,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
txn: Connection,
user_ids: List[str],
user_ids: Iterable[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@ -37,21 +39,16 @@ class UserErasureWorkerStore(SQLBaseStore):
return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids):
async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
"""
Checks which users in a list have requested erasure
Args:
user_ids (iterable[str]): full user id to check
user_ids: full user ids to check
Returns:
dict[str, bool]:
for each user, whether the user has requested erasure.
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",