Add some type hints to datastore (#12485)

This commit is contained in:
Dirk Klimpel 2022-04-27 14:05:00 +02:00 committed by GitHub
parent 63ba9ba38b
commit b76f1a4d5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 188 additions and 84 deletions

View file

@ -14,11 +14,25 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
cast,
)
from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret)
async def get_all_pushers(self) -> Iterator[PusherConfig]:
def get_pushers(txn):
def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn)
@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
def get_all_updated_pushers_rows_txn(txn):
def get_all_updated_pushers_rows_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """
SELECT id, user_name, app_id, pushkey
FROM pushers
@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore):
ORDER BY id ASC LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [
(stream_id, (user_name, app_id, pushkey, False))
for stream_id, user_name, app_id, pushkey in txn
]
updates = cast(
List[Tuple[int, tuple]],
[
(stream_id, (user_name, app_id, pushkey, False))
for stream_id, user_name, app_id, pushkey in txn
],
)
sql = """
SELECT stream_id, user_id, app_id, pushkey
@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore):
)
@cached(num_args=1, max_entries=15000)
async def get_if_user_has_pusher(self, user_id: str):
async def get_if_user_has_pusher(self, user_id: str) -> None:
# This only exists for the cachedList decorator
raise NotImplementedError()
async def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int
) -> None:
await self.db_pool.simple_update_one(
"pushers",
@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore):
last_user = progress.get("last_user", "")
def _delete_pushers(txn) -> int:
def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """
SELECT name FROM users
@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn) -> int:
def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """
SELECT p.id, access_token FROM pushers AS p
@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn) -> int:
def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """
SELECT p.id, p.user_name, p.app_id, p.pushkey
@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore):
async def delete_pusher_by_app_id_pushkey_user_id(
self, app_id: str, pushkey: str, user_id: str
) -> None:
def delete_pusher_txn(txn, stream_id):
def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None:
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)
@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore):
# account.
pushers = list(await self.get_pushers_by_user_id(user_id))
def delete_pushers_txn(txn, stream_ids):
def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None:
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,)
)