diff --git a/changelog.d/9299.misc b/changelog.d/9299.misc new file mode 100644 index 000000000..c883a677e --- /dev/null +++ b/changelog.d/9299.misc @@ -0,0 +1 @@ +Update the `Cursor` type hints to better match PEP 249. diff --git a/synapse/storage/database.py b/synapse/storage/database.py index d2ba4bd2f..ae4bf1a54 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -158,8 +158,8 @@ class LoggingDatabaseConnection: def commit(self) -> None: self.conn.commit() - def rollback(self, *args, **kwargs) -> None: - self.conn.rollback(*args, **kwargs) + def rollback(self) -> None: + self.conn.rollback() def __enter__(self) -> "Connection": self.conn.__enter__() @@ -244,12 +244,15 @@ class LoggingTransaction: assert self.exception_callbacks is not None self.exception_callbacks.append((callback, args, kwargs)) + def fetchone(self) -> Optional[Tuple]: + return self.txn.fetchone() + + def fetchmany(self, size: Optional[int] = None) -> List[Tuple]: + return self.txn.fetchmany(size=size) + def fetchall(self) -> List[Tuple]: return self.txn.fetchall() - def fetchone(self) -> Tuple: - return self.txn.fetchone() - def __iter__(self) -> Iterator[Tuple]: return self.txn.__iter__() @@ -754,6 +757,7 @@ class DatabasePool: Returns: A list of dicts where the key is the column header. """ + assert cursor.description is not None, "cursor.description was None" col_headers = [intern(str(column[0])) for column in cursor.description] results = [dict(zip(col_headers, row)) for row in cursor] return results diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 566ea19ba..28bb2eb66 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -619,9 +619,9 @@ def _get_or_create_schema_state( txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() - current_version = int(row[0]) if row else None - if current_version: + if row is not None: + current_version = int(row[0]) txn.execute( "SELECT file FROM applied_schema_deltas WHERE version >= ?", (current_version,), diff --git a/synapse/storage/types.py b/synapse/storage/types.py index 9cadcba18..17291c9d5 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union from typing_extensions import Protocol @@ -20,23 +20,44 @@ from typing_extensions import Protocol Some very basic protocol definitions for the DB-API2 classes specified in PEP-249 """ +_Parameters = Union[Sequence[Any], Mapping[str, Any]] + class Cursor(Protocol): - def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any: + def execute(self, sql: str, parameters: _Parameters = ...) -> Any: ... - def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any: + def executemany(self, sql: str, parameters: Sequence[_Parameters]) -> Any: + ... + + def fetchone(self) -> Optional[Tuple]: + ... + + def fetchmany(self, size: Optional[int] = ...) -> List[Tuple]: ... def fetchall(self) -> List[Tuple]: ... - def fetchone(self) -> Tuple: - ... - @property - def description(self) -> Any: - return None + def description( + self, + ) -> Optional[ + Sequence[ + # Note that this is an approximate typing based on sqlite3 and other + # drivers, and may not be entirely accurate. + Tuple[ + str, + Optional[Any], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[int], + ] + ] + ]: + ... @property def rowcount(self) -> int: @@ -59,7 +80,7 @@ class Connection(Protocol): def commit(self) -> None: ... - def rollback(self, *args, **kwargs) -> None: + def rollback(self) -> None: ... def __enter__(self) -> "Connection": diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 0ec4dc291..e2b316a21 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -106,7 +106,9 @@ class PostgresSequenceGenerator(SequenceGenerator): def get_next_id_txn(self, txn: Cursor) -> int: txn.execute("SELECT nextval(?)", (self._sequence_name,)) - return txn.fetchone()[0] + fetch_res = txn.fetchone() + assert fetch_res is not None + return fetch_res[0] def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: txn.execute( @@ -147,7 +149,9 @@ class PostgresSequenceGenerator(SequenceGenerator): txn.execute( "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} ) - last_value, is_called = txn.fetchone() + fetch_res = txn.fetchone() + assert fetch_res is not None + last_value, is_called = fetch_res # If we have an associated stream check the stream_positions table. max_in_stream_positions = None