mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-12 03:40:02 -04:00
Reduce the number of "untyped defs" (#12716)
This commit is contained in:
parent
de1e599b9d
commit
17e1eb7749
16 changed files with 142 additions and 69 deletions
|
@ -31,6 +31,7 @@ from typing import (
|
|||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
|
@ -41,6 +42,7 @@ from prometheus_client import Histogram
|
|||
from typing_extensions import Concatenate, Literal, ParamSpec
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.internet.interfaces import IReactorCore
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
|
@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
|
|||
|
||||
|
||||
def make_pool(
|
||||
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||
reactor: IReactorCore,
|
||||
db_config: DatabaseConnectionConfig,
|
||||
engine: BaseDatabaseEngine,
|
||||
) -> adbapi.ConnectionPool:
|
||||
"""Get the connection pool for the database."""
|
||||
|
||||
|
@ -101,7 +105,7 @@ def make_pool(
|
|||
db_args = dict(db_config.config.get("args", {}))
|
||||
db_args.setdefault("cp_reconnect", True)
|
||||
|
||||
def _on_new_connection(conn):
|
||||
def _on_new_connection(conn: Connection) -> None:
|
||||
# Ensure we have a logging context so we can correctly track queries,
|
||||
# etc.
|
||||
with LoggingContext("db.on_new_connection"):
|
||||
|
@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
|
|||
default_txn_name: str
|
||||
|
||||
def cursor(
|
||||
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
|
||||
self,
|
||||
*,
|
||||
txn_name: Optional[str] = None,
|
||||
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
||||
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
||||
) -> "LoggingTransaction":
|
||||
if not txn_name:
|
||||
txn_name = self.default_txn_name
|
||||
|
@ -183,11 +191,16 @@ class LoggingDatabaseConnection:
|
|||
self.conn.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[types.TracebackType],
|
||||
) -> Optional[bool]:
|
||||
return self.conn.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
# Proxy through any unknown lookups to the DB conn class.
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.conn, name)
|
||||
|
||||
|
||||
|
@ -391,17 +404,22 @@ class LoggingTransaction:
|
|||
def __enter__(self) -> "LoggingTransaction":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[types.TracebackType],
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class PerformanceCounters:
|
||||
def __init__(self):
|
||||
self.current_counters = {}
|
||||
self.previous_counters = {}
|
||||
def __init__(self) -> None:
|
||||
self.current_counters: Dict[str, Tuple[int, float]] = {}
|
||||
self.previous_counters: Dict[str, Tuple[int, float]] = {}
|
||||
|
||||
def update(self, key: str, duration_secs: float) -> None:
|
||||
count, cum_time = self.current_counters.get(key, (0, 0))
|
||||
count, cum_time = self.current_counters.get(key, (0, 0.0))
|
||||
count += 1
|
||||
cum_time += duration_secs
|
||||
self.current_counters[key] = (count, cum_time)
|
||||
|
@ -527,7 +545,7 @@ class DatabasePool:
|
|||
def start_profiling(self) -> None:
|
||||
self._previous_loop_ts = monotonic_time()
|
||||
|
||||
def loop():
|
||||
def loop() -> None:
|
||||
curr = self._current_txn_total_time
|
||||
prev = self._previous_txn_total_time
|
||||
self._previous_txn_total_time = curr
|
||||
|
@ -1186,7 +1204,7 @@ class DatabasePool:
|
|||
if lock:
|
||||
self.engine.lock_table(txn, table)
|
||||
|
||||
def _getwhere(key):
|
||||
def _getwhere(key: str) -> str:
|
||||
# If the value we're passing in is None (aka NULL), we need to use
|
||||
# IS, not =, as NULL = NULL equals NULL (False).
|
||||
if keyvalues[key] is None:
|
||||
|
@ -2258,7 +2276,7 @@ class DatabasePool:
|
|||
term: Optional[str],
|
||||
col: str,
|
||||
retcols: Collection[str],
|
||||
desc="simple_search_list",
|
||||
desc: str = "simple_search_list",
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
|
|
@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
|||
from synapse.storage.databases.main.event_push_actions import (
|
||||
EventPushActionsWorkerStore,
|
||||
)
|
||||
from synapse.storage.types import Cursor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
self._last_user_visit_update = self._get_start_of_day()
|
||||
|
||||
@wrap_as_background_process("read_forward_extremities")
|
||||
async def _read_forward_extremities(self):
|
||||
async def _read_forward_extremities(self) -> None:
|
||||
def fetch(txn):
|
||||
txn.execute(
|
||||
"""
|
||||
|
@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
(x[0] - 1) * x[1] for x in res if x[1]
|
||||
)
|
||||
|
||||
async def count_daily_e2ee_messages(self):
|
||||
async def count_daily_e2ee_messages(self) -> int:
|
||||
"""
|
||||
Returns an estimate of the number of messages sent in the last day.
|
||||
|
||||
|
@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
|
||||
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
|
||||
|
||||
async def count_daily_sent_e2ee_messages(self):
|
||||
async def count_daily_sent_e2ee_messages(self) -> int:
|
||||
def _count_messages(txn):
|
||||
# This is good enough as if you have silly characters in your own
|
||||
# hostname then that's your own fault.
|
||||
|
@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
"count_daily_sent_e2ee_messages", _count_messages
|
||||
)
|
||||
|
||||
async def count_daily_active_e2ee_rooms(self):
|
||||
async def count_daily_active_e2ee_rooms(self) -> int:
|
||||
def _count(txn):
|
||||
sql = """
|
||||
SELECT COUNT(DISTINCT room_id) FROM events
|
||||
|
@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
"count_daily_active_e2ee_rooms", _count
|
||||
)
|
||||
|
||||
async def count_daily_messages(self):
|
||||
async def count_daily_messages(self) -> int:
|
||||
"""
|
||||
Returns an estimate of the number of messages sent in the last day.
|
||||
|
||||
|
@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
|
||||
return await self.db_pool.runInteraction("count_messages", _count_messages)
|
||||
|
||||
async def count_daily_sent_messages(self):
|
||||
async def count_daily_sent_messages(self) -> int:
|
||||
def _count_messages(txn):
|
||||
# This is good enough as if you have silly characters in your own
|
||||
# hostname then that's your own fault.
|
||||
|
@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
"count_daily_sent_messages", _count_messages
|
||||
)
|
||||
|
||||
async def count_daily_active_rooms(self):
|
||||
async def count_daily_active_rooms(self) -> int:
|
||||
def _count(txn):
|
||||
sql = """
|
||||
SELECT COUNT(DISTINCT room_id) FROM events
|
||||
|
@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
"count_monthly_users", self._count_users, thirty_days_ago
|
||||
)
|
||||
|
||||
def _count_users(self, txn, time_from):
|
||||
def _count_users(self, txn: Cursor, time_from: int) -> int:
|
||||
"""
|
||||
Returns number of users seen in the past time_from period
|
||||
"""
|
||||
|
@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
) u
|
||||
"""
|
||||
txn.execute(sql, (time_from,))
|
||||
(count,) = txn.fetchone()
|
||||
# Mypy knows that fetchone() might return None if there are no rows.
|
||||
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
|
||||
# returns exactly one row.
|
||||
(count,) = txn.fetchone() # type: ignore[misc]
|
||||
return count
|
||||
|
||||
async def count_r30_users(self) -> Dict[str, int]:
|
||||
|
@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||
"count_r30v2_users", _count_r30v2_users
|
||||
)
|
||||
|
||||
def _get_start_of_day(self):
|
||||
def _get_start_of_day(self) -> int:
|
||||
"""
|
||||
Returns millisecond unixtime for start of UTC day.
|
||||
"""
|
||||
|
|
|
@ -798,9 +798,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
self,
|
||||
txn: LoggingTransaction,
|
||||
event_id: str,
|
||||
allow_none=False,
|
||||
) -> int:
|
||||
return self.db_pool.simple_select_one_onecol_txn(
|
||||
allow_none: bool = False,
|
||||
) -> Optional[int]:
|
||||
# Type ignore: we pass keyvalues a Dict[str, str]; the function wants
|
||||
# Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
|
||||
return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload]
|
||||
txn=txn,
|
||||
table="events",
|
||||
keyvalues={"event_id": event_id},
|
||||
|
|
|
@ -25,6 +25,7 @@ from typing import (
|
|||
Collection,
|
||||
Deque,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
|
@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
|||
|
||||
return res
|
||||
|
||||
def _handle_queue(self, room_id):
|
||||
def _handle_queue(self, room_id: str) -> None:
|
||||
"""Attempts to handle the queue for a room if not already being handled.
|
||||
|
||||
The queue's callback will be invoked with for each item in the queue,
|
||||
|
@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
|||
|
||||
self._currently_persisting_rooms.add(room_id)
|
||||
|
||||
async def handle_queue_loop():
|
||||
async def handle_queue_loop() -> None:
|
||||
try:
|
||||
queue = self._get_drainining_queue(room_id)
|
||||
for item in queue:
|
||||
|
@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
|||
with PreserveLoggingContext():
|
||||
item.deferred.callback(ret)
|
||||
finally:
|
||||
queue = self._event_persist_queues.pop(room_id, None)
|
||||
if queue:
|
||||
self._event_persist_queues[room_id] = queue
|
||||
remaining_queue = self._event_persist_queues.pop(room_id, None)
|
||||
if remaining_queue:
|
||||
self._event_persist_queues[room_id] = remaining_queue
|
||||
self._currently_persisting_rooms.discard(room_id)
|
||||
|
||||
# set handle_queue_loop off in the background
|
||||
run_as_background_process("persist_events", handle_queue_loop)
|
||||
|
||||
def _get_drainining_queue(self, room_id):
|
||||
def _get_drainining_queue(
|
||||
self, room_id: str
|
||||
) -> Generator[_EventPersistQueueItem, None, None]:
|
||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||
|
||||
try:
|
||||
|
@ -317,7 +320,9 @@ class EventsPersistenceStorage:
|
|||
for event, ctx in events_and_contexts:
|
||||
partitioned.setdefault(event.room_id, []).append((event, ctx))
|
||||
|
||||
async def enqueue(item):
|
||||
async def enqueue(
|
||||
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
|
||||
) -> Dict[str, str]:
|
||||
room_id, evs_ctxs = item
|
||||
return await self._event_persist_queue.add_to_queue(
|
||||
room_id, evs_ctxs, backfilled=backfilled
|
||||
|
@ -1102,7 +1107,7 @@ class EventsPersistenceStorage:
|
|||
|
||||
return False
|
||||
|
||||
async def _handle_potentially_left_users(self, user_ids: Set[str]):
|
||||
async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
|
||||
"""Given a set of remote users check if the server still shares a room with
|
||||
them. If not then mark those users' device cache as stale.
|
||||
"""
|
||||
|
|
|
@ -85,7 +85,7 @@ def prepare_database(
|
|||
database_engine: BaseDatabaseEngine,
|
||||
config: Optional[HomeServerConfig],
|
||||
databases: Collection[str] = ("main", "state"),
|
||||
):
|
||||
) -> None:
|
||||
"""Prepares a physical database for usage. Will either create all necessary tables
|
||||
or upgrade from an older schema version.
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ class StateFilter:
|
|||
types: "frozendict[str, Optional[FrozenSet[str]]]"
|
||||
include_others: bool = False
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
def __attrs_post_init__(self) -> None:
|
||||
# If `include_others` is set we canonicalise the filter by removing
|
||||
# wildcards from the types dictionary
|
||||
if self.include_others:
|
||||
|
@ -138,7 +138,9 @@ class StateFilter:
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
|
||||
def freeze(
|
||||
types: Mapping[str, Optional[Collection[str]]], include_others: bool
|
||||
) -> "StateFilter":
|
||||
"""
|
||||
Returns a (frozen) StateFilter with the same contents as the parameters
|
||||
specified here, which can be made of mutable types.
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
|
||||
from types import TracebackType
|
||||
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
@ -86,5 +87,10 @@ class Connection(Protocol):
|
|||
def __enter__(self) -> "Connection":
|
||||
...
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
...
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue