Fix additional type hints. (#9543)

Type hint fixes due to Twisted 21.2.0 adding type hints.
This commit is contained in:
Patrick Cloke 2021-03-09 07:41:32 -05:00 committed by GitHub
parent 075c16b410
commit 7fdc6cefb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 32 additions and 18 deletions

1
changelog.d/9543.misc Normal file
View File

@ -0,0 +1 @@
Fix incorrect type hints.

View File

@ -21,8 +21,10 @@ import threading
from string import Template from string import Template
import yaml import yaml
from zope.interface import implementer
from twisted.logger import ( from twisted.logger import (
ILogObserver,
LogBeginner, LogBeginner,
STDLibLogObserver, STDLibLogObserver,
eventAsText, eventAsText,
@ -227,7 +229,8 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
threadlocal = threading.local() threadlocal = threading.local()
def _log(event): @implementer(ILogObserver)
def _log(event: dict) -> None:
if "log_text" in event: if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "): if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return return

View File

@ -361,7 +361,7 @@ class FederationServer(FederationBase):
logger.error( logger.error(
"Failed to handle PDU %s", "Failed to handle PDU %s",
event_id, event_id,
exc_info=(f.type, f.value, f.getTracebackObject()), exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
) )
await concurrently_execute( await concurrently_execute(

View File

@ -285,7 +285,7 @@ class PaginationHandler:
except Exception: except Exception:
f = Failure() f = Failure()
logger.error( logger.error(
"[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore
) )
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally: finally:

View File

@ -322,7 +322,8 @@ def _cache_period_from_headers(
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
cache_controls = {} cache_controls = {}
for hdr in headers.getRawHeaders(b"cache-control", []): cache_control_headers = headers.getRawHeaders(b"cache-control") or []
for hdr in cache_control_headers:
for directive in hdr.split(b","): for directive in hdr.split(b","):
splits = [x.strip() for x in directive.split(b"=", 1)] splits = [x.strip() for x in directive.split(b"=", 1)]
k = splits[0].lower() k = splits[0].lower()

View File

@ -669,7 +669,7 @@ def preserve_fn(f):
return g return g
def run_in_background(f, *args, **kwargs): def run_in_background(f, *args, **kwargs) -> defer.Deferred:
"""Calls a function, ensuring that the current context is restored after """Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the return from the function, and that the sentinel context is set once the
deferred returned by the function completes. deferred returned by the function completes.
@ -697,8 +697,10 @@ def run_in_background(f, *args, **kwargs):
if isinstance(res, types.CoroutineType): if isinstance(res, types.CoroutineType):
res = defer.ensureDeferred(res) res = defer.ensureDeferred(res)
# At this point we should have a Deferred, if not then f was a synchronous
# function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred): if not isinstance(res, defer.Deferred):
return res return defer.succeed(res)
if res.called and not res.paused: if res.called and not res.paused:
# The function should have maintained the logcontext, so we can # The function should have maintained the logcontext, so we can

View File

@ -22,6 +22,7 @@ from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request, Site
from synapse.app.generic_worker import ( from synapse.app.generic_worker import (
GenericWorkerReplicationHandler, GenericWorkerReplicationHandler,
@ -32,7 +33,10 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.replication.tcp.resource import (
ReplicationStreamProtocolFactory,
ServerReplicationStreamProtocol,
)
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
@ -59,7 +63,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server # build a replication server
server_factory = ReplicationStreamProtocolFactory(hs) server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer() self.streamer = hs.get_replication_streamer()
self.server = server_factory.buildProtocol(None) self.server = server_factory.buildProtocol(
None
) # type: ServerReplicationStreamProtocol
# Make a new HomeServer object for the worker # Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
@ -155,9 +161,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
request_factory = OneShotRequestFactory() request_factory = OneShotRequestFactory()
# Set up the server side protocol # Set up the server side protocol
channel = _PushHTTPChannel(self.reactor) channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
channel.requestFactory = request_factory
channel.site = self.site
# Connect client to server and vice versa. # Connect client to server and vice versa.
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
@ -188,8 +192,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
fetching updates for given stream. fetching updates for given stream.
""" """
path = request.path # type: bytes # type: ignore
self.assertRegex( self.assertRegex(
request.path, path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$" br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
% (stream_name.encode("ascii"),), % (stream_name.encode("ascii"),),
) )
@ -390,9 +395,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
request_factory = OneShotRequestFactory() request_factory = OneShotRequestFactory()
# Set up the server side protocol # Set up the server side protocol
channel = _PushHTTPChannel(self.reactor) channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
channel.requestFactory = request_factory
channel.site = self._hs_to_site[hs]
# Connect client to server and vice versa. # Connect client to server and vice versa.
client_to_server_transport = FakeTransport( client_to_server_transport = FakeTransport(
@ -475,9 +478,13 @@ class _PushHTTPChannel(HTTPChannel):
makes it very hard to test. makes it very hard to test.
""" """
def __init__(self, reactor: IReactorTime): def __init__(
self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
):
super().__init__() super().__init__()
self.reactor = reactor self.reactor = reactor
self.requestFactory = request_factory
self.site = site
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer] self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]

View File

@ -188,7 +188,7 @@ class FakeSite:
def make_request( def make_request(
reactor, reactor,
site: Site, site: Union[Site, FakeSite],
method, method,
path, path,
content=b"", content=b"",

View File

@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record): def emit(self, record):
log_entry = self.format(record) log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn") log_level = record.levelname.lower().replace("warning", "warn")
self.tx_log.emit( self.tx_log.emit( # type: ignore
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
) )