Add type hints to state database module. (#10823)

This commit is contained in:
Patrick Cloke 2021-09-15 09:54:13 -04:00 committed by GitHub
parent b93259082c
commit 3eba047d38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 133 additions and 72 deletions

1
changelog.d/10823.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to the state database.

View File

@ -60,6 +60,7 @@ files =
synapse/storage/databases/main/session.py, synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,
synapse/storage/databases/state,
synapse/storage/database.py, synapse/storage/database.py,
synapse/storage/engines, synapse/storage/engines,
synapse/storage/keys.py, synapse/storage/keys.py,

View File

@ -13,12 +13,20 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,7 +39,9 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
updates. updates.
""" """
def _count_state_group_hops_txn(self, txn, state_group): def _count_state_group_hops_txn(
self, txn: LoggingTransaction, state_group: int
) -> int:
"""Given a state group, count how many hops there are in the tree. """Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long. This is used to ensure the delta chains don't get too long.
@ -56,7 +66,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
else: else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions # We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy) # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group next_group: Optional[int] = state_group
count = 0 count = 0
while next_group: while next_group:
@ -73,11 +83,14 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
return count return count
def _get_state_groups_from_groups_txn( def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter: Optional[StateFilter] = None self,
): txn: LoggingTransaction,
groups: List[int],
state_filter: Optional[StateFilter] = None,
) -> Mapping[int, StateMap[str]]:
state_filter = state_filter or StateFilter.all() state_filter = state_filter or StateFilter.all()
results = {group: {} for group in groups} results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause() where_clause, where_args = state_filter.make_sql_filter_clause()
@ -117,7 +130,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
""" """
for group in groups: for group in groups:
args = [group] args: List[Union[int, str]] = [group]
args.extend(where_args) args.extend(where_args)
txn.execute(sql % (where_clause,), args) txn.execute(sql % (where_clause,), args)
@ -131,7 +144,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions # We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy) # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups: for group in groups:
next_group = group next_group: Optional[int] = group
while next_group: while next_group:
# We did this before by getting the list of group ids, and # We did this before by getting the list of group ids, and
@ -173,6 +186,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
# The results shouldn't be considered mutable.
return results return results
@ -182,7 +196,12 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler( self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
@ -198,7 +217,9 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
columns=["room_id"], columns=["room_id"],
) )
async def _background_deduplicate_state(self, progress, batch_size): async def _background_deduplicate_state(
self, progress: dict, batch_size: int
) -> int:
"""This background update will slowly deduplicate state by reencoding """This background update will slowly deduplicate state by reencoding
them as deltas. them as deltas.
""" """
@ -218,7 +239,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
) )
max_group = rows[0][0] max_group = rows[0][0]
def reindex_txn(txn): def reindex_txn(txn: LoggingTransaction) -> Tuple[bool, int]:
new_last_state_group = last_state_group new_last_state_group = last_state_group
for count in range(batch_size): for count in range(batch_size):
txn.execute( txn.execute(
@ -251,7 +272,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
" WHERE id < ? AND room_id = ?", " WHERE id < ? AND room_id = ?",
(state_group, room_id), (state_group, room_id),
) )
(prev_group,) = txn.fetchone() # There will be a result due to the coalesce.
(prev_group,) = txn.fetchone() # type: ignore
new_last_state_group = state_group new_last_state_group = state_group
if prev_group: if prev_group:
@ -261,15 +283,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
# otherwise read performance degrades. # otherwise read performance degrades.
continue continue
prev_state = self._get_state_groups_from_groups_txn( prev_state_by_group = self._get_state_groups_from_groups_txn(
txn, [prev_group] txn, [prev_group]
) )
prev_state = prev_state[prev_group] prev_state = prev_state_by_group[prev_group]
curr_state = self._get_state_groups_from_groups_txn( curr_state_by_group = self._get_state_groups_from_groups_txn(
txn, [state_group] txn, [state_group]
) )
curr_state = curr_state[state_group] curr_state = curr_state_by_group[state_group]
if not set(prev_state.keys()) - set(curr_state.keys()): if not set(prev_state.keys()) - set(curr_state.keys()):
# We can only do a delta if the current has a strict super set # We can only do a delta if the current has a strict super set
@ -340,8 +362,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
return result * BATCH_SIZE_SCALE_FACTOR return result * BATCH_SIZE_SCALE_FACTOR
async def _background_index_state(self, progress, batch_size): async def _background_index_state(self, progress: dict, batch_size: int) -> int:
def reindex_txn(conn): def reindex_txn(conn: LoggingDatabaseConnection) -> None:
conn.rollback() conn.rollback()
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# postgres insists on autocommit for the index # postgres insists on autocommit for the index

View File

@ -13,43 +13,56 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import namedtuple from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import Dict, Iterable, List, Optional, Set, Tuple
import attr
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import MutableStateMap, StateMap from synapse.types import MutableStateMap, StateKey, StateMap
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100 MAX_STATE_DELTA_HOPS = 100
class _GetStateGroupDelta( @attr.s(slots=True, frozen=True, auto_attribs=True)
namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) class _GetStateGroupDelta:
):
"""Return type of get_state_group_delta that implements __len__, which lets """Return type of get_state_group_delta that implements __len__, which lets
us use the itrable flag when caching us use the iterable flag when caching
""" """
__slots__ = [] prev_group: Optional[int]
delta_ids: Optional[StateMap[str]]
def __len__(self): def __len__(self) -> int:
return len(self.delta_ids) if self.delta_ids else 0 return len(self.delta_ids) if self.delta_ids else 0
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups.""" """A data store for fetching/storing state groups."""
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the # Originally the state store used a single DictionaryCache to cache the
@ -81,19 +94,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# We size the non-members cache to be smaller than the members cache as the # We size the non-members cache to be smaller than the members cache as the
# vast majority of state in Matrix (today) is member events. # vast majority of state in Matrix (today) is member events.
self._state_group_cache = DictionaryCache( self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache(
"*stateGroupCache*", "*stateGroupCache*",
# TODO: this hasn't been tuned yet # TODO: this hasn't been tuned yet
50000, 50000,
) )
self._state_group_members_cache = DictionaryCache( self._state_group_members_cache: DictionaryCache[
int, StateKey, str
] = DictionaryCache(
"*stateGroupMembersCache*", "*stateGroupMembersCache*",
500000, 500000,
) )
def get_max_state_group_txn(txn: Cursor): def get_max_state_group_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
return txn.fetchone()[0] return txn.fetchone()[0] # type: ignore
self._state_group_seq_gen = build_sequence_generator( self._state_group_seq_gen = build_sequence_generator(
db_conn, db_conn,
@ -105,15 +120,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
) )
@cached(max_entries=10000, iterable=True) @cached(max_entries=10000, iterable=True)
async def get_state_group_delta(self, state_group): async def get_state_group_delta(self, state_group: int) -> _GetStateGroupDelta:
"""Given a state group try to return a previous group and a delta between """Given a state group try to return a previous group and a delta between
the old and the new. the old and the new.
Returns: Returns:
(prev_group, delta_ids), where both may be None. _GetStateGroupDelta containing prev_group and delta_ids, where both may be None.
""" """
def _get_state_group_delta_txn(txn): def _get_state_group_delta_txn(txn: LoggingTransaction) -> _GetStateGroupDelta:
prev_group = self.db_pool.simple_select_one_onecol_txn( prev_group = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
@ -154,7 +169,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns: Returns:
Dict of state group to state map. Dict of state group to state map.
""" """
results = {} results: Dict[int, StateMap[str]] = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks: for chunk in chunks:
@ -168,19 +183,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return results return results
def _get_state_for_group_using_cache(self, cache, group, state_filter): def _get_state_for_group_using_cache(
self,
cache: DictionaryCache[int, StateKey, str],
group: int,
state_filter: StateFilter,
) -> Tuple[MutableStateMap[str], bool]:
"""Checks if group is in cache. See `_get_state_for_groups` """Checks if group is in cache. See `_get_state_for_groups`
Args: Args:
cache(DictionaryCache): the state group cache to use cache: the state group cache to use
group(int): The state group to lookup group: The state group to lookup
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state from the database.
from the database.
Returns 2-tuple (`state_dict`, `got_all`). Returns:
`got_all` is a bool indicating if we successfully retrieved all 2-tuple (`state_dict`, `got_all`).
requests state from the cache, if False we need to query the DB for the `got_all` is a bool indicating if we successfully retrieved all
missing state. requests state from the cache, if False we need to query the DB for the
missing state.
""" """
cache_entry = cache.get(group) cache_entry = cache.get(group)
state_dict_ids = cache_entry.value state_dict_ids = cache_entry.value
@ -277,8 +297,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state return state
def _get_state_for_groups_using_cache( def _get_state_for_groups_using_cache(
self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter self,
) -> Tuple[Dict[int, StateMap[str]], Set[int]]: groups: Iterable[int],
cache: DictionaryCache[int, StateKey, str],
state_filter: StateFilter,
) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache. filtering by type/state_key, querying from a specific cache.
@ -310,21 +333,21 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
def _insert_into_cache( def _insert_into_cache(
self, self,
group_to_state_dict, group_to_state_dict: Dict[int, StateMap[str]],
state_filter, state_filter: StateFilter,
cache_seq_num_members, cache_seq_num_members: int,
cache_seq_num_non_members, cache_seq_num_non_members: int,
): ) -> None:
"""Inserts results from querying the database into the relevant cache. """Inserts results from querying the database into the relevant cache.
Args: Args:
group_to_state_dict (dict): The new entries pulled from database. group_to_state_dict: The new entries pulled from database.
Map from state group to state dict Map from state group to state dict
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
cache_seq_num_members (int): Sequence number of member cache since cache_seq_num_members: Sequence number of member cache since
last lookup in cache last lookup in cache
cache_seq_num_non_members (int): Sequence number of member cache since cache_seq_num_non_members: Sequence number of member cache since
last lookup in cache last lookup in cache
""" """
@ -395,7 +418,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group ID The state group ID
""" """
def _store_state_group_txn(txn): def _store_state_group_txn(txn: LoggingTransaction) -> int:
if current_state_ids is None: if current_state_ids is None:
# AFAIK, this can never happen # AFAIK, this can never happen
raise Exception("current_state_ids cannot be None") raise Exception("current_state_ids cannot be None")
@ -426,6 +449,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
potential_hops = self._count_state_group_hops_txn(txn, prev_group) potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
assert delta_ids is not None
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="state_group_edges", table="state_group_edges",
@ -498,7 +523,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
) )
async def purge_unreferenced_state_groups( async def purge_unreferenced_state_groups(
self, room_id: str, state_groups_to_delete self, room_id: str, state_groups_to_delete: Collection[int]
) -> None: ) -> None:
"""Deletes no longer referenced state groups and de-deltas any state """Deletes no longer referenced state groups and de-deltas any state
groups that reference them. groups that reference them.
@ -506,8 +531,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Args: Args:
room_id: The room the state groups belong to (must all be in the room_id: The room the state groups belong to (must all be in the
same room). same room).
state_groups_to_delete (Collection[int]): Set of all state groups state_groups_to_delete: Set of all state groups to delete.
to delete.
""" """
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -517,7 +541,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete, state_groups_to_delete,
) )
def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): def _purge_unreferenced_state_groups(
self,
txn: LoggingTransaction,
room_id: str,
state_groups_to_delete: Collection[int],
) -> None:
logger.info( logger.info(
"[purge] found %i state groups to delete", len(state_groups_to_delete) "[purge] found %i state groups to delete", len(state_groups_to_delete)
) )
@ -546,8 +575,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# groups to non delta versions. # groups to non delta versions.
for sg in remaining_state_groups: for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg) logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) curr_state_by_group = self._get_state_groups_from_groups_txn(txn, [sg])
curr_state = curr_state[sg] curr_state = curr_state_by_group[sg]
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="state_groups_state", keyvalues={"state_group": sg} txn, table="state_groups_state", keyvalues={"state_group": sg}
@ -605,12 +634,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return {row["state_group"]: row["prev_state_group"] for row in rows} return {row["state_group"]: row["prev_state_group"] for row in rows}
async def purge_room_state(self, room_id, state_groups_to_delete): async def purge_room_state(
self, room_id: str, state_groups_to_delete: Collection[int]
) -> None:
"""Deletes all record of a room from state tables """Deletes all record of a room from state tables
Args: Args:
room_id (str): room_id:
state_groups_to_delete (list[int]): State groups to delete state_groups_to_delete: State groups to delete
""" """
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@ -620,7 +651,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete, state_groups_to_delete,
) )
def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): def _purge_room_state_txn(
self,
txn: LoggingTransaction,
room_id: str,
state_groups_to_delete: Collection[int],
) -> None:
# first we have to delete the state groups states # first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id) logger.info("[purge] removing %s from state_groups_state", room_id)

View File

@ -377,7 +377,8 @@ class StateGroupStorage:
make up the delta between the old and new state groups. make up the delta between the old and new state groups.
""" """
return await self.stores.state.get_state_group_delta(state_group) state_group_delta = await self.stores.state.get_state_group_delta(state_group)
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids( async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str] self, _room_id: str, event_ids: Iterable[str]

View File

@ -130,7 +130,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
sequence: int, sequence: int,
key: KT, key: KT,
value: Dict[DKT, DV], value: Dict[DKT, DV],
fetched_keys: Optional[Set[DKT]] = None, fetched_keys: Optional[Iterable[DKT]] = None,
) -> None: ) -> None:
"""Updates the entry in the cache """Updates the entry in the cache
@ -155,7 +155,7 @@ class DictionaryCache(Generic[KT, DKT, DV]):
self._update_or_insert(key, value, fetched_keys) self._update_or_insert(key, value, fetched_keys)
def _update_or_insert( def _update_or_insert(
self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT] self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
) -> None: ) -> None:
# We pop and reinsert as we need to tell the cache the size may have # We pop and reinsert as we need to tell the cache the size may have
# changed # changed