Reduce the number of "untyped defs" (#12716)

This commit is contained in:
David Robertson 2022-05-12 15:33:50 +01:00 committed by GitHub
parent de1e599b9d
commit 17e1eb7749
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 142 additions and 69 deletions

View file

@ -31,6 +31,7 @@ from typing import (
List,
Optional,
Tuple,
Type,
TypeVar,
cast,
overload,
@ -41,6 +42,7 @@ from prometheus_client import Histogram
from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.enterprise import adbapi
from twisted.internet.interfaces import IReactorCore
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
reactor: IReactorCore,
db_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database."""
@ -101,7 +105,7 @@ def make_pool(
db_args = dict(db_config.config.get("args", {}))
db_args.setdefault("cp_reconnect", True)
def _on_new_connection(conn):
def _on_new_connection(conn: Connection) -> None:
# Ensure we have a logging context so we can correctly track queries,
# etc.
with LoggingContext("db.on_new_connection"):
@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
default_txn_name: str
def cursor(
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
self,
*,
txn_name: Optional[str] = None,
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
) -> "LoggingTransaction":
if not txn_name:
txn_name = self.default_txn_name
@ -183,11 +191,16 @@ class LoggingDatabaseConnection:
self.conn.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> Optional[bool]:
return self.conn.__exit__(exc_type, exc_value, traceback)
# Proxy through any unknown lookups to the DB conn class.
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
return getattr(self.conn, name)
@ -391,17 +404,22 @@ class LoggingTransaction:
def __enter__(self) -> "LoggingTransaction":
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> None:
self.close()
class PerformanceCounters:
def __init__(self):
self.current_counters = {}
self.previous_counters = {}
def __init__(self) -> None:
self.current_counters: Dict[str, Tuple[int, float]] = {}
self.previous_counters: Dict[str, Tuple[int, float]] = {}
def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0))
count, cum_time = self.current_counters.get(key, (0, 0.0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
@ -527,7 +545,7 @@ class DatabasePool:
def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()
def loop():
def loop() -> None:
curr = self._current_txn_total_time
prev = self._previous_txn_total_time
self._previous_txn_total_time = curr
@ -1186,7 +1204,7 @@ class DatabasePool:
if lock:
self.engine.lock_table(txn, table)
def _getwhere(key):
def _getwhere(key: str) -> str:
# If the value we're passing in is None (aka NULL), we need to use
# IS, not =, as NULL = NULL equals NULL (False).
if keyvalues[key] is None:
@ -2258,7 +2276,7 @@ class DatabasePool:
term: Optional[str],
col: str,
retcols: Collection[str],
desc="simple_search_list",
desc: str = "simple_search_list",
) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.

View file

@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.types import Cursor
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
self._last_user_visit_update = self._get_start_of_day()
@wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self):
async def _read_forward_extremities(self) -> None:
def fetch(txn):
txn.execute(
"""
@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
async def count_daily_e2ee_messages(self):
async def count_daily_e2ee_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.
@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
async def count_daily_sent_e2ee_messages(self):
async def count_daily_sent_e2ee_messages(self) -> int:
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_sent_e2ee_messages", _count_messages
)
async def count_daily_active_e2ee_rooms(self):
async def count_daily_active_e2ee_rooms(self) -> int:
def _count(txn):
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_active_e2ee_rooms", _count
)
async def count_daily_messages(self):
async def count_daily_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.
@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return await self.db_pool.runInteraction("count_messages", _count_messages)
async def count_daily_sent_messages(self):
async def count_daily_sent_messages(self) -> int:
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_sent_messages", _count_messages
)
async def count_daily_active_rooms(self):
async def count_daily_active_rooms(self) -> int:
def _count(txn):
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_monthly_users", self._count_users, thirty_days_ago
)
def _count_users(self, txn, time_from):
def _count_users(self, txn: Cursor, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) u
"""
txn.execute(sql, (time_from,))
(count,) = txn.fetchone()
# Mypy knows that fetchone() might return None if there are no rows.
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
# returns exactly one row.
(count,) = txn.fetchone() # type: ignore[misc]
return count
async def count_r30_users(self) -> Dict[str, int]:
@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_r30v2_users", _count_r30v2_users
)
def _get_start_of_day(self):
def _get_start_of_day(self) -> int:
"""
Returns millisecond unixtime for start of UTC day.
"""

View file

@ -798,9 +798,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self,
txn: LoggingTransaction,
event_id: str,
allow_none=False,
) -> int:
return self.db_pool.simple_select_one_onecol_txn(
allow_none: bool = False,
) -> Optional[int]:
# Type ignore: we pass keyvalues a Dict[str, str]; the function wants
# Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload]
txn=txn,
table="events",
keyvalues={"event_id": event_id},

View file

@ -25,6 +25,7 @@ from typing import (
Collection,
Deque,
Dict,
Generator,
Generic,
Iterable,
List,
@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
return res
def _handle_queue(self, room_id):
def _handle_queue(self, room_id: str) -> None:
"""Attempts to handle the queue for a room if not already being handled.
The queue's callback will be invoked with for each item in the queue,
@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
self._currently_persisting_rooms.add(room_id)
async def handle_queue_loop():
async def handle_queue_loop() -> None:
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
with PreserveLoggingContext():
item.deferred.callback(ret)
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
remaining_queue = self._event_persist_queues.pop(room_id, None)
if remaining_queue:
self._event_persist_queues[room_id] = remaining_queue
self._currently_persisting_rooms.discard(room_id)
# set handle_queue_loop off in the background
run_as_background_process("persist_events", handle_queue_loop)
def _get_drainining_queue(self, room_id):
def _get_drainining_queue(
self, room_id: str
) -> Generator[_EventPersistQueueItem, None, None]:
queue = self._event_persist_queues.setdefault(room_id, deque())
try:
@ -317,7 +320,9 @@ class EventsPersistenceStorage:
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
async def enqueue(item):
async def enqueue(
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
) -> Dict[str, str]:
room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
@ -1102,7 +1107,7 @@ class EventsPersistenceStorage:
return False
async def _handle_potentially_left_users(self, user_ids: Set[str]):
async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
"""Given a set of remote users check if the server still shares a room with
them. If not then mark those users' device cache as stale.
"""

View file

@ -85,7 +85,7 @@ def prepare_database(
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str] = ("main", "state"),
):
) -> None:
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.

View file

@ -62,7 +62,7 @@ class StateFilter:
types: "frozendict[str, Optional[FrozenSet[str]]]"
include_others: bool = False
def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
@ -138,7 +138,9 @@ class StateFilter:
)
@staticmethod
def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
def freeze(
types: Mapping[str, Optional[Collection[str]]], include_others: bool
) -> "StateFilter":
"""
Returns a (frozen) StateFilter with the same contents as the parameters
specified here, which can be made of mutable types.

View file

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from types import TracebackType
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Protocol
@ -86,5 +87,10 @@ class Connection(Protocol):
def __enter__(self) -> "Connection":
...
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
...