Add some type hints to datastore. (#12477)

This commit is contained in:
Dirk Klimpel 2022-05-10 20:07:48 +02:00 committed by GitHub
parent 147f098fb4
commit 989fa33096
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 122 additions and 71 deletions

View file

@ -14,7 +14,7 @@
import logging
import re
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
import attr
@ -27,7 +27,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
if TYPE_CHECKING:
@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
)
async def _background_reindex_search(self, progress, batch_size):
async def _background_reindex_search(
self, progress: JsonDict, batch_size: int
) -> int:
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn):
def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events"
@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
return result
async def _background_reindex_gin_search(self, progress, batch_size):
async def _background_reindex_gin_search(
self, progress: JsonDict, batch_size: int
) -> int:
"""This handles old synapses which used GIST indexes, if any;
converting them back to be GIN as per the actual schema.
"""
def create_index(conn):
def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
# we have to set autocommit, because postgres refuses to
@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
)
return 1
async def _background_reindex_search_order(self, progress, batch_size):
async def _background_reindex_search_order(
self, progress: JsonDict, batch_size: int
) -> int:
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
if not have_added_index:
def create_index(conn):
def create_index(conn: LoggingDatabaseConnection) -> None:
conn.rollback()
conn.set_session(autocommit=True)
c = conn.cursor()
@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
pg,
)
def reindex_search_txn(txn):
def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
sql = (
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
" origin_server_ts = e.origin_server_ts"
@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
else:
raise Exception("Unrecognized database engine")
args.append(limit)
# mypy expects to append only a `str`, not an `int`
args.append(limit) # type: ignore[arg-type]
results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
A set of strings.
"""
def f(txn):
def f(txn: LoggingTransaction) -> Set[str]:
highlight_words = set()
for event in events:
# As a hack we simply join values of all possible keys. This is
@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
def _to_postgres_options(options_dict: JsonDict) -> str:
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
def _parse_query(database_engine, search_term):
def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something