Add some type hints to datastore. (#12477)

This commit is contained in:
Dirk Klimpel 2022-05-10 20:07:48 +02:00 committed by GitHub
parent 147f098fb4
commit 989fa33096
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 71 deletions

1
changelog.d/12477.misc Normal file
View File

@ -0,0 +1 @@
Add some type hints to datastore.

View File

@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from typing_extensions import Literal
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
@ -106,7 +107,7 @@ class EventContext:
incomplete state. incomplete state.
""" """
rejected: Union[bool, str] = False rejected: Union[Literal[False], str] = False
_state_group: Optional[int] = None _state_group: Optional[int] = None
state_group_before_event: Optional[int] = None state_group_before_event: Optional[int] = None
prev_group: Optional[int] = None prev_group: Optional[int] = None

View File

@ -49,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines.postgres import PostgresEngine from synapse.storage.engines.postgres import PostgresEngine
from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically from synapse.util.iterutils import batch_iter, sorted_topologically
@ -235,7 +235,9 @@ class PersistEventsStore:
""" """
results: List[str] = [] results: List[str] = []
def _get_events_which_are_prevs_txn(txn, batch): def _get_events_which_are_prevs_txn(
txn: LoggingTransaction, batch: Collection[str]
) -> None:
sql = """ sql = """
SELECT prev_event_id, internal_metadata SELECT prev_event_id, internal_metadata
FROM event_edges FROM event_edges
@ -285,7 +287,9 @@ class PersistEventsStore:
# and their prev events. # and their prev events.
existing_prevs = set() existing_prevs = set()
def _get_prevs_before_rejected_txn(txn, batch): def _get_prevs_before_rejected_txn(
txn: LoggingTransaction, batch: Collection[str]
) -> None:
to_recursively_check = batch to_recursively_check = batch
while to_recursively_check: while to_recursively_check:
@ -515,7 +519,7 @@ class PersistEventsStore:
@classmethod @classmethod
def _add_chain_cover_index( def _add_chain_cover_index(
cls, cls,
txn, txn: LoggingTransaction,
db_pool: DatabasePool, db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator, event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
@ -809,7 +813,7 @@ class PersistEventsStore:
@staticmethod @staticmethod
def _allocate_chain_ids( def _allocate_chain_ids(
txn, txn: LoggingTransaction,
db_pool: DatabasePool, db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator, event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
@ -943,7 +947,7 @@ class PersistEventsStore:
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
): ) -> None:
"""Persist the mapping from transaction IDs to event IDs (if defined).""" """Persist the mapping from transaction IDs to event IDs (if defined)."""
to_insert = [] to_insert = []
@ -997,7 +1001,7 @@ class PersistEventsStore:
txn: LoggingTransaction, txn: LoggingTransaction,
state_delta_by_room: Dict[str, DeltaState], state_delta_by_room: Dict[str, DeltaState],
stream_id: int, stream_id: int,
): ) -> None:
for room_id, delta_state in state_delta_by_room.items(): for room_id, delta_state in state_delta_by_room.items():
to_delete = delta_state.to_delete to_delete = delta_state.to_delete
to_insert = delta_state.to_insert to_insert = delta_state.to_insert
@ -1155,7 +1159,7 @@ class PersistEventsStore:
txn, room_id, members_changed txn, room_id, members_changed
) )
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str): def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
"""Update the room version in the database based off current state """Update the room version in the database based off current state
events. events.
@ -1189,7 +1193,7 @@ class PersistEventsStore:
txn: LoggingTransaction, txn: LoggingTransaction,
new_forward_extremities: Dict[str, Set[str]], new_forward_extremities: Dict[str, Set[str]],
max_stream_order: int, max_stream_order: int,
): ) -> None:
for room_id in new_forward_extremities.keys(): for room_id in new_forward_extremities.keys():
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id} txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
@ -1254,9 +1258,9 @@ class PersistEventsStore:
def _update_room_depths_txn( def _update_room_depths_txn(
self, self,
txn, txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
): ) -> None:
"""Update min_depth for each room """Update min_depth for each room
Args: Args:
@ -1385,7 +1389,7 @@ class PersistEventsStore:
# nothing to do here # nothing to do here
return return
def event_dict(event): def event_dict(event: EventBase) -> JsonDict:
d = event.get_dict() d = event.get_dict()
d.pop("redacted", None) d.pop("redacted", None)
d.pop("redacted_because", None) d.pop("redacted_because", None)
@ -1476,18 +1480,20 @@ class PersistEventsStore:
), ),
) )
def _store_rejected_events_txn(self, txn, events_and_contexts): def _store_rejected_events_txn(
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
) -> List[Tuple[EventBase, EventContext]]:
"""Add rows to the 'rejections' table for received events which were """Add rows to the 'rejections' table for received events which were
rejected rejected
Args: Args:
txn (twisted.enterprise.adbapi.Connection): db connection txn: db connection
events_and_contexts (list[(EventBase, EventContext)]): events events_and_contexts: events we are persisting
we are persisting
Returns: Returns:
list[(EventBase, EventContext)] new list, without the rejected new list, without the rejected events.
events.
""" """
# Remove the rejected events from the list now that we've added them # Remove the rejected events from the list now that we've added them
# to the events table and the events_json table. # to the events table and the events_json table.
@ -1508,7 +1514,7 @@ class PersistEventsStore:
events_and_contexts: List[Tuple[EventBase, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
all_events_and_contexts: List[Tuple[EventBase, EventContext]], all_events_and_contexts: List[Tuple[EventBase, EventContext]],
inhibit_local_membership_updates: bool = False, inhibit_local_membership_updates: bool = False,
): ) -> None:
"""Update all the miscellaneous tables for new events """Update all the miscellaneous tables for new events
Args: Args:
@ -1602,7 +1608,11 @@ class PersistEventsStore:
# Prefill the event cache # Prefill the event cache
self._add_to_cache(txn, events_and_contexts) self._add_to_cache(txn, events_and_contexts)
def _add_to_cache(self, txn, events_and_contexts): def _add_to_cache(
self,
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
) -> None:
to_prefill = [] to_prefill = []
rows = [] rows = []
@ -1633,7 +1643,7 @@ class PersistEventsStore:
if not row["rejects"] and not row["redacts"]: if not row["rejects"] and not row["redacts"]:
to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
def prefill(): def prefill() -> None:
for cache_entry in to_prefill: for cache_entry in to_prefill:
self.store._get_event_cache.set( self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry (cache_entry.event.event_id,), cache_entry
@ -1663,19 +1673,24 @@ class PersistEventsStore:
) )
def insert_labels_for_event_txn( def insert_labels_for_event_txn(
self, txn, event_id, labels, room_id, topological_ordering self,
): txn: LoggingTransaction,
event_id: str,
labels: List[str],
room_id: str,
topological_ordering: int,
) -> None:
"""Store the mapping between an event's ID and its labels, with one row per """Store the mapping between an event's ID and its labels, with one row per
(event_id, label) tuple. (event_id, label) tuple.
Args: Args:
txn (LoggingTransaction): The transaction to execute. txn: The transaction to execute.
event_id (str): The event's ID. event_id: The event's ID.
labels (list[str]): A list of text labels. labels: A list of text labels.
room_id (str): The ID of the room the event was sent to. room_id: The ID of the room the event was sent to.
topological_ordering (int): The position of the event in the room's topology. topological_ordering: The position of the event in the room's topology.
""" """
return self.db_pool.simple_insert_many_txn( self.db_pool.simple_insert_many_txn(
txn=txn, txn=txn,
table="event_labels", table="event_labels",
keys=("event_id", "label", "room_id", "topological_ordering"), keys=("event_id", "label", "room_id", "topological_ordering"),
@ -1684,25 +1699,32 @@ class PersistEventsStore:
], ],
) )
def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): def _insert_event_expiry_txn(
self, txn: LoggingTransaction, event_id: str, expiry_ts: int
) -> None:
"""Save the expiry timestamp associated with a given event ID. """Save the expiry timestamp associated with a given event ID.
Args: Args:
txn (LoggingTransaction): The database transaction to use. txn: The database transaction to use.
event_id (str): The event ID the expiry timestamp is associated with. event_id: The event ID the expiry timestamp is associated with.
expiry_ts (int): The timestamp at which to expire (delete) the event. expiry_ts: The timestamp at which to expire (delete) the event.
""" """
return self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn=txn, txn=txn,
table="event_expiry", table="event_expiry",
values={"event_id": event_id, "expiry_ts": expiry_ts}, values={"event_id": event_id, "expiry_ts": expiry_ts},
) )
def _store_room_members_txn( def _store_room_members_txn(
self, txn, events, *, inhibit_local_membership_updates: bool = False self,
): txn: LoggingTransaction,
events: List[EventBase],
*,
inhibit_local_membership_updates: bool = False,
) -> None:
""" """
Store a room member in the database. Store a room member in the database.
Args: Args:
txn: The transaction to use. txn: The transaction to use.
events: List of events to store. events: List of events to store.
@ -1742,6 +1764,7 @@ class PersistEventsStore:
) )
for event in events: for event in events:
assert event.internal_metadata.stream_ordering is not None
txn.call_after( txn.call_after(
self.store._membership_stream_cache.entity_has_changed, self.store._membership_stream_cache.entity_has_changed,
event.state_key, event.state_key,
@ -1838,7 +1861,9 @@ class PersistEventsStore:
(parent_id, event.sender), (parent_id, event.sender),
) )
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase): def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase
) -> None:
"""Handles keeping track of insertion events and edges/connections. """Handles keeping track of insertion events and edges/connections.
Part of MSC2716. Part of MSC2716.
@ -1899,7 +1924,7 @@ class PersistEventsStore:
}, },
) )
def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase): def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
"""Handles inserting the batch edges/connections between the batch event """Handles inserting the batch edges/connections between the batch event
and an insertion event. Part of MSC2716. and an insertion event. Part of MSC2716.
@ -1999,25 +2024,29 @@ class PersistEventsStore:
txn, table="event_relations", keyvalues={"event_id": redacted_event_id} txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
) )
def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase): def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("topic"), str): if isinstance(event.content.get("topic"), str):
self.store_event_search_txn( self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"] txn, event, "content.topic", event.content["topic"]
) )
def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase): def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
if isinstance(event.content.get("name"), str): if isinstance(event.content.get("name"), str):
self.store_event_search_txn( self.store_event_search_txn(
txn, event, "content.name", event.content["name"] txn, event, "content.name", event.content["name"]
) )
def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase): def _store_room_message_txn(
self, txn: LoggingTransaction, event: EventBase
) -> None:
if isinstance(event.content.get("body"), str): if isinstance(event.content.get("body"), str):
self.store_event_search_txn( self.store_event_search_txn(
txn, event, "content.body", event.content["body"] txn, event, "content.body", event.content["body"]
) )
def _store_retention_policy_for_room_txn(self, txn, event): def _store_retention_policy_for_room_txn(
self, txn: LoggingTransaction, event: EventBase
) -> None:
if not event.is_state(): if not event.is_state():
logger.debug("Ignoring non-state m.room.retention event") logger.debug("Ignoring non-state m.room.retention event")
return return
@ -2077,8 +2106,11 @@ class PersistEventsStore:
) )
def _set_push_actions_for_event_and_users_txn( def _set_push_actions_for_event_and_users_txn(
self, txn, events_and_contexts, all_events_and_contexts self,
): txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
all_events_and_contexts: List[Tuple[EventBase, EventContext]],
) -> None:
"""Handles moving push actions from staging table to main """Handles moving push actions from staging table to main
event_push_actions table for all events in `events_and_contexts`. event_push_actions table for all events in `events_and_contexts`.
@ -2086,12 +2118,10 @@ class PersistEventsStore:
from the push action staging area. from the push action staging area.
Args: Args:
events_and_contexts (list[(EventBase, EventContext)]): events events_and_contexts: events we are persisting
we are persisting all_events_and_contexts: all events that we were going to persist.
all_events_and_contexts (list[(EventBase, EventContext)]): all This includes events we've already persisted, etc, that wouldn't
events that we were going to persist. This includes events appear in events_and_context.
we've already persisted, etc, that wouldn't appear in
events_and_context.
""" """
# Only non outlier events will have push actions associated with them, # Only non outlier events will have push actions associated with them,
@ -2160,7 +2190,9 @@ class PersistEventsStore:
), ),
) )
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): def _remove_push_actions_for_event_id_txn(
self, txn: LoggingTransaction, room_id: str, event_id: str
) -> None:
# Sad that we have to blow away the cache for the whole room here # Sad that we have to blow away the cache for the whole room here
txn.call_after( txn.call_after(
self.store.get_unread_event_push_actions_by_room_for_user.invalidate, self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
@ -2171,7 +2203,9 @@ class PersistEventsStore:
(room_id, event_id), (room_id, event_id),
) )
def _store_rejections_txn(self, txn, event_id, reason): def _store_rejections_txn(
self, txn: LoggingTransaction, event_id: str, reason: str
) -> None:
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="rejections", table="rejections",
@ -2183,8 +2217,10 @@ class PersistEventsStore:
) )
def _store_event_state_mappings_txn( def _store_event_state_mappings_txn(
self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]] self,
): txn: LoggingTransaction,
events_and_contexts: Collection[Tuple[EventBase, EventContext]],
) -> None:
state_groups = {} state_groups = {}
for event, context in events_and_contexts: for event, context in events_and_contexts:
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():
@ -2241,7 +2277,9 @@ class PersistEventsStore:
state_group_id, state_group_id,
) )
def _update_min_depth_for_room_txn(self, txn, room_id, depth): def _update_min_depth_for_room_txn(
self, txn: LoggingTransaction, room_id: str, depth: int
) -> None:
min_depth = self.store._get_min_depth_interaction(txn, room_id) min_depth = self.store._get_min_depth_interaction(txn, room_id)
if min_depth is not None and depth >= min_depth: if min_depth is not None and depth >= min_depth:
@ -2254,7 +2292,9 @@ class PersistEventsStore:
values={"min_depth": depth}, values={"min_depth": depth},
) )
def _handle_mult_prev_events(self, txn, events): def _handle_mult_prev_events(
self, txn: LoggingTransaction, events: List[EventBase]
) -> None:
""" """
For the given event, update the event edges table and forward and For the given event, update the event edges table and forward and
backward extremities tables. backward extremities tables.
@ -2272,7 +2312,9 @@ class PersistEventsStore:
self._update_backward_extremeties(txn, events) self._update_backward_extremeties(txn, events)
def _update_backward_extremeties(self, txn, events): def _update_backward_extremeties(
self, txn: LoggingTransaction, events: List[EventBase]
) -> None:
"""Updates the event_backward_extremities tables based on the new/updated """Updates the event_backward_extremities tables based on the new/updated
events being persisted. events being persisted.

View File

@ -14,7 +14,7 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
import attr import attr
@ -27,7 +27,7 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
) )
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict from synapse.types import JsonDict
if TYPE_CHECKING: if TYPE_CHECKING:
@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
) )
async def _background_reindex_search(self, progress, batch_size): async def _background_reindex_search(
self, progress: JsonDict, batch_size: int
) -> int:
# we work through the events table from highest stream id to lowest # we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
TYPES = ["m.room.name", "m.room.message", "m.room.topic"] TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn): def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = ( sql = (
"SELECT stream_ordering, event_id, room_id, type, json, " "SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events" " origin_server_ts FROM events"
@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return result return result
async def _background_reindex_gin_search(self, progress, batch_size): async def _background_reindex_gin_search(
self, progress: JsonDict, batch_size: int
) -> int:
"""This handles old synapses which used GIST indexes, if any; """This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema. converting them back to be GIN as per the actual schema.
""" """
def create_index(conn): def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback() conn.rollback()
# we have to set autocommit, because postgres refuses to # we have to set autocommit, because postgres refuses to
@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
) )
return 1 return 1
async def _background_reindex_search_order(self, progress, batch_size): async def _background_reindex_search_order(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
if not have_added_index: if not have_added_index:
def create_index(conn): def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback() conn.rollback()
conn.set_session(autocommit=True) conn.set_session(autocommit=True)
c = conn.cursor() c = conn.cursor()
@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
pg, pg,
) )
def reindex_search_txn(txn): def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
sql = ( sql = (
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering," "UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
" origin_server_ts = e.origin_server_ts" " origin_server_ts = e.origin_server_ts"
@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
else: else:
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
args.append(limit) # mypy expects to append only a `str`, not an `int`
args.append(limit) # type: ignore[arg-type]
results = await self.db_pool.execute( results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args "search_rooms", self.db_pool.cursor_to_dict, sql, *args
@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
A set of strings. A set of strings.
""" """
def f(txn): def f(txn: LoggingTransaction) -> Set[str]:
highlight_words = set() highlight_words = set()
for event in events: for event in events:
# As a hack we simply join values of all possible keys. This is # As a hack we simply join values of all possible keys. This is
@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
return await self.db_pool.runInteraction("_find_highlights", f) return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict): def _to_postgres_options(options_dict: JsonDict) -> str:
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),) return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
def _parse_query(database_engine, search_term): def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form """Takes a plain unicode string from the user and converts it into a form
that can be passed to database. that can be passed to database.
We use this so that we can add prefix matching, which isn't something We use this so that we can add prefix matching, which isn't something