Add type hints to state database module. (#10823)

This commit is contained in:
Patrick Cloke 2021-09-15 09:54:13 -04:00 committed by GitHub
parent b93259082c
commit 3eba047d38
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 133 additions and 72 deletions

View file

@ -13,12 +13,20 @@
# limitations under the License.
import logging
from typing import Optional
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
updates.
"""
def _count_state_group_hops_txn(self, txn, state_group):
def _count_state_group_hops_txn(
self, txn: LoggingTransaction, state_group: int
) -> int:
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
@ -56,7 +66,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
next_group: Optional[int] = state_group
count = 0
while next_group:
@ -73,11 +83,14 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter: Optional[StateFilter] = None
):
self,
txn: LoggingTransaction,
groups: List[int],
state_filter: Optional[StateFilter] = None,
) -> Mapping[int, StateMap[str]]:
state_filter = state_filter or StateFilter.all()
results = {group: {} for group in groups}
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
@ -117,7 +130,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""
for group in groups:
args = [group]
args: List[Union[int, str]] = [group]
args.extend(where_args)
txn.execute(sql % (where_clause,), args)
@ -131,7 +144,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
next_group: Optional[int] = group
while next_group:
# We did this before by getting the list of group ids, and
@ -173,6 +186,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
allow_none=True,
)
# The results shouldn't be considered mutable.
return results
@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
@ -198,7 +217,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"],
)
async def _background_deduplicate_state(self, progress, batch_size):
async def _background_deduplicate_state(
self, progress: dict, batch_size: int
) -> int:
"""This background update will slowly deduplicate state by reencoding
them as deltas.
"""
@ -218,7 +239,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
)
max_group = rows[0][0]
def reindex_txn(txn):
def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
new_last_state_group = last_state_group
for count in range(batch_size):
txn.execute(
@ -251,7 +272,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
" WHERE id < ? AND room_id = ?",
(state_group, room_id),
)
(prev_group,) = txn.fetchone()
# There will be a result due to the coalesce.
(prev_group,) = txn.fetchone() # type: ignore
new_last_state_group = state_group
if prev_group:
@ -261,15 +283,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
# otherwise read performance degrades.
continue
prev_state = self._get_state_groups_from_groups_txn(
prev_state_by_group = self._get_state_groups_from_groups_txn(
txn, [prev_group]
)
prev_state = prev_state[prev_group]
prev_state = prev_state_by_group[prev_group]
curr_state = self._get_state_groups_from_groups_txn(
curr_state_by_group = self._get_state_groups_from_groups_txn(
txn, [state_group]
)
curr_state = curr_state[state_group]
curr_state = curr_state_by_group[state_group]
if not set(prev_state.keys()) - set(curr_state.keys()):
# We can only do a delta if the current has a strict super set
@ -340,8 +362,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return result * BATCH_SIZE_SCALE_FACTOR
async def _background_index_state(self, progress, batch_size):
def reindex_txn(conn):
async def _background_index_state(self, progress: dict, batch_size: int) -> int:
def reindex_txn(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index