Add some type hints to datastore (#12485)

This commit is contained in:
Dirk Klimpel 2022-04-27 14:05:00 +02:00 committed by GitHub
parent 63ba9ba38b
commit b76f1a4d5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 188 additions and 84 deletions

View file

@ -25,6 +25,7 @@ from typing import (
Optional,
Set,
Tuple,
cast,
)
from synapse.api.errors import Codes, StoreError
@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore):
Number of devices of this users.
"""
def count_devices_by_users_txn(txn, user_ids):
def count_devices_by_users_txn(
txn: LoggingTransaction, user_ids: List[str]
) -> int:
sql = """
SELECT count(*)
FROM devices
@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
txn.execute(sql + clause, args)
return txn.fetchone()[0]
return cast(Tuple[int], txn.fetchone())[0]
if not user_ids:
return 0
@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
return list(txn)
return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall())
async def _get_device_update_edus_by_remote(
self,
@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore):
async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int
) -> int:
def f(txn):
def f(txn: LoggingTransaction) -> int:
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore):
if not user_ids_to_check:
return set()
def _get_users_whose_devices_changed_txn(txn):
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
changes = set()
stream_id_where_clause = "stream_id > ?"
@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore):
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user."""
def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
def _mark_remote_user_device_list_as_unsubscribed_txn(
txn: LoggingTransaction,
) -> None:
self.db_pool.simple_delete_txn(
txn,
table="device_lists_remote_extremeties",
@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
def _store_dehydrated_device_txn(
self, txn, user_id: str, device_id: str, device_data: str
self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
) -> Optional[str]:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""
yesterday = self._clock.time_msec() - prune_age
def _prune_txn(txn):
def _prune_txn(txn: LoggingTransaction) -> None:
# look for (user, destination) pairs which have an update older than
# the cutoff.
#
@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
"drop_device_lists_outbound_last_success_non_unique_idx",
)
async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
async def _drop_device_list_streams_non_unique_indexes(
self, progress: JsonDict, batch_size: int
) -> int:
def f(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
return 1
async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
async def _remove_duplicate_outbound_pokes(
self, progress: JsonDict, batch_size: int
) -> int:
# for some reason, we have accumulated duplicate entries in
# device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
# efficient.
@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
)
def _txn(txn):
def _txn(txn: LoggingTransaction) -> int:
clause, args = make_tuple_comparison_clause(
[(x, last_row[x]) for x in KEY_COLS]
)
@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context = get_active_span_text_map()
def add_device_changes_txn(txn, stream_ids):
def add_device_changes_txn(
txn: LoggingTransaction, stream_ids: List[int]
) -> None:
self._add_device_change_to_stream_txn(
txn,
user_id,
@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn: LoggingTransaction,
user_id: str,
device_ids: Collection[str],
stream_ids: List[str],
):
stream_ids: List[int],
) -> None:
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id,
@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
stream_ids: List[str],
stream_ids: List[int],
context: Dict[str, str],
) -> None:
"""Record the user in the room has updated their device."""
@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
LIMIT ?
"""
def get_uncoverted_outbound_room_pokes_txn(txn):
def get_uncoverted_outbound_room_pokes_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
txn.execute(sql, (limit,))
return [
@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Marks the associated row in `device_lists_changes_in_room` as handled.
"""
def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
def add_device_list_outbound_pokes_txn(
txn: LoggingTransaction, stream_ids: List[int]
) -> None:
if hosts:
self._add_device_outbound_poke_to_stream_txn(
txn,