Reduce the number of "untyped defs" (#12716)

This commit is contained in:
David Robertson 2022-05-12 15:33:50 +01:00 committed by GitHub
parent de1e599b9d
commit 17e1eb7749
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 142 additions and 69 deletions

1
changelog.d/12716.misc Normal file
View File

@ -0,0 +1 @@
Add type annotations to increase the number of modules passing `disallow-untyped-defs`.

View File

@ -119,9 +119,18 @@ disallow_untyped_defs = True
[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False
[mypy-synapse.groups.*]
disallow_untyped_defs = True
[mypy-synapse.handlers.*]
disallow_untyped_defs = True
[mypy-synapse.http.federation.*]
disallow_untyped_defs = True
[mypy-synapse.http.request_metrics]
disallow_untyped_defs = True
[mypy-synapse.http.server]
disallow_untyped_defs = True
@ -196,12 +205,27 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.stream]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.transactions]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.user_erasure_store]
disallow_untyped_defs = True
[mypy-synapse.storage.prepare_database]
disallow_untyped_defs = True
[mypy-synapse.storage.persist_events]
disallow_untyped_defs = True
[mypy-synapse.storage.state]
disallow_untyped_defs = True
[mypy-synapse.storage.types]
disallow_untyped_defs = True
[mypy-synapse.storage.util.*]
disallow_untyped_defs = True

View File

@ -934,7 +934,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
# Before deleting the group lets kick everyone out of it
users = await self.store.get_users_in_group(group_id, include_private=True)
async def _kick_user_from_group(user_id):
async def _kick_user_from_group(user_id: str) -> None:
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
assert isinstance(

View File

@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IDelayedCall,
IHostResolution,
IReactorPluggableNameResolver,
IReactorTime,
IResolutionReceiver,
ITCPTransport,
)
@ -121,13 +123,15 @@ def check_against_blacklist(
_EPSILON = 0.00000001
def _make_scheduler(reactor):
def _make_scheduler(
reactor: IReactorTime,
) -> Callable[[Callable[[], object]], IDelayedCall]:
"""Makes a schedular suitable for a Cooperator using the given reactor.
(This is effectively just a copy from `twisted.internet.task`)
"""
def _scheduler(x):
def _scheduler(x: Callable[[], object]) -> IDelayedCall:
return reactor.callLater(_EPSILON, x)
return _scheduler
@ -775,7 +779,7 @@ class SimpleHttpClient:
)
def _timeout_to_request_timed_out_error(f: Failure):
def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
# The TCP connection has its own timeout (set by the 'connectTimeout' param
# on the Agent), which raises twisted_error.TimeoutError exception.
@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
def _maybe_fail(self):
def _maybe_fail(self) -> None:
"""
Report a max size exceed error and disconnect the first time this is called.
"""
@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
Do not use this since it allows an attacker to intercept your communications.
"""
def __init__(self):
def __init__(self) -> None:
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: False)
def getContext(self, hostname=None, port=None):
return self._context
def creatorForNetloc(self, hostname, port):
def creatorForNetloc(self, hostname: bytes, port: int):
return self

View File

@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory:
self._srv_resolver = srv_resolver
def endpointForURI(self, parsed_uri: URI):
def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint":
return MatrixHostnameEndpoint(
self._reactor,
self._proxy_reactor,

View File

@ -16,7 +16,7 @@
import logging
import random
import time
from typing import Callable, Dict, List
from typing import Any, Callable, Dict, List
import attr
@ -109,7 +109,7 @@ class SrvResolver:
def __init__(
self,
dns_client=client,
dns_client: Any = client,
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
get_time: Callable[[], float] = time.time,
):

View File

@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class WellKnownLookupResult:
delegated_server = attr.ib()
delegated_server: Optional[bytes]
class WellKnownResolver:
@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
class _FetchWellKnownFailure(Exception):
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
# a temporary failure.
temporary = attr.ib()
temporary: bool = attr.ib()

View File

@ -23,6 +23,8 @@ from http import HTTPStatus
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Dict,
Generic,
@ -44,7 +46,7 @@ from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
@ -58,11 +60,13 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
BodyExceededMaxSize,
ByteWriteable,
_make_scheduler,
encode_query_args,
read_body_with_max_size,
)
@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
CONTENT_TYPE = "application/json"
def __init__(self):
def __init__(self) -> None:
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)
@ -299,7 +303,9 @@ async def _handle_response(
class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly."""
def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"):
def __init__(
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
):
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file
@ -317,7 +323,11 @@ class MatrixFederationHttpClient:
requests.
"""
def __init__(self, hs: "HomeServer", tls_client_options_factory):
def __init__(
self,
hs: "HomeServer",
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
):
self.hs = hs
self.signing_key = hs.signing_key
self.server_name = hs.hostname
@ -348,10 +358,7 @@ class MatrixFederationHttpClient:
self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout = 60
def schedule(x):
self.reactor.callLater(_EPSILON, x)
self._cooperator = Cooperator(scheduler=schedule)
self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))
self._sleeper = AwakenableSleeper(self.reactor)
@ -364,7 +371,7 @@ class MatrixFederationHttpClient:
self,
request: MatrixFederationRequest,
try_trailing_slash_on_400: bool = False,
**send_request_args,
**send_request_args: Any,
) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
output_stream,
output_stream: BinaryIO,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient:
return length, headers
def _flatten_response_never_received(e):
def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"):
reasons = ", ".join(
_flatten_response_never_received(f.value) for f in e.reasons
_flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined]
)
return "%s:[%s]" % (type(e).__name__, reasons)

View File

@ -162,7 +162,7 @@ class RequestMetrics:
with _in_flight_requests_lock:
_in_flight_requests.add(self)
def stop(self, time_sec, response_code, sent_bytes):
def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None:
with _in_flight_requests_lock:
_in_flight_requests.discard(self)
@ -186,13 +186,13 @@ class RequestMetrics:
)
return
response_code = str(response_code)
response_code_str = str(response_code)
outgoing_responses_counter.labels(self.method, response_code).inc()
outgoing_responses_counter.labels(self.method, response_code_str).inc()
response_count.labels(self.method, self.name, tag).inc()
response_timer.labels(self.method, self.name, tag, response_code).observe(
response_timer.labels(self.method, self.name, tag, response_code_str).observe(
time_sec - self.start_ts
)
@ -221,7 +221,7 @@ class RequestMetrics:
# flight.
self.update_metrics()
def update_metrics(self):
def update_metrics(self) -> None:
"""Updates the in flight metrics with values from this request."""
if not self.start_context:
logger.error(

View File

@ -31,6 +31,7 @@ from typing import (
List,
Optional,
Tuple,
Type,
TypeVar,
cast,
overload,
@ -41,6 +42,7 @@ from prometheus_client import Histogram
from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.enterprise import adbapi
from twisted.internet.interfaces import IReactorCore
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
reactor: IReactorCore,
db_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database."""
@ -101,7 +105,7 @@ def make_pool(
db_args = dict(db_config.config.get("args", {}))
db_args.setdefault("cp_reconnect", True)
def _on_new_connection(conn):
def _on_new_connection(conn: Connection) -> None:
# Ensure we have a logging context so we can correctly track queries,
# etc.
with LoggingContext("db.on_new_connection"):
@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
default_txn_name: str
def cursor(
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
self,
*,
txn_name: Optional[str] = None,
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
) -> "LoggingTransaction":
if not txn_name:
txn_name = self.default_txn_name
@ -183,11 +191,16 @@ class LoggingDatabaseConnection:
self.conn.__enter__()
return self
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> Optional[bool]:
return self.conn.__exit__(exc_type, exc_value, traceback)
# Proxy through any unknown lookups to the DB conn class.
def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
return getattr(self.conn, name)
@ -391,17 +404,22 @@ class LoggingTransaction:
def __enter__(self) -> "LoggingTransaction":
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType],
) -> None:
self.close()
class PerformanceCounters:
def __init__(self):
self.current_counters = {}
self.previous_counters = {}
def __init__(self) -> None:
self.current_counters: Dict[str, Tuple[int, float]] = {}
self.previous_counters: Dict[str, Tuple[int, float]] = {}
def update(self, key: str, duration_secs: float) -> None:
count, cum_time = self.current_counters.get(key, (0, 0))
count, cum_time = self.current_counters.get(key, (0, 0.0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
@ -527,7 +545,7 @@ class DatabasePool:
def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()
def loop():
def loop() -> None:
curr = self._current_txn_total_time
prev = self._previous_txn_total_time
self._previous_txn_total_time = curr
@ -1186,7 +1204,7 @@ class DatabasePool:
if lock:
self.engine.lock_table(txn, table)
def _getwhere(key):
def _getwhere(key: str) -> str:
# If the value we're passing in is None (aka NULL), we need to use
# IS, not =, as NULL = NULL equals NULL (False).
if keyvalues[key] is None:
@ -2258,7 +2276,7 @@ class DatabasePool:
term: Optional[str],
col: str,
retcols: Collection[str],
desc="simple_search_list",
desc: str = "simple_search_list",
) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.

View File

@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.types import Cursor
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
self._last_user_visit_update = self._get_start_of_day()
@wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self):
async def _read_forward_extremities(self) -> None:
def fetch(txn):
txn.execute(
"""
@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
async def count_daily_e2ee_messages(self):
async def count_daily_e2ee_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.
@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
async def count_daily_sent_e2ee_messages(self):
async def count_daily_sent_e2ee_messages(self) -> int:
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_sent_e2ee_messages", _count_messages
)
async def count_daily_active_e2ee_rooms(self):
async def count_daily_active_e2ee_rooms(self) -> int:
def _count(txn):
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_active_e2ee_rooms", _count
)
async def count_daily_messages(self):
async def count_daily_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.
@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return await self.db_pool.runInteraction("count_messages", _count_messages)
async def count_daily_sent_messages(self):
async def count_daily_sent_messages(self) -> int:
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_sent_messages", _count_messages
)
async def count_daily_active_rooms(self):
async def count_daily_active_rooms(self) -> int:
def _count(txn):
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_monthly_users", self._count_users, thirty_days_ago
)
def _count_users(self, txn, time_from):
def _count_users(self, txn: Cursor, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) u
"""
txn.execute(sql, (time_from,))
(count,) = txn.fetchone()
# Mypy knows that fetchone() might return None if there are no rows.
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
# returns exactly one row.
(count,) = txn.fetchone() # type: ignore[misc]
return count
async def count_r30_users(self) -> Dict[str, int]:
@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_r30v2_users", _count_r30v2_users
)
def _get_start_of_day(self):
def _get_start_of_day(self) -> int:
"""
Returns millisecond unixtime for start of UTC day.
"""

View File

@ -798,9 +798,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self,
txn: LoggingTransaction,
event_id: str,
allow_none=False,
) -> int:
return self.db_pool.simple_select_one_onecol_txn(
allow_none: bool = False,
) -> Optional[int]:
# Type ignore: we pass keyvalues a Dict[str, str]; the function wants
# Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload]
txn=txn,
table="events",
keyvalues={"event_id": event_id},

View File

@ -25,6 +25,7 @@ from typing import (
Collection,
Deque,
Dict,
Generator,
Generic,
Iterable,
List,
@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
return res
def _handle_queue(self, room_id):
def _handle_queue(self, room_id: str) -> None:
"""Attempts to handle the queue for a room if not already being handled.
The queue's callback will be invoked with for each item in the queue,
@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
self._currently_persisting_rooms.add(room_id)
async def handle_queue_loop():
async def handle_queue_loop() -> None:
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
with PreserveLoggingContext():
item.deferred.callback(ret)
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
remaining_queue = self._event_persist_queues.pop(room_id, None)
if remaining_queue:
self._event_persist_queues[room_id] = remaining_queue
self._currently_persisting_rooms.discard(room_id)
# set handle_queue_loop off in the background
run_as_background_process("persist_events", handle_queue_loop)
def _get_drainining_queue(self, room_id):
def _get_drainining_queue(
self, room_id: str
) -> Generator[_EventPersistQueueItem, None, None]:
queue = self._event_persist_queues.setdefault(room_id, deque())
try:
@ -317,7 +320,9 @@ class EventsPersistenceStorage:
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
async def enqueue(item):
async def enqueue(
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
) -> Dict[str, str]:
room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
@ -1102,7 +1107,7 @@ class EventsPersistenceStorage:
return False
async def _handle_potentially_left_users(self, user_ids: Set[str]):
async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
"""Given a set of remote users check if the server still shares a room with
them. If not then mark those users' device cache as stale.
"""

View File

@ -85,7 +85,7 @@ def prepare_database(
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str] = ("main", "state"),
):
) -> None:
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.

View File

@ -62,7 +62,7 @@ class StateFilter:
types: "frozendict[str, Optional[FrozenSet[str]]]"
include_others: bool = False
def __attrs_post_init__(self):
def __attrs_post_init__(self) -> None:
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
@ -138,7 +138,9 @@ class StateFilter:
)
@staticmethod
def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
def freeze(
types: Mapping[str, Optional[Collection[str]]], include_others: bool
) -> "StateFilter":
"""
Returns a (frozen) StateFilter with the same contents as the parameters
specified here, which can be made of mutable types.

View File

@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
from types import TracebackType
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Protocol
@ -86,5 +87,10 @@ class Connection(Protocol):
def __enter__(self) -> "Connection":
...
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
...