Type annotations in synapse.databases.main.devices (#13025)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
This commit is contained in:
David Robertson 2022-06-15 16:20:04 +01:00 committed by GitHub
parent 0d1d3e0708
commit 97e9fbe1b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 21 deletions

1
changelog.d/13025.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations to `synapse.storage.databases.main.devices`.

View File

@ -27,7 +27,6 @@ exclude = (?x)
^(
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/schema/
|tests/api/test_auth.py

View File

@ -19,13 +19,12 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
if TYPE_CHECKING:
from synapse.server import HomeServer
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
def __init__(
self,
database: DatabasePool,

View File

@ -195,6 +195,7 @@ class DataStore(
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
def get_device_stream_token(self) -> int:
# TODO: shouldn't this be moved to `DeviceWorkerStore`?
return self._device_list_id_gen.get_current_token()
async def get_users(self) -> List[JsonDict]:

View File

@ -28,6 +28,8 @@ from typing import (
cast,
)
from typing_extensions import Literal
from synapse.api.constants import EduTypes
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@ -44,6 +46,8 @@ from synapse.storage.database import (
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -65,7 +69,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
class DeviceWorkerStore(EndToEndKeyWorkerStore):
def __init__(
self,
database: DatabasePool,
@ -74,7 +78,9 @@ class DeviceWorkerStore(SQLBaseStore):
):
super().__init__(database, db_conn, hs)
device_list_max = self._device_list_id_gen.get_current_token()
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_stream",
@ -339,8 +345,9 @@ class DeviceWorkerStore(SQLBaseStore):
# following this stream later.
last_processed_stream_id = from_stream_id
query_map = {}
cross_signing_keys_by_user = {}
# A map of (user ID, device ID) to (stream ID, context).
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
for user_id, device_id, update_stream_id, update_context in updates:
# Calculate the remaining length budget.
# Note that, for now, each entry in `cross_signing_keys_by_user`
@ -596,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn=txn,
table="device_lists_outbound_last_success",
key_names=("destination", "user_id"),
key_values=((destination, user_id) for user_id, _ in rows),
key_values=[(destination, user_id) for user_id, _ in rows],
value_names=("stream_id",),
value_values=((stream_id,) for _, stream_id in rows),
)
@ -621,7 +628,9 @@ class DeviceWorkerStore(SQLBaseStore):
The new stream ID.
"""
async with self._device_list_id_gen.get_next() as stream_id:
# TODO: this looks like it's _writing_. Should this be on DeviceStore rather
# than DeviceWorkerStore?
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@ -686,7 +695,7 @@ class DeviceWorkerStore(SQLBaseStore):
} - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
@ -727,7 +736,7 @@ class DeviceWorkerStore(SQLBaseStore):
def get_cached_device_list_changes(
self,
from_key: int,
) -> Optional[Set[str]]:
) -> Optional[List[str]]:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
@ -737,7 +746,7 @@ class DeviceWorkerStore(SQLBaseStore):
async def get_users_whose_devices_changed(
self,
from_key: int,
user_ids: Optional[Iterable[str]] = None,
user_ids: Optional[Collection[str]] = None,
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
@ -757,6 +766,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
user_ids_to_check: Optional[Collection[str]]
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
@ -772,7 +782,7 @@ class DeviceWorkerStore(SQLBaseStore):
return set()
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
changes = set()
changes: Set[str] = set()
stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
@ -788,6 +798,9 @@ class DeviceWorkerStore(SQLBaseStore):
"""
# Query device changes with a batch of users at a time
# Assertion for mypy's benefit; see also
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
assert user_ids_to_check is not None
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
@ -854,7 +867,9 @@ class DeviceWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
def _get_all_device_list_changes_for_remotes(txn):
def _get_all_device_list_changes_for_remotes(
txn: Cursor,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
@ -913,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_device_list_last_stream_id_for_remotes",
)
results = {user_id: None for user_id in user_ids}
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})
return results
@ -1337,9 +1352,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = LruCache(
cache_name="device_id_exists", max_size=10000
)
self.device_id_exists_cache: LruCache[
Tuple[str, str], Literal[True]
] = LruCache(cache_name="device_id_exists", max_size=10000)
async def store_device(
self,
@ -1651,7 +1666,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context,
)
async with self._device_list_id_gen.get_next_mult(
async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@ -1704,7 +1719,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_ids: Iterable[str],
hosts: Collection[str],
stream_ids: List[int],
context: Dict[str, str],
context: Optional[Dict[str, str]],
) -> None:
for host in hosts:
txn.call_after(
@ -1875,7 +1890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[],
)
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,