Fix remaining mypy issues due to Twisted upgrade. (#9608)

This commit is contained in:
Patrick Cloke 2021-03-15 11:14:39 -04:00 committed by GitHub
parent 026503fa3b
commit d29b71aa50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 42 additions and 34 deletions

View file

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
return request_factory.request
return channel.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
b"localhost",
6379,
self.connect_any_redis_attempts,
)
@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(host, b"localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler):
self.received_rdata_rows.append((stream_name, token, r))
@attr.s()
class OneShotRequestFactory:
"""A simple request factory that generates a single `SynapseRequest` and
stores it for future use. Can only be used once.
"""
request = attr.ib(default=None)
def __call__(self, *args, **kwargs):
assert self.request is None
self.request = SynapseRequest(*args, **kwargs)
return self.request
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
"""
def __init__(
self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
):
super().__init__()
self.reactor = reactor
@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel):
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
def requestDone(self, request):
# Store the request for inspection.
self.request = request
super().requestDone(request)
class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
@ -597,6 +581,8 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""
transport = None # type: Optional[FakeTransport]
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol):
def send(self, msg):
"""Send a message back to the client."""
assert self.transport is not None
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)