Add more missing type hints to tests. (#15028)

This commit is contained in:
Patrick Cloke 2023-02-08 16:29:49 -05:00 committed by GitHub
parent 4eed7b2ede
commit 30509a1010
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 124 additions and 111 deletions

View file

@ -20,12 +20,13 @@ import sys
import warnings
from asyncio import Future
from binascii import unhexlify
from typing import Awaitable, Callable, Tuple, TypeVar
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
from unittest.mock import Mock
import attr
import zope.interface
from twisted.internet.interfaces import IProtocol
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.web.http import RESPONSES
@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
from synapse.types import JsonDict
if TYPE_CHECKING:
from sys import UnraisableHookArgs
TV = TypeVar("TV")
@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook
def unraisablehook(unraisable):
def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
unraisable_exceptions.append(unraisable.exc_value)
def cleanup():
def cleanup() -> None:
"""
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
"""
sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions:
raise unraisable_exceptions.pop()
exc = unraisable_exceptions.pop()
assert exc is not None
raise exc
sys.unraisablehook = unraisablehook
return cleanup
def simple_async_mock(return_value=None, raises=None) -> Mock:
def simple_async_mock(
return_value: Optional[TV] = None, raises: Optional[Exception] = None
) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
if raises:
raise raises
return return_value
@ -125,14 +133,14 @@ class FakeResponse: # type: ignore[misc]
headers: Headers = attr.Factory(Headers)
@property
def phrase(self):
def phrase(self) -> bytes:
return RESPONSES.get(self.code, b"Unknown Status")
@property
def length(self):
def length(self) -> int:
return len(self.body)
def deliverBody(self, protocol):
def deliverBody(self, protocol: IProtocol) -> None:
protocol.dataReceived(self.body)
protocol.connectionLost(Failure(ResponseDone()))