Add type hints to synapse/storage/databases/main/stats.py (#11653)

This commit is contained in:
Dirk Klimpel 2021-12-29 14:01:13 +01:00 committed by GitHub
parent fcfe67578f
commit 15bb1c8511
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 57 additions and 42 deletions

1
changelog.d/11653.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to storage classes.

View File

@ -39,7 +39,6 @@ exclude = (?x)
|synapse/storage/databases/main/roommember.py |synapse/storage/databases/main/roommember.py
|synapse/storage/databases/main/search.py |synapse/storage/databases/main/search.py
|synapse/storage/databases/main/state.py |synapse/storage/databases/main/state.py
|synapse/storage/databases/main/stats.py
|synapse/storage/databases/main/user_directory.py |synapse/storage/databases/main/user_directory.py
|synapse/storage/schema/ |synapse/storage/schema/
@ -214,6 +213,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.profile] [mypy-synapse.storage.databases.main.profile]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.stats]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.state_deltas] [mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -16,7 +16,7 @@
import logging import logging
from enum import Enum from enum import Enum
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from typing_extensions import Counter from typing_extensions import Counter
@ -24,7 +24,11 @@ from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.state_deltas import StateDeltasStore from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -122,7 +126,9 @@ class StatsStore(StateDeltasStore):
self.db_pool.updates.register_noop_background_update("populate_stats_cleanup") self.db_pool.updates.register_noop_background_update("populate_stats_cleanup")
self.db_pool.updates.register_noop_background_update("populate_stats_prepare") self.db_pool.updates.register_noop_background_update("populate_stats_prepare")
async def _populate_stats_process_users(self, progress, batch_size): async def _populate_stats_process_users(
self, progress: JsonDict, batch_size: int
) -> int:
""" """
This is a background update which regenerates statistics for users. This is a background update which regenerates statistics for users.
""" """
@ -134,7 +140,7 @@ class StatsStore(StateDeltasStore):
last_user_id = progress.get("last_user_id", "") last_user_id = progress.get("last_user_id", "")
def _get_next_batch(txn): def _get_next_batch(txn: LoggingTransaction) -> List[str]:
sql = """ sql = """
SELECT DISTINCT name FROM users SELECT DISTINCT name FROM users
WHERE name > ? WHERE name > ?
@ -168,7 +174,9 @@ class StatsStore(StateDeltasStore):
return len(users_to_work_on) return len(users_to_work_on)
async def _populate_stats_process_rooms(self, progress, batch_size): async def _populate_stats_process_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""This is a background update which regenerates statistics for rooms.""" """This is a background update which regenerates statistics for rooms."""
if not self.stats_enabled: if not self.stats_enabled:
await self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
@ -178,7 +186,7 @@ class StatsStore(StateDeltasStore):
last_room_id = progress.get("last_room_id", "") last_room_id = progress.get("last_room_id", "")
def _get_next_batch(txn): def _get_next_batch(txn: LoggingTransaction) -> List[str]:
sql = """ sql = """
SELECT DISTINCT room_id FROM current_state_events SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ? WHERE room_id > ?
@ -307,7 +315,7 @@ class StatsStore(StateDeltasStore):
stream_id: Current position. stream_id: Current position.
""" """
def _bulk_update_stats_delta_txn(txn): def _bulk_update_stats_delta_txn(txn: LoggingTransaction) -> None:
for stats_type, stats_updates in updates.items(): for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items(): for stats_id, fields in stats_updates.items():
logger.debug( logger.debug(
@ -339,7 +347,7 @@ class StatsStore(StateDeltasStore):
stats_type: str, stats_type: str,
stats_id: str, stats_id: str,
fields: Dict[str, int], fields: Dict[str, int],
complete_with_stream_id: Optional[int], complete_with_stream_id: int,
absolute_field_overrides: Optional[Dict[str, int]] = None, absolute_field_overrides: Optional[Dict[str, int]] = None,
) -> None: ) -> None:
""" """
@ -372,14 +380,14 @@ class StatsStore(StateDeltasStore):
def _update_stats_delta_txn( def _update_stats_delta_txn(
self, self,
txn, txn: LoggingTransaction,
ts, ts: int,
stats_type, stats_type: str,
stats_id, stats_id: str,
fields, fields: Dict[str, int],
complete_with_stream_id, complete_with_stream_id: int,
absolute_field_overrides=None, absolute_field_overrides: Optional[Dict[str, int]] = None,
): ) -> None:
if absolute_field_overrides is None: if absolute_field_overrides is None:
absolute_field_overrides = {} absolute_field_overrides = {}
@ -422,20 +430,23 @@ class StatsStore(StateDeltasStore):
) )
def _upsert_with_additive_relatives_txn( def _upsert_with_additive_relatives_txn(
self, txn, table, keyvalues, absolutes, additive_relatives self,
): txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
absolutes: Dict[str, Any],
additive_relatives: Dict[str, int],
) -> None:
"""Used to update values in the stats tables. """Used to update values in the stats tables.
This is basically a slightly convoluted upsert that *adds* to any This is basically a slightly convoluted upsert that *adds* to any
existing rows. existing rows.
Args: Args:
txn table: Table name
table (str): Table name keyvalues: Row-identifying key values
keyvalues (dict[str, any]): Row-identifying key values absolutes: Absolute (set) fields
absolutes (dict[str, any]): Absolute (set) fields additive_relatives: Fields that will be added onto if existing row present.
additive_relatives (dict[str, int]): Fields that will be added onto
if existing row present.
""" """
if self.database_engine.can_native_upsert: if self.database_engine.can_native_upsert:
absolute_updates = [ absolute_updates = [
@ -491,20 +502,17 @@ class StatsStore(StateDeltasStore):
current_row.update(absolutes) current_row.update(absolutes)
self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row) self.db_pool.simple_update_one_txn(txn, table, keyvalues, current_row)
async def _calculate_and_set_initial_state_for_room( async def _calculate_and_set_initial_state_for_room(self, room_id: str) -> None:
self, room_id: str
) -> Tuple[dict, dict, int]:
"""Calculate and insert an entry into room_stats_current. """Calculate and insert an entry into room_stats_current.
Args: Args:
room_id: The room ID under calculation. room_id: The room ID under calculation.
Returns:
A tuple of room state, membership counts and stream position.
""" """
def _fetch_current_state_stats(txn): def _fetch_current_state_stats(
pos = self.get_room_max_stream_ordering() txn: LoggingTransaction,
) -> Tuple[List[str], Dict[str, int], int, List[str], int]:
pos = self.get_room_max_stream_ordering() # type: ignore[attr-defined]
rows = self.db_pool.simple_select_many_txn( rows = self.db_pool.simple_select_many_txn(
txn, txn,
@ -524,7 +532,7 @@ class StatsStore(StateDeltasStore):
retcols=["event_id"], retcols=["event_id"],
) )
event_ids = [row["event_id"] for row in rows] event_ids = cast(List[str], [row["event_id"] for row in rows])
txn.execute( txn.execute(
""" """
@ -544,9 +552,9 @@ class StatsStore(StateDeltasStore):
(room_id,), (room_id,),
) )
(current_state_events_count,) = txn.fetchone() current_state_events_count = cast(Tuple[int], txn.fetchone())[0]
users_in_room = self.get_users_in_room_txn(txn, room_id) users_in_room = self.get_users_in_room_txn(txn, room_id) # type: ignore[attr-defined]
return ( return (
event_ids, event_ids,
@ -566,7 +574,7 @@ class StatsStore(StateDeltasStore):
"get_initial_state_for_room", _fetch_current_state_stats "get_initial_state_for_room", _fetch_current_state_stats
) )
state_event_map = await self.get_events(event_ids, get_prev_content=False) state_event_map = await self.get_events(event_ids, get_prev_content=False) # type: ignore[attr-defined]
room_state = { room_state = {
"join_rules": None, "join_rules": None,
@ -622,8 +630,10 @@ class StatsStore(StateDeltasStore):
}, },
) )
async def _calculate_and_set_initial_state_for_user(self, user_id): async def _calculate_and_set_initial_state_for_user(self, user_id: str) -> None:
def _calculate_and_set_initial_state_for_user_txn(txn): def _calculate_and_set_initial_state_for_user_txn(
txn: LoggingTransaction,
) -> Tuple[int, int]:
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
txn.execute( txn.execute(
@ -634,7 +644,7 @@ class StatsStore(StateDeltasStore):
""", """,
(user_id,), (user_id,),
) )
(count,) = txn.fetchone() count = cast(Tuple[int], txn.fetchone())[0]
return count, pos return count, pos
joined_rooms, pos = await self.db_pool.runInteraction( joined_rooms, pos = await self.db_pool.runInteraction(
@ -678,7 +688,9 @@ class StatsStore(StateDeltasStore):
users that exist given this query users that exist given this query
""" """
def get_users_media_usage_paginate_txn(txn): def get_users_media_usage_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
filters = [] filters = []
args = [self.hs.config.server.server_name] args = [self.hs.config.server.server_name]
@ -733,7 +745,7 @@ class StatsStore(StateDeltasStore):
sql_base=sql_base, sql_base=sql_base,
) )
txn.execute(sql, args) txn.execute(sql, args)
count = txn.fetchone()[0] count = cast(Tuple[int], txn.fetchone())[0]
sql = """ sql = """
SELECT SELECT