Type annotations in synapse.databases.main.devices (#13025)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
David Robertson 2022-06-15 16:20:04 +01:00 committed by GitHub
parent 0d1d3e0708
commit 97e9fbe1b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 21 deletions

1
changelog.d/13025.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations to `synapse.storage.databases.main.devices`.

View File

@ -27,7 +27,6 @@ exclude = (?x)
^( ^(
|synapse/storage/databases/__init__.py |synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/schema/ |synapse/storage/schema/
|tests/api/test_auth.py |tests/api/test_auth.py

View File

@ -19,13 +19,12 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,

View File

@ -195,6 +195,7 @@ class DataStore(
self._min_stream_order_on_start = self.get_room_min_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering()
def get_device_stream_token(self) -> int: def get_device_stream_token(self) -> int:
# TODO: shouldn't this be moved to `DeviceWorkerStore`?
return self._device_list_id_gen.get_current_token() return self._device_list_id_gen.get_current_token()
async def get_users(self) -> List[JsonDict]: async def get_users(self) -> List[JsonDict]:

View File

@ -28,6 +28,8 @@ from typing import (
cast, cast,
) )
from typing_extensions import Literal
from synapse.api.constants import EduTypes from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
@ -44,6 +46,8 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
make_tuple_comparison_clause, make_tuple_comparison_clause,
) )
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -65,7 +69,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore): class DeviceWorkerStore(EndToEndKeyWorkerStore):
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@ -74,7 +78,9 @@ class DeviceWorkerStore(SQLBaseStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
device_list_max = self._device_list_id_gen.get_current_token() # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict( device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn, db_conn,
"device_lists_stream", "device_lists_stream",
@ -339,8 +345,9 @@ class DeviceWorkerStore(SQLBaseStore):
# following this stream later. # following this stream later.
last_processed_stream_id = from_stream_id last_processed_stream_id = from_stream_id
query_map = {} # A map of (user ID, device ID) to (stream ID, context).
cross_signing_keys_by_user = {} query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
for user_id, device_id, update_stream_id, update_context in updates: for user_id, device_id, update_stream_id, update_context in updates:
# Calculate the remaining length budget. # Calculate the remaining length budget.
# Note that, for now, each entry in `cross_signing_keys_by_user` # Note that, for now, each entry in `cross_signing_keys_by_user`
@ -596,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn=txn, txn=txn,
table="device_lists_outbound_last_success", table="device_lists_outbound_last_success",
key_names=("destination", "user_id"), key_names=("destination", "user_id"),
key_values=((destination, user_id) for user_id, _ in rows), key_values=[(destination, user_id) for user_id, _ in rows],
value_names=("stream_id",), value_names=("stream_id",),
value_values=((stream_id,) for _, stream_id in rows), value_values=((stream_id,) for _, stream_id in rows),
) )
@ -621,7 +628,9 @@ class DeviceWorkerStore(SQLBaseStore):
The new stream ID. The new stream ID.
""" """
async with self._device_list_id_gen.get_next() as stream_id: # TODO: this looks like it's _writing_. Should this be on DeviceStore rather
# than DeviceWorkerStore?
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_user_sig_change_to_streams", "add_user_sig_change_to_streams",
self._add_user_signature_change_txn, self._add_user_signature_change_txn,
@ -686,7 +695,7 @@ class DeviceWorkerStore(SQLBaseStore):
} - users_needing_resync } - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {} results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in query_list: for user_id, device_id in query_list:
if user_id not in user_ids_in_cache: if user_id not in user_ids_in_cache:
continue continue
@ -727,7 +736,7 @@ class DeviceWorkerStore(SQLBaseStore):
def get_cached_device_list_changes( def get_cached_device_list_changes(
self, self,
from_key: int, from_key: int,
) -> Optional[Set[str]]: ) -> Optional[List[str]]:
"""Get set of users whose devices have changed since `from_key`, or None """Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache. if that information is not in our cache.
""" """
@ -737,7 +746,7 @@ class DeviceWorkerStore(SQLBaseStore):
async def get_users_whose_devices_changed( async def get_users_whose_devices_changed(
self, self,
from_key: int, from_key: int,
user_ids: Optional[Iterable[str]] = None, user_ids: Optional[Collection[str]] = None,
to_key: Optional[int] = None, to_key: Optional[int] = None,
) -> Set[str]: ) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that """Get set of users whose devices have changed since `from_key` that
@ -757,6 +766,7 @@ class DeviceWorkerStore(SQLBaseStore):
""" """
# Get set of users who *may* have changed. Users not in the returned # Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed. # list have definitely not changed.
user_ids_to_check: Optional[Collection[str]]
if user_ids is None: if user_ids is None:
# Get set of all users that have had device list changes since 'from_key' # Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
@ -772,7 +782,7 @@ class DeviceWorkerStore(SQLBaseStore):
return set() return set()
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
changes = set() changes: Set[str] = set()
stream_id_where_clause = "stream_id > ?" stream_id_where_clause = "stream_id > ?"
sql_args = [from_key] sql_args = [from_key]
@ -788,6 +798,9 @@ class DeviceWorkerStore(SQLBaseStore):
""" """
# Query device changes with a batch of users at a time # Query device changes with a batch of users at a time
# Assertion for mypy's benefit; see also
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
assert user_ids_to_check is not None
for chunk in batch_iter(user_ids_to_check, 100): for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk txn.database_engine, "user_id", chunk
@ -854,7 +867,9 @@ class DeviceWorkerStore(SQLBaseStore):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
def _get_all_device_list_changes_for_remotes(txn): def _get_all_device_list_changes_for_remotes(
txn: Cursor,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# This query Does The Right Thing where it'll correctly apply the # This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries. # bounds to the inner queries.
sql = """ sql = """
@ -913,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_device_list_last_stream_id_for_remotes", desc="get_device_list_last_stream_id_for_remotes",
) )
results = {user_id: None for user_id in user_ids} results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows}) results.update({row["user_id"]: row["stream_id"] for row in rows})
return results return results
@ -1337,9 +1352,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists. # the device exists.
self.device_id_exists_cache = LruCache( self.device_id_exists_cache: LruCache[
cache_name="device_id_exists", max_size=10000 Tuple[str, str], Literal[True]
) ] = LruCache(cache_name="device_id_exists", max_size=10000)
async def store_device( async def store_device(
self, self,
@ -1651,7 +1666,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context, context,
) )
async with self._device_list_id_gen.get_next_mult( async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
len(device_ids) len(device_ids)
) as stream_ids: ) as stream_ids:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -1704,7 +1719,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_ids: Iterable[str], device_ids: Iterable[str],
hosts: Collection[str], hosts: Collection[str],
stream_ids: List[int], stream_ids: List[int],
context: Dict[str, str], context: Optional[Dict[str, str]],
) -> None: ) -> None:
for host in hosts: for host in hosts:
txn.call_after( txn.call_after(
@ -1875,7 +1890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[], [],
) )
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes", "add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn, add_device_list_outbound_pokes_txn,