Add some type hints to datastore. (#12255)

This commit is contained in:
Dirk Klimpel 2022-03-28 20:11:14 +02:00 committed by GitHub
parent 4ba55a620f
commit ac95167d2f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 61 additions and 42 deletions

View file

@ -14,7 +14,7 @@
# limitations under the License.
import collections.abc
import logging
from typing import TYPE_CHECKING, Iterable, Optional, Set
from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@ -29,7 +29,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.types import JsonDict, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# We delegate to the cached version
return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn):
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
) -> StateMap[str]:
results = {}
sql = """
SELECT type, state_key, event_id FROM current_state_events
@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
return
return None
event = await self.get_event(event_id, allow_none=True)
if not event:
return
return None
return event.content.get("canonical_alias")
@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids",
num_args=1,
)
async def _get_state_group_for_events(self, event_ids):
async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
"""Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.server_name: str = hs.hostname
self.db_pool.updates.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
self._background_remove_left_rooms,
)
async def _background_remove_left_rooms(self, progress, batch_size):
async def _background_remove_left_rooms(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to delete rows from `current_state_events` and
`event_forward_extremities` tables of rooms that the server is no
longer joined to.
@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "")
def _background_remove_left_rooms_txn(txn):
def _background_remove_left_rooms_txn(
txn: LoggingTransaction,
) -> Tuple[bool, Set[str]]:
# get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events