Fix additional type hints from Twisted upgrade. (#9518)

This commit is contained in:
Patrick Cloke 2021-03-03 15:47:38 -05:00 committed by GitHub
parent 4db07f9aef
commit 33a02f0f52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 96 additions and 61 deletions

1
changelog.d/9518.misc Normal file
View File

@ -0,0 +1 @@
Fix incorrect type hints.

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib.parse import urllib.parse
from typing import List, Optional from typing import Any, Generator, List, Optional
from netaddr import AddrFormatError, IPAddress, IPSet from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer from zope.interface import implementer
@ -116,7 +116,7 @@ class MatrixFederationAgent:
uri: bytes, uri: bytes,
headers: Optional[Headers] = None, headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None, bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred: ) -> Generator[defer.Deferred, Any, defer.Deferred]:
""" """
Args: Args:
method: HTTP method: GET/POST/etc method: HTTP method: GET/POST/etc
@ -177,17 +177,17 @@ class MatrixFederationAgent:
# We need to make sure the host header is set to the netloc of the # We need to make sure the host header is set to the netloc of the
# server and that a user-agent is provided. # server and that a user-agent is provided.
if headers is None: if headers is None:
headers = Headers() request_headers = Headers()
else: else:
headers = headers.copy() request_headers = headers.copy()
if not headers.hasHeader(b"host"): if not request_headers.hasHeader(b"host"):
headers.addRawHeader(b"host", parsed_uri.netloc) request_headers.addRawHeader(b"host", parsed_uri.netloc)
if not headers.hasHeader(b"user-agent"): if not request_headers.hasHeader(b"user-agent"):
headers.addRawHeader(b"user-agent", self.user_agent) request_headers.addRawHeader(b"user-agent", self.user_agent)
res = yield make_deferred_yieldable( res = yield make_deferred_yieldable(
self._agent.request(method, uri, headers, bodyProducer) self._agent.request(method, uri, request_headers, bodyProducer)
) )
return res return res

View File

@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
RequestSendFailed: if the Content-Type header is missing or isn't JSON RequestSendFailed: if the Content-Type header is missing or isn't JSON
""" """
c_type = headers.getRawHeaders(b"Content-Type") content_type_headers = headers.getRawHeaders(b"Content-Type")
if c_type is None: if content_type_headers is None:
raise RequestSendFailed( raise RequestSendFailed(
RuntimeError("No Content-Type header received from remote server"), RuntimeError("No Content-Type header received from remote server"),
can_retry=False, can_retry=False,
) )
c_type = c_type[0].decode("ascii") # only the first header c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type) val, options = cgi.parse_header(c_type)
if val != "application/json": if val != "application/json":
raise RequestSendFailed( raise RequestSendFailed(

View File

@ -21,6 +21,7 @@ import logging
import types import types
import urllib import urllib
from http import HTTPStatus from http import HTTPStatus
from inspect import isawaitable
from io import BytesIO from io import BytesIO
from typing import ( from typing import (
Any, Any,
@ -30,6 +31,7 @@ from typing import (
Iterable, Iterable,
Iterator, Iterator,
List, List,
Optional,
Pattern, Pattern,
Tuple, Tuple,
Union, Union,
@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients.""" """Sends a JSON error response to clients."""
if f.check(SynapseError): if f.check(SynapseError):
error_code = f.value.code # mypy doesn't understand that f.check asserts the type.
error_dict = f.value.error_dict() exc = f.value # type: SynapseError # type: ignore
error_code = exc.code
error_dict = exc.error_dict()
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg) logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
else: else:
error_code = 500 error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN} error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"Failed handle request via %r: %r", "Failed handle request via %r: %r",
request.request_metrics.name, request.request_metrics.name,
request, request,
exc_info=(f.type, f.value, f.getTracebackObject()), exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
) )
# Only respond with an error response if we haven't already started writing, # Only respond with an error response if we haven't already started writing,
@ -128,7 +132,8 @@ def return_html_error(
`{msg}` placeholders), or a jinja2 template `{msg}` placeholders), or a jinja2 template
""" """
if f.check(CodeMessageException): if f.check(CodeMessageException):
cme = f.value # mypy doesn't understand that f.check asserts the type.
cme = f.value # type: CodeMessageException # type: ignore
code = cme.code code = cme.code
msg = cme.msg msg = cme.msg
@ -142,7 +147,7 @@ def return_html_error(
logger.error( logger.error(
"Failed handle request %r", "Failed handle request %r",
request, request,
exc_info=(f.type, f.value, f.getTracebackObject()), exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
) )
else: else:
code = HTTPStatus.INTERNAL_SERVER_ERROR code = HTTPStatus.INTERNAL_SERVER_ERROR
@ -151,7 +156,7 @@ def return_html_error(
logger.error( logger.error(
"Failed handle request %r", "Failed handle request %r",
request, request,
exc_info=(f.type, f.value, f.getTracebackObject()), exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
) )
if isinstance(error_template, str): if isinstance(error_template, str):
@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
raw_callback_return = method_handler(request) raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now. # Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): if isawaitable(raw_callback_return):
callback_return = await raw_callback_return callback_return = await raw_callback_return
else: else:
callback_return = raw_callback_return # type: ignore callback_return = raw_callback_return # type: ignore
@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
A tuple of the callback to use, the name of the servlet, and the A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback key word arguments to pass to the callback
""" """
# At this point the path must be bytes.
request_path_bytes = request.path # type: bytes # type: ignore
request_path = request_path_bytes.decode("ascii")
# Treat HEAD requests as GET requests. # Treat HEAD requests as GET requests.
request_path = request.path.decode("ascii")
request_method = request.method request_method = request.method
if request_method == b"HEAD": if request_method == b"HEAD":
request_method = b"GET" request_method = b"GET"
@ -551,7 +558,7 @@ class _ByteProducer:
request: Request, request: Request,
iterator: Iterator[bytes], iterator: Iterator[bytes],
): ):
self._request = request self._request = request # type: Optional[Request]
self._iterator = iterator self._iterator = iterator
self._paused = False self._paused = False
@ -563,7 +570,7 @@ class _ByteProducer:
""" """
Send a list of bytes as a chunk of a response. Send a list of bytes as a chunk of a response.
""" """
if not data: if not data or not self._request:
return return
self._request.write(b"".join(data)) self._request.write(b"".join(data))

View File

@ -14,7 +14,7 @@
import contextlib import contextlib
import logging import logging
import time import time
from typing import Optional, Union from typing import Optional, Type, Union
import attr import attr
from zope.interface import implementer from zope.interface import implementer
@ -57,7 +57,7 @@ class SynapseRequest(Request):
def __init__(self, channel, *args, **kw): def __init__(self, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw) Request.__init__(self, channel, *args, **kw)
self.site = channel.site self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests self._channel = channel # this is used by the tests
self.start_time = 0.0 self.start_time = 0.0
@ -96,25 +96,34 @@ class SynapseRequest(Request):
def get_request_id(self): def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq) return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self): def get_redacted_uri(self) -> str:
uri = self.uri """Gets the redacted URI associated with the request (or placeholder if the URI
has not yet been received).
Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.uri`.
Returns:
The redacted URI as a string.
"""
uri = self.uri # type: Union[bytes, str]
if isinstance(uri, bytes): if isinstance(uri, bytes):
uri = self.uri.decode("ascii", errors="replace") uri = uri.decode("ascii", errors="replace")
return redact_uri(uri) return redact_uri(uri)
def get_method(self): def get_method(self) -> str:
"""Gets the method associated with the request (or placeholder if not """Gets the method associated with the request (or placeholder if method
method has yet been received). has not yet been received).
Note: This is necessary as the placeholder value in twisted is str Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.method`. rather than bytes, so we need to sanitise `self.method`.
Returns: Returns:
str The request method as a string.
""" """
method = self.method method = self.method # type: Union[bytes, str]
if isinstance(method, bytes): if isinstance(method, bytes):
method = self.method.decode("ascii") return self.method.decode("ascii")
return method return method
def render(self, resrc): def render(self, resrc):
@ -432,7 +441,9 @@ class SynapseSite(Site):
assert config.http_options is not None assert config.http_options is not None
proxied = config.http_options.x_forwarded proxied = config.http_options.x_forwarded
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest self.requestFactory = (
XForwardedForRequest if proxied else SynapseRequest
) # type: Type[Request]
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii") self.server_version_string = server_version_string.encode("ascii")

View File

@ -32,7 +32,7 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint, TCP4ClientEndpoint,
TCP6ClientEndpoint, TCP6ClientEndpoint,
) )
from twisted.internet.interfaces import IPushProducer, ITransport from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
from twisted.internet.protocol import Factory, Protocol from twisted.internet.protocol import Factory, Protocol
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -121,7 +121,9 @@ class RemoteHandler(logging.Handler):
try: try:
ip = ip_address(self.host) ip = ip_address(self.host)
if isinstance(ip, IPv4Address): if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port) endpoint = TCP4ClientEndpoint(
_reactor, self.host, self.port
) # type: IStreamClientEndpoint
elif isinstance(ip, IPv6Address): elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port) endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else: else:

View File

@ -527,7 +527,7 @@ class ReactorLastSeenMetric:
REGISTRY.register(ReactorLastSeenMetric()) REGISTRY.register(ReactorLastSeenMetric())
def runUntilCurrentTimer(func): def runUntilCurrentTimer(reactor, func):
@functools.wraps(func) @functools.wraps(func)
def f(*args, **kwargs): def f(*args, **kwargs):
now = reactor.seconds() now = reactor.seconds()
@ -590,13 +590,14 @@ def runUntilCurrentTimer(func):
try: try:
# Ensure the reactor has all the attributes we expect # Ensure the reactor has all the attributes we expect
reactor.runUntilCurrent reactor.seconds # type: ignore
reactor._newTimedCalls reactor.runUntilCurrent # type: ignore
reactor.threadCallQueue reactor._newTimedCalls # type: ignore
reactor.threadCallQueue # type: ignore
# runUntilCurrent is called when we have pending calls. It is called once # runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling. # per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent) reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore
# We manually run the GC each reactor tick so that we can get some metrics # We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC, # about time spent doing GC,

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Iterable, Optional, Tuple from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -307,7 +307,7 @@ class ModuleApi:
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events_in_room( def get_state_events_in_room(
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]] self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
) -> defer.Deferred: ) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""Gets current state events for the given room. """Gets current state events for the given room.
(This is exposed for compatibility with the old SpamCheckerApi. We should (This is exposed for compatibility with the old SpamCheckerApi. We should

View File

@ -15,11 +15,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib.parse import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from twisted.internet.interfaces import IDelayedCall
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase from synapse.events import EventBase
@ -71,7 +72,7 @@ class HttpPusher(Pusher):
self.data = pusher_config.data self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusher_config.failing_since self.failing_since = pusher_config.failing_since
self.timed_call = None self.timed_call = None # type: Optional[IDelayedCall]
self._is_processing = False self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
self._pusherpool = hs.get_pusherpool() self._pusherpool = hs.get_pusherpool()

View File

@ -108,9 +108,7 @@ class ReplicationDataHandler:
# Map from stream to list of deferreds waiting for the stream to # Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position. # arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters = ( self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]
{}
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
async def on_rdata( async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list self, stream_name: str, instance_name: str, token: int, rows: list

View File

@ -38,6 +38,7 @@ from typing import (
import twisted.internet.base import twisted.internet.base
import twisted.internet.tcp import twisted.internet.tcp
from twisted.internet import defer
from twisted.mail.smtp import sendmail from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS from twisted.web.iweb import IPolicyForHTTPS
@ -403,7 +404,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return RoomShutdownHandler(self) return RoomShutdownHandler(self)
@cache_in_self @cache_in_self
def get_sendmail(self) -> sendmail: def get_sendmail(self) -> Callable[..., defer.Deferred]:
return sendmail return sendmail
@cache_in_self @cache_in_self

View File

@ -522,7 +522,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
cas_uri = channel.headers.getRawHeaders("Location")[0] location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
cas_uri = location_headers[0]
cas_uri_path, cas_uri_query = cas_uri.split("?", 1) cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
# it should redirect us to the login page of the cas server # it should redirect us to the login page of the cas server
@ -545,7 +547,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=saml", + "&idp=saml",
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
saml_uri = channel.headers.getRawHeaders("Location")[0] location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
saml_uri = location_headers[0]
saml_uri_path, saml_uri_query = saml_uri.split("?", 1) saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
# it should redirect us to the login page of the SAML server # it should redirect us to the login page of the SAML server
@ -567,17 +571,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=oidc", + "&idp=oidc",
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0] location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server # it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
# ... and should have set a cookie including the redirect url # ... and should have set a cookie including the redirect url
cookies = dict( cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
h.split(";")[0].split("=", maxsplit=1) assert cookie_headers
for h in channel.headers.getRawHeaders("Set-Cookie") cookies = {} # type: Dict[str, str]
) for h in cookie_headers:
key, value = h.split(";")[0].split("=", maxsplit=1)
cookies[key] = value
oidc_session_cookie = cookies["oidc_session"] oidc_session_cookie = cookies["oidc_session"]
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
@ -590,9 +598,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# that should serve a confirmation page # that should serve a confirmation page
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
self.assertTrue( content_type_headers = channel.headers.getRawHeaders("Content-Type")
channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html") assert content_type_headers
) self.assertTrue(content_type_headers[-1].startswith("text/html"))
p = TestHtmlParser() p = TestHtmlParser()
p.feed(channel.text_body) p.feed(channel.text_body)
p.close() p.close()
@ -806,6 +814,7 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 302) self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url) self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}}) @override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
@ -1248,7 +1257,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
# that should redirect to the username picker # that should redirect to the username picker
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
picker_url = channel.headers.getRawHeaders("Location")[0] location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
picker_url = location_headers[0]
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
# ... with a username_mapping_session cookie # ... with a username_mapping_session cookie
@ -1291,6 +1302,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
) )
self.assertEqual(chan.code, 302, chan.result) self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location") location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
# send a request to the completion page, which should 302 to the client redirectUrl # send a request to the completion page, which should 302 to the client redirectUrl
chan = self.make_request( chan = self.make_request(
@ -1300,6 +1312,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
) )
self.assertEqual(chan.code, 302, chan.result) self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location") location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
# ensure that the returned location matches the requested redirect URL # ensure that the returned location matches the requested redirect URL
path, query = location_headers[0].split("?", 1) path, query = location_headers[0].split("?", 1)