More precise type for LoggingTransaction.execute (#15432)

* More precise type for LoggingTransaction.execute
* Add an annotation for stream_ordering_month_ago

This would have spotted the error that was fixed in "Add comma missing from #15382. (#15429)"
This commit is contained in:
David Robertson 2023-04-14 19:04:49 +01:00 committed by GitHub
parent 24b61f32ff
commit 8a47d6e3a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 32 additions and 14 deletions

1
changelog.d/15432.misc Normal file
View File

@ -0,0 +1 @@
Improve type hints.

View File

@ -58,7 +58,7 @@ from synapse.metrics import register_threadpool
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
@ -371,10 +371,18 @@ class LoggingTransaction:
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch
# TODO: is it safe for values to be Iterable[Iterable[Any]] here?
# https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_batch
# suggests each arg in args should be a sequence or mapping
self._do_execute(
lambda the_sql: execute_batch(self.txn, the_sql, args), sql
)
else:
# TODO: is it safe for values to be Iterable[Iterable[Any]] here?
# https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#sqlite3.Cursor.executemany
# suggests that the outer collection may be iterable, but
# https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#how-to-use-placeholders-to-bind-values-in-sql-queries
# suggests that the inner collection should be a sequence or dict.
self.executemany(sql, args)
def execute_values(
@ -390,14 +398,20 @@ class LoggingTransaction:
from psycopg2.extras import execute_values
return self._do_execute(
# TODO: is it safe for values to be Iterable[Iterable[Any]] here?
# https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
sql,
)
def execute(self, sql: str, *args: Any) -> None:
self._do_execute(self.txn.execute, sql, *args)
def execute(self, sql: str, parameters: SQLQueryParameters = ()) -> None:
self._do_execute(self.txn.execute, sql, parameters)
def executemany(self, sql: str, *args: Any) -> None:
# TODO: we should add a type for *args here. Looking at Cursor.executemany
# and DBAPI2 it ought to be Sequence[_Parameter], but we pass in
# Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy
# complains about.
self._do_execute(self.txn.executemany, sql, *args)
def executescript(self, sql: str) -> None:

View File

@ -114,6 +114,10 @@ class _NoChainCoverIndex(Exception):
class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBaseStore):
# TODO: this attribute comes from EventPushActionWorkerStore. Should we inherit from
# that store so that mypy can deduce this for itself?
stream_ordering_month_ago: Optional[int]
def __init__(
self,
database: DatabasePool,
@ -1182,8 +1186,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Throws a StoreError if we have since purged the index for
stream_orderings from that point.
"""
if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
assert self.stream_ordering_month_ago is not None
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, f"stream_ordering too old {stream_ordering}")
sql = """
@ -1231,7 +1235,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# provided the last_change is recent enough, we now clamp the requested
# stream_ordering to it.
if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined]
assert self.stream_ordering_month_ago is not None
if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering)
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@ -1246,8 +1251,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
Throws a StoreError if we have since purged the index for
stream_orderings from that point.
"""
if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
assert self.stream_ordering_month_ago is not None
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
@ -1707,9 +1712,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
DELETE FROM stream_ordering_to_exterm
WHERE stream_ordering < ?
"""
txn.execute(
sql, (self.stream_ordering_month_ago,) # type: ignore[attr-defined]
)
txn.execute(sql, (self.stream_ordering_month_ago,))
await self.db_pool.runInteraction(
"_delete_old_forward_extrem_cache",

View File

@ -31,14 +31,14 @@ from typing_extensions import Protocol
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""
_Parameters = Union[Sequence[Any], Mapping[str, Any]]
SQLQueryParameters = Union[Sequence[Any], Mapping[str, Any]]
class Cursor(Protocol):
def execute(self, sql: str, parameters: _Parameters = ...) -> Any:
def execute(self, sql: str, parameters: SQLQueryParameters = ...) -> Any:
...
def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any:
def executemany(self, sql: str, parameters: Sequence[SQLQueryParameters]) -> Any:
...
def fetchone(self) -> Optional[Tuple]: