mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-17 14:34:21 -05:00
Reduce the number of "untyped defs" (#12716)
This commit is contained in:
parent
de1e599b9d
commit
17e1eb7749
1
changelog.d/12716.misc
Normal file
1
changelog.d/12716.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add type annotations to increase the number of modules passing `disallow-untyped-defs`.
|
24
mypy.ini
24
mypy.ini
@ -119,9 +119,18 @@ disallow_untyped_defs = True
|
|||||||
[mypy-synapse.federation.transport.client]
|
[mypy-synapse.federation.transport.client]
|
||||||
disallow_untyped_defs = False
|
disallow_untyped_defs = False
|
||||||
|
|
||||||
|
[mypy-synapse.groups.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.handlers.*]
|
[mypy-synapse.handlers.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.http.federation.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.http.request_metrics]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.http.server]
|
[mypy-synapse.http.server]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
@ -196,12 +205,27 @@ disallow_untyped_defs = True
|
|||||||
[mypy-synapse.storage.databases.main.state_deltas]
|
[mypy-synapse.storage.databases.main.state_deltas]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.databases.main.stream]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.storage.databases.main.transactions]
|
[mypy-synapse.storage.databases.main.transactions]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.storage.databases.main.user_erasure_store]
|
[mypy-synapse.storage.databases.main.user_erasure_store]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.prepare_database]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.persist_events]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.state]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.types]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.storage.util.*]
|
[mypy-synapse.storage.util.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -934,7 +934,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||||||
# Before deleting the group lets kick everyone out of it
|
# Before deleting the group lets kick everyone out of it
|
||||||
users = await self.store.get_users_in_group(group_id, include_private=True)
|
users = await self.store.get_users_in_group(group_id, include_private=True)
|
||||||
|
|
||||||
async def _kick_user_from_group(user_id):
|
async def _kick_user_from_group(user_id: str) -> None:
|
||||||
if self.hs.is_mine_id(user_id):
|
if self.hs.is_mine_id(user_id):
|
||||||
groups_local = self.hs.get_groups_local_handler()
|
groups_local = self.hs.get_groups_local_handler()
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
|
@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl
|
|||||||
from twisted.internet.address import IPv4Address, IPv6Address
|
from twisted.internet.address import IPv4Address, IPv6Address
|
||||||
from twisted.internet.interfaces import (
|
from twisted.internet.interfaces import (
|
||||||
IAddress,
|
IAddress,
|
||||||
|
IDelayedCall,
|
||||||
IHostResolution,
|
IHostResolution,
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
|
IReactorTime,
|
||||||
IResolutionReceiver,
|
IResolutionReceiver,
|
||||||
ITCPTransport,
|
ITCPTransport,
|
||||||
)
|
)
|
||||||
@ -121,13 +123,15 @@ def check_against_blacklist(
|
|||||||
_EPSILON = 0.00000001
|
_EPSILON = 0.00000001
|
||||||
|
|
||||||
|
|
||||||
def _make_scheduler(reactor):
|
def _make_scheduler(
|
||||||
|
reactor: IReactorTime,
|
||||||
|
) -> Callable[[Callable[[], object]], IDelayedCall]:
|
||||||
"""Makes a schedular suitable for a Cooperator using the given reactor.
|
"""Makes a schedular suitable for a Cooperator using the given reactor.
|
||||||
|
|
||||||
(This is effectively just a copy from `twisted.internet.task`)
|
(This is effectively just a copy from `twisted.internet.task`)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _scheduler(x):
|
def _scheduler(x: Callable[[], object]) -> IDelayedCall:
|
||||||
return reactor.callLater(_EPSILON, x)
|
return reactor.callLater(_EPSILON, x)
|
||||||
|
|
||||||
return _scheduler
|
return _scheduler
|
||||||
@ -775,7 +779,7 @@ class SimpleHttpClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _timeout_to_request_timed_out_error(f: Failure):
|
def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
|
||||||
if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
|
if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
|
||||||
# The TCP connection has its own timeout (set by the 'connectTimeout' param
|
# The TCP connection has its own timeout (set by the 'connectTimeout' param
|
||||||
# on the Agent), which raises twisted_error.TimeoutError exception.
|
# on the Agent), which raises twisted_error.TimeoutError exception.
|
||||||
@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
|||||||
def __init__(self, deferred: defer.Deferred):
|
def __init__(self, deferred: defer.Deferred):
|
||||||
self.deferred = deferred
|
self.deferred = deferred
|
||||||
|
|
||||||
def _maybe_fail(self):
|
def _maybe_fail(self) -> None:
|
||||||
"""
|
"""
|
||||||
Report a max size exceed error and disconnect the first time this is called.
|
Report a max size exceed error and disconnect the first time this is called.
|
||||||
"""
|
"""
|
||||||
@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
|||||||
Do not use this since it allows an attacker to intercept your communications.
|
Do not use this since it allows an attacker to intercept your communications.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._context = SSL.Context(SSL.SSLv23_METHOD)
|
self._context = SSL.Context(SSL.SSLv23_METHOD)
|
||||||
self._context.set_verify(VERIFY_NONE, lambda *_: False)
|
self._context.set_verify(VERIFY_NONE, lambda *_: False)
|
||||||
|
|
||||||
def getContext(self, hostname=None, port=None):
|
def getContext(self, hostname=None, port=None):
|
||||||
return self._context
|
return self._context
|
||||||
|
|
||||||
def creatorForNetloc(self, hostname, port):
|
def creatorForNetloc(self, hostname: bytes, port: int):
|
||||||
return self
|
return self
|
||||||
|
@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory:
|
|||||||
|
|
||||||
self._srv_resolver = srv_resolver
|
self._srv_resolver = srv_resolver
|
||||||
|
|
||||||
def endpointForURI(self, parsed_uri: URI):
|
def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint":
|
||||||
return MatrixHostnameEndpoint(
|
return MatrixHostnameEndpoint(
|
||||||
self._reactor,
|
self._reactor,
|
||||||
self._proxy_reactor,
|
self._proxy_reactor,
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Dict, List
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ class SrvResolver:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dns_client=client,
|
dns_client: Any = client,
|
||||||
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
|
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
|
||||||
get_time: Callable[[], float] = time.time,
|
get_time: Callable[[], float] = time.time,
|
||||||
):
|
):
|
||||||
|
@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
|
|||||||
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
|
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, frozen=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
class WellKnownLookupResult:
|
class WellKnownLookupResult:
|
||||||
delegated_server = attr.ib()
|
delegated_server: Optional[bytes]
|
||||||
|
|
||||||
|
|
||||||
class WellKnownResolver:
|
class WellKnownResolver:
|
||||||
@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
|
|||||||
class _FetchWellKnownFailure(Exception):
|
class _FetchWellKnownFailure(Exception):
|
||||||
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
|
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
|
||||||
# a temporary failure.
|
# a temporary failure.
|
||||||
temporary = attr.ib()
|
temporary: bool = attr.ib()
|
||||||
|
@ -23,6 +23,8 @@ from http import HTTPStatus
|
|||||||
from io import BytesIO, StringIO
|
from io import BytesIO, StringIO
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
BinaryIO,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
@ -44,7 +46,7 @@ from typing_extensions import Literal
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.internet.interfaces import IReactorTime
|
from twisted.internet.interfaces import IReactorTime
|
||||||
from twisted.internet.task import _EPSILON, Cooperator
|
from twisted.internet.task import Cooperator
|
||||||
from twisted.web.client import ResponseFailed
|
from twisted.web.client import ResponseFailed
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web.iweb import IBodyProducer, IResponse
|
from twisted.web.iweb import IBodyProducer, IResponse
|
||||||
@ -58,11 +60,13 @@ from synapse.api.errors import (
|
|||||||
RequestSendFailed,
|
RequestSendFailed,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
from synapse.crypto.context_factory import FederationPolicyForHTTPS
|
||||||
from synapse.http import QuieterFileBodyProducer
|
from synapse.http import QuieterFileBodyProducer
|
||||||
from synapse.http.client import (
|
from synapse.http.client import (
|
||||||
BlacklistingAgentWrapper,
|
BlacklistingAgentWrapper,
|
||||||
BodyExceededMaxSize,
|
BodyExceededMaxSize,
|
||||||
ByteWriteable,
|
ByteWriteable,
|
||||||
|
_make_scheduler,
|
||||||
encode_query_args,
|
encode_query_args,
|
||||||
read_body_with_max_size,
|
read_body_with_max_size,
|
||||||
)
|
)
|
||||||
@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
|
|||||||
|
|
||||||
CONTENT_TYPE = "application/json"
|
CONTENT_TYPE = "application/json"
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._buffer = StringIO()
|
self._buffer = StringIO()
|
||||||
self._binary_wrapper = BinaryIOWrapper(self._buffer)
|
self._binary_wrapper = BinaryIOWrapper(self._buffer)
|
||||||
|
|
||||||
@ -299,7 +303,9 @@ async def _handle_response(
|
|||||||
class BinaryIOWrapper:
|
class BinaryIOWrapper:
|
||||||
"""A wrapper for a TextIO which converts from bytes on the fly."""
|
"""A wrapper for a TextIO which converts from bytes on the fly."""
|
||||||
|
|
||||||
def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"):
|
def __init__(
|
||||||
|
self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
|
||||||
|
):
|
||||||
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
|
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
|
||||||
self.file = file
|
self.file = file
|
||||||
|
|
||||||
@ -317,7 +323,11 @@ class MatrixFederationHttpClient:
|
|||||||
requests.
|
requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", tls_client_options_factory):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
|
||||||
|
):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.signing_key = hs.signing_key
|
self.signing_key = hs.signing_key
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
@ -348,10 +358,7 @@ class MatrixFederationHttpClient:
|
|||||||
self.version_string_bytes = hs.version_string.encode("ascii")
|
self.version_string_bytes = hs.version_string.encode("ascii")
|
||||||
self.default_timeout = 60
|
self.default_timeout = 60
|
||||||
|
|
||||||
def schedule(x):
|
self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))
|
||||||
self.reactor.callLater(_EPSILON, x)
|
|
||||||
|
|
||||||
self._cooperator = Cooperator(scheduler=schedule)
|
|
||||||
|
|
||||||
self._sleeper = AwakenableSleeper(self.reactor)
|
self._sleeper = AwakenableSleeper(self.reactor)
|
||||||
|
|
||||||
@ -364,7 +371,7 @@ class MatrixFederationHttpClient:
|
|||||||
self,
|
self,
|
||||||
request: MatrixFederationRequest,
|
request: MatrixFederationRequest,
|
||||||
try_trailing_slash_on_400: bool = False,
|
try_trailing_slash_on_400: bool = False,
|
||||||
**send_request_args,
|
**send_request_args: Any,
|
||||||
) -> IResponse:
|
) -> IResponse:
|
||||||
"""Wrapper for _send_request which can optionally retry the request
|
"""Wrapper for _send_request which can optionally retry the request
|
||||||
upon receiving a combination of a 400 HTTP response code and a
|
upon receiving a combination of a 400 HTTP response code and a
|
||||||
@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient:
|
|||||||
self,
|
self,
|
||||||
destination: str,
|
destination: str,
|
||||||
path: str,
|
path: str,
|
||||||
output_stream,
|
output_stream: BinaryIO,
|
||||||
args: Optional[QueryParams] = None,
|
args: Optional[QueryParams] = None,
|
||||||
retry_on_dns_fail: bool = True,
|
retry_on_dns_fail: bool = True,
|
||||||
max_size: Optional[int] = None,
|
max_size: Optional[int] = None,
|
||||||
@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient:
|
|||||||
return length, headers
|
return length, headers
|
||||||
|
|
||||||
|
|
||||||
def _flatten_response_never_received(e):
|
def _flatten_response_never_received(e: BaseException) -> str:
|
||||||
if hasattr(e, "reasons"):
|
if hasattr(e, "reasons"):
|
||||||
reasons = ", ".join(
|
reasons = ", ".join(
|
||||||
_flatten_response_never_received(f.value) for f in e.reasons
|
_flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
|
|
||||||
return "%s:[%s]" % (type(e).__name__, reasons)
|
return "%s:[%s]" % (type(e).__name__, reasons)
|
||||||
|
@ -162,7 +162,7 @@ class RequestMetrics:
|
|||||||
with _in_flight_requests_lock:
|
with _in_flight_requests_lock:
|
||||||
_in_flight_requests.add(self)
|
_in_flight_requests.add(self)
|
||||||
|
|
||||||
def stop(self, time_sec, response_code, sent_bytes):
|
def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None:
|
||||||
with _in_flight_requests_lock:
|
with _in_flight_requests_lock:
|
||||||
_in_flight_requests.discard(self)
|
_in_flight_requests.discard(self)
|
||||||
|
|
||||||
@ -186,13 +186,13 @@ class RequestMetrics:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
response_code = str(response_code)
|
response_code_str = str(response_code)
|
||||||
|
|
||||||
outgoing_responses_counter.labels(self.method, response_code).inc()
|
outgoing_responses_counter.labels(self.method, response_code_str).inc()
|
||||||
|
|
||||||
response_count.labels(self.method, self.name, tag).inc()
|
response_count.labels(self.method, self.name, tag).inc()
|
||||||
|
|
||||||
response_timer.labels(self.method, self.name, tag, response_code).observe(
|
response_timer.labels(self.method, self.name, tag, response_code_str).observe(
|
||||||
time_sec - self.start_ts
|
time_sec - self.start_ts
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -221,7 +221,7 @@ class RequestMetrics:
|
|||||||
# flight.
|
# flight.
|
||||||
self.update_metrics()
|
self.update_metrics()
|
||||||
|
|
||||||
def update_metrics(self):
|
def update_metrics(self) -> None:
|
||||||
"""Updates the in flight metrics with values from this request."""
|
"""Updates the in flight metrics with values from this request."""
|
||||||
if not self.start_context:
|
if not self.start_context:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
@ -31,6 +31,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
cast,
|
cast,
|
||||||
overload,
|
overload,
|
||||||
@ -41,6 +42,7 @@ from prometheus_client import Histogram
|
|||||||
from typing_extensions import Concatenate, Literal, ParamSpec
|
from typing_extensions import Concatenate, Literal, ParamSpec
|
||||||
|
|
||||||
from twisted.enterprise import adbapi
|
from twisted.enterprise import adbapi
|
||||||
|
from twisted.internet.interfaces import IReactorCore
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.config.database import DatabaseConnectionConfig
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
|
|||||||
|
|
||||||
|
|
||||||
def make_pool(
|
def make_pool(
|
||||||
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
reactor: IReactorCore,
|
||||||
|
db_config: DatabaseConnectionConfig,
|
||||||
|
engine: BaseDatabaseEngine,
|
||||||
) -> adbapi.ConnectionPool:
|
) -> adbapi.ConnectionPool:
|
||||||
"""Get the connection pool for the database."""
|
"""Get the connection pool for the database."""
|
||||||
|
|
||||||
@ -101,7 +105,7 @@ def make_pool(
|
|||||||
db_args = dict(db_config.config.get("args", {}))
|
db_args = dict(db_config.config.get("args", {}))
|
||||||
db_args.setdefault("cp_reconnect", True)
|
db_args.setdefault("cp_reconnect", True)
|
||||||
|
|
||||||
def _on_new_connection(conn):
|
def _on_new_connection(conn: Connection) -> None:
|
||||||
# Ensure we have a logging context so we can correctly track queries,
|
# Ensure we have a logging context so we can correctly track queries,
|
||||||
# etc.
|
# etc.
|
||||||
with LoggingContext("db.on_new_connection"):
|
with LoggingContext("db.on_new_connection"):
|
||||||
@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
|
|||||||
default_txn_name: str
|
default_txn_name: str
|
||||||
|
|
||||||
def cursor(
|
def cursor(
|
||||||
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
|
self,
|
||||||
|
*,
|
||||||
|
txn_name: Optional[str] = None,
|
||||||
|
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
||||||
|
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
||||||
) -> "LoggingTransaction":
|
) -> "LoggingTransaction":
|
||||||
if not txn_name:
|
if not txn_name:
|
||||||
txn_name = self.default_txn_name
|
txn_name = self.default_txn_name
|
||||||
@ -183,11 +191,16 @@ class LoggingDatabaseConnection:
|
|||||||
self.conn.__enter__()
|
self.conn.__enter__()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc_value: Optional[BaseException],
|
||||||
|
traceback: Optional[types.TracebackType],
|
||||||
|
) -> Optional[bool]:
|
||||||
return self.conn.__exit__(exc_type, exc_value, traceback)
|
return self.conn.__exit__(exc_type, exc_value, traceback)
|
||||||
|
|
||||||
# Proxy through any unknown lookups to the DB conn class.
|
# Proxy through any unknown lookups to the DB conn class.
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name: str) -> Any:
|
||||||
return getattr(self.conn, name)
|
return getattr(self.conn, name)
|
||||||
|
|
||||||
|
|
||||||
@ -391,17 +404,22 @@ class LoggingTransaction:
|
|||||||
def __enter__(self) -> "LoggingTransaction":
|
def __enter__(self) -> "LoggingTransaction":
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc_value: Optional[BaseException],
|
||||||
|
traceback: Optional[types.TracebackType],
|
||||||
|
) -> None:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
|
||||||
class PerformanceCounters:
|
class PerformanceCounters:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.current_counters = {}
|
self.current_counters: Dict[str, Tuple[int, float]] = {}
|
||||||
self.previous_counters = {}
|
self.previous_counters: Dict[str, Tuple[int, float]] = {}
|
||||||
|
|
||||||
def update(self, key: str, duration_secs: float) -> None:
|
def update(self, key: str, duration_secs: float) -> None:
|
||||||
count, cum_time = self.current_counters.get(key, (0, 0))
|
count, cum_time = self.current_counters.get(key, (0, 0.0))
|
||||||
count += 1
|
count += 1
|
||||||
cum_time += duration_secs
|
cum_time += duration_secs
|
||||||
self.current_counters[key] = (count, cum_time)
|
self.current_counters[key] = (count, cum_time)
|
||||||
@ -527,7 +545,7 @@ class DatabasePool:
|
|||||||
def start_profiling(self) -> None:
|
def start_profiling(self) -> None:
|
||||||
self._previous_loop_ts = monotonic_time()
|
self._previous_loop_ts = monotonic_time()
|
||||||
|
|
||||||
def loop():
|
def loop() -> None:
|
||||||
curr = self._current_txn_total_time
|
curr = self._current_txn_total_time
|
||||||
prev = self._previous_txn_total_time
|
prev = self._previous_txn_total_time
|
||||||
self._previous_txn_total_time = curr
|
self._previous_txn_total_time = curr
|
||||||
@ -1186,7 +1204,7 @@ class DatabasePool:
|
|||||||
if lock:
|
if lock:
|
||||||
self.engine.lock_table(txn, table)
|
self.engine.lock_table(txn, table)
|
||||||
|
|
||||||
def _getwhere(key):
|
def _getwhere(key: str) -> str:
|
||||||
# If the value we're passing in is None (aka NULL), we need to use
|
# If the value we're passing in is None (aka NULL), we need to use
|
||||||
# IS, not =, as NULL = NULL equals NULL (False).
|
# IS, not =, as NULL = NULL equals NULL (False).
|
||||||
if keyvalues[key] is None:
|
if keyvalues[key] is None:
|
||||||
@ -2258,7 +2276,7 @@ class DatabasePool:
|
|||||||
term: Optional[str],
|
term: Optional[str],
|
||||||
col: str,
|
col: str,
|
||||||
retcols: Collection[str],
|
retcols: Collection[str],
|
||||||
desc="simple_search_list",
|
desc: str = "simple_search_list",
|
||||||
) -> Optional[List[Dict[str, Any]]]:
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
"""Executes a SELECT query on the named table, which may return zero or
|
"""Executes a SELECT query on the named table, which may return zero or
|
||||||
more rows, returning the result as a list of dicts.
|
more rows, returning the result as a list of dicts.
|
||||||
|
@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
|||||||
from synapse.storage.databases.main.event_push_actions import (
|
from synapse.storage.databases.main.event_push_actions import (
|
||||||
EventPushActionsWorkerStore,
|
EventPushActionsWorkerStore,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.types import Cursor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||||||
self._last_user_visit_update = self._get_start_of_day()
|
self._last_user_visit_update = self._get_start_of_day()
|
||||||
|
|
||||||
@wrap_as_background_process("read_forward_extremities")
|
@wrap_as_background_process("read_forward_extremities")
|
||||||
async def _read_forward_extremities(self):
|
async def _read_forward_extremities(self) -> None:
|
||||||
def fetch(txn):
|
def fetch(txn):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||||||
(x[0] - 1) * x[1] for x in res if x[1]
|
(x[0] - 1) * x[1] for x in res if x[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
async def count_daily_e2ee_messages(self):
|
async def count_daily_e2ee_messages(self) -> int:
|
||||||
"""
|
"""
|
||||||
Returns an estimate of the number of messages sent in the last day.
|
Returns an estimate of the number of messages sent in the last day.
|
||||||
|
|
||||||
@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||||||
|
|
||||||
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
|
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
|
||||||
|
|
||||||
async def count_daily_sent_e2ee_messages(self):
|
async def count_daily_sent_e2ee_messages(self) -> int:
|
||||||
def _count_messages(txn):
|
def _count_messages(txn):
|
||||||
# This is good enough as if you have silly characters in your own
|
# This is good enough as if you have silly characters in your own
|
||||||
# hostname then that's your own fault.
|
# hostname then that's your own fault.
|
||||||
@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||||||
"count_daily_sent_e2ee_messages", _count_messages
|
"count_daily_sent_e2ee_messages", _count_messages
|
||||||
)
|
)
|
||||||
|
|
||||||
async def count_daily_active_e2ee_rooms(self):
|
async def count_daily_active_e2ee_rooms(self) -> int:
|
||||||
def _count(txn):
|
def _count(txn):
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COUNT(DISTINCT room_id) FROM events
|
SELECT COUNT(DISTINCT room_id) FROM events
|
||||||
@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||||||
"count_daily_active_e2ee_rooms", _count
|
"count_daily_active_e2ee_rooms", _count
|
||||||
)
|
)
|
||||||
|
|
||||||
async def count_daily_messages(self):
|
async def count_daily_messages(self) -> int:
|
||||||
"""
|
"""
|
||||||
Returns an estimate of the number of messages sent in the last day.
|
Returns an estimate of the number of messages sent in the last day.
|
||||||
|
|
||||||
@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||||||
|
|
||||||
return await self.db_pool.runInteraction("count_messages", _count_messages)
|
return await self.db_pool.runInteraction("count_messages", _count_messages)
|
||||||
|
|
||||||
async def count_daily_sent_messages(self):
|
async def count_daily_sent_messages(self) -> int:
|
||||||
def _count_messages(txn):
|
def _count_messages(txn):
|
||||||
# This is good enough as if you have silly characters in your own
|
# This is good enough as if you have silly characters in your own
|
||||||
# hostname then that's your own fault.
|
# hostname then that's your own fault.
|
||||||
@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||||||
"count_daily_sent_messages", _count_messages
|
"count_daily_sent_messages", _count_messages
|
||||||
)
|
)
|
||||||
|
|
||||||
async def count_daily_active_rooms(self):
|
async def count_daily_active_rooms(self) -> int:
|
||||||
def _count(txn):
|
def _count(txn):
|
||||||
sql = """
|
sql = """
|
||||||
SELECT COUNT(DISTINCT room_id) FROM events
|
SELECT COUNT(DISTINCT room_id) FROM events
|
||||||
@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||||||
"count_monthly_users", self._count_users, thirty_days_ago
|
"count_monthly_users", self._count_users, thirty_days_ago
|
||||||
)
|
)
|
||||||
|
|
||||||
def _count_users(self, txn, time_from):
|
def _count_users(self, txn: Cursor, time_from: int) -> int:
|
||||||
"""
|
"""
|
||||||
Returns number of users seen in the past time_from period
|
Returns number of users seen in the past time_from period
|
||||||
"""
|
"""
|
||||||
@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||||||
) u
|
) u
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (time_from,))
|
txn.execute(sql, (time_from,))
|
||||||
(count,) = txn.fetchone()
|
# Mypy knows that fetchone() might return None if there are no rows.
|
||||||
|
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
|
||||||
|
# returns exactly one row.
|
||||||
|
(count,) = txn.fetchone() # type: ignore[misc]
|
||||||
return count
|
return count
|
||||||
|
|
||||||
async def count_r30_users(self) -> Dict[str, int]:
|
async def count_r30_users(self) -> Dict[str, int]:
|
||||||
@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
|||||||
"count_r30v2_users", _count_r30v2_users
|
"count_r30v2_users", _count_r30v2_users
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_start_of_day(self):
|
def _get_start_of_day(self) -> int:
|
||||||
"""
|
"""
|
||||||
Returns millisecond unixtime for start of UTC day.
|
Returns millisecond unixtime for start of UTC day.
|
||||||
"""
|
"""
|
||||||
|
@ -798,9 +798,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||||||
self,
|
self,
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
allow_none=False,
|
allow_none: bool = False,
|
||||||
) -> int:
|
) -> Optional[int]:
|
||||||
return self.db_pool.simple_select_one_onecol_txn(
|
# Type ignore: we pass keyvalues a Dict[str, str]; the function wants
|
||||||
|
# Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
|
||||||
|
return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload]
|
||||||
txn=txn,
|
txn=txn,
|
||||||
table="events",
|
table="events",
|
||||||
keyvalues={"event_id": event_id},
|
keyvalues={"event_id": event_id},
|
||||||
|
@ -25,6 +25,7 @@ from typing import (
|
|||||||
Collection,
|
Collection,
|
||||||
Deque,
|
Deque,
|
||||||
Dict,
|
Dict,
|
||||||
|
Generator,
|
||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _handle_queue(self, room_id):
|
def _handle_queue(self, room_id: str) -> None:
|
||||||
"""Attempts to handle the queue for a room if not already being handled.
|
"""Attempts to handle the queue for a room if not already being handled.
|
||||||
|
|
||||||
The queue's callback will be invoked with for each item in the queue,
|
The queue's callback will be invoked with for each item in the queue,
|
||||||
@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
|||||||
|
|
||||||
self._currently_persisting_rooms.add(room_id)
|
self._currently_persisting_rooms.add(room_id)
|
||||||
|
|
||||||
async def handle_queue_loop():
|
async def handle_queue_loop() -> None:
|
||||||
try:
|
try:
|
||||||
queue = self._get_drainining_queue(room_id)
|
queue = self._get_drainining_queue(room_id)
|
||||||
for item in queue:
|
for item in queue:
|
||||||
@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
|
|||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
item.deferred.callback(ret)
|
item.deferred.callback(ret)
|
||||||
finally:
|
finally:
|
||||||
queue = self._event_persist_queues.pop(room_id, None)
|
remaining_queue = self._event_persist_queues.pop(room_id, None)
|
||||||
if queue:
|
if remaining_queue:
|
||||||
self._event_persist_queues[room_id] = queue
|
self._event_persist_queues[room_id] = remaining_queue
|
||||||
self._currently_persisting_rooms.discard(room_id)
|
self._currently_persisting_rooms.discard(room_id)
|
||||||
|
|
||||||
# set handle_queue_loop off in the background
|
# set handle_queue_loop off in the background
|
||||||
run_as_background_process("persist_events", handle_queue_loop)
|
run_as_background_process("persist_events", handle_queue_loop)
|
||||||
|
|
||||||
def _get_drainining_queue(self, room_id):
|
def _get_drainining_queue(
|
||||||
|
self, room_id: str
|
||||||
|
) -> Generator[_EventPersistQueueItem, None, None]:
|
||||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -317,7 +320,9 @@ class EventsPersistenceStorage:
|
|||||||
for event, ctx in events_and_contexts:
|
for event, ctx in events_and_contexts:
|
||||||
partitioned.setdefault(event.room_id, []).append((event, ctx))
|
partitioned.setdefault(event.room_id, []).append((event, ctx))
|
||||||
|
|
||||||
async def enqueue(item):
|
async def enqueue(
|
||||||
|
item: Tuple[str, List[Tuple[EventBase, EventContext]]]
|
||||||
|
) -> Dict[str, str]:
|
||||||
room_id, evs_ctxs = item
|
room_id, evs_ctxs = item
|
||||||
return await self._event_persist_queue.add_to_queue(
|
return await self._event_persist_queue.add_to_queue(
|
||||||
room_id, evs_ctxs, backfilled=backfilled
|
room_id, evs_ctxs, backfilled=backfilled
|
||||||
@ -1102,7 +1107,7 @@ class EventsPersistenceStorage:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _handle_potentially_left_users(self, user_ids: Set[str]):
|
async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
|
||||||
"""Given a set of remote users check if the server still shares a room with
|
"""Given a set of remote users check if the server still shares a room with
|
||||||
them. If not then mark those users' device cache as stale.
|
them. If not then mark those users' device cache as stale.
|
||||||
"""
|
"""
|
||||||
|
@ -85,7 +85,7 @@ def prepare_database(
|
|||||||
database_engine: BaseDatabaseEngine,
|
database_engine: BaseDatabaseEngine,
|
||||||
config: Optional[HomeServerConfig],
|
config: Optional[HomeServerConfig],
|
||||||
databases: Collection[str] = ("main", "state"),
|
databases: Collection[str] = ("main", "state"),
|
||||||
):
|
) -> None:
|
||||||
"""Prepares a physical database for usage. Will either create all necessary tables
|
"""Prepares a physical database for usage. Will either create all necessary tables
|
||||||
or upgrade from an older schema version.
|
or upgrade from an older schema version.
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ class StateFilter:
|
|||||||
types: "frozendict[str, Optional[FrozenSet[str]]]"
|
types: "frozendict[str, Optional[FrozenSet[str]]]"
|
||||||
include_others: bool = False
|
include_others: bool = False
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self) -> None:
|
||||||
# If `include_others` is set we canonicalise the filter by removing
|
# If `include_others` is set we canonicalise the filter by removing
|
||||||
# wildcards from the types dictionary
|
# wildcards from the types dictionary
|
||||||
if self.include_others:
|
if self.include_others:
|
||||||
@ -138,7 +138,9 @@ class StateFilter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
|
def freeze(
|
||||||
|
types: Mapping[str, Optional[Collection[str]]], include_others: bool
|
||||||
|
) -> "StateFilter":
|
||||||
"""
|
"""
|
||||||
Returns a (frozen) StateFilter with the same contents as the parameters
|
Returns a (frozen) StateFilter with the same contents as the parameters
|
||||||
specified here, which can be made of mutable types.
|
specified here, which can be made of mutable types.
|
||||||
|
@ -11,7 +11,8 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
|
from types import TracebackType
|
||||||
|
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
@ -86,5 +87,10 @@ class Connection(Protocol):
|
|||||||
def __enter__(self) -> "Connection":
|
def __enter__(self) -> "Connection":
|
||||||
...
|
...
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional[Type[BaseException]],
|
||||||
|
exc_value: Optional[BaseException],
|
||||||
|
traceback: Optional[TracebackType],
|
||||||
|
) -> Optional[bool]:
|
||||||
...
|
...
|
||||||
|
Loading…
Reference in New Issue
Block a user