Add missing type hints for tests.unittest. (#13397)

This commit is contained in:
Patrick Cloke 2022-07-27 13:18:41 -04:00 committed by GitHub
parent 502f075e96
commit 922b771337
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 66 additions and 52 deletions

1
changelog.d/13397.misc Normal file
View File

@ -0,0 +1 @@
Adding missing type hints to tests.

View File

@ -481,17 +481,13 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
return config return config
def prepare( def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
self.allowed_access_token = self.login(self.allowed_localpart, "pass") self.allowed_access_token = self.login(self.allowed_localpart, "pass")
self.denied_user_id = self.register_user("denied", "pass") self.denied_user_id = self.register_user("denied", "pass")
self.denied_access_token = self.login("denied", "pass") self.denied_access_token = self.login("denied", "pass")
return hs
def test_denied_without_publication_permission(self) -> None: def test_denied_without_publication_permission(self) -> None:
""" """
Try to create a room, register an alias for it, and publish it, Try to create a room, register an alias for it, and publish it,
@ -575,9 +571,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
servlets = [directory.register_servlets, room.register_servlets] servlets = [directory.register_servlets, room.register_servlets]
def prepare( def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
) -> HomeServer:
room_id = self.helper.create_room_as(self.user_id) room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request( channel = self.make_request(
@ -588,8 +582,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.room_list_handler = hs.get_room_list_handler() self.room_list_handler = hs.get_room_list_handler()
self.directory_handler = hs.get_directory_handler() self.directory_handler = hs.get_directory_handler()
return hs
def test_disabling_room_list(self) -> None: def test_disabling_room_list(self) -> None:
self.room_list_handler.enable_room_list_search = True self.room_list_handler.enable_room_list_search = True
self.directory_handler.enable_room_list_search = True self.directory_handler.enable_room_list_search = True

View File

@ -1060,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
participated, bundled_aggregations.get("current_user_participated") participated, bundled_aggregations.get("current_user_participated")
) )
# The latest thread event has some fields that don't matter. # The latest thread event has some fields that don't matter.
self.assertIn("latest_event", bundled_aggregations)
self.assert_dict( self.assert_dict(
{ {
"content": { "content": {
@ -1072,7 +1073,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"sender": self.user2_id, "sender": self.user2_id,
"type": "m.room.test", "type": "m.room.test",
}, },
bundled_aggregations.get("latest_event"), bundled_aggregations["latest_event"],
) )
return assert_thread return assert_thread
@ -1112,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
self.assertEqual(2, bundled_aggregations.get("count")) self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated")) self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter. # The latest thread event has some fields that don't matter.
self.assertIn("latest_event", bundled_aggregations)
self.assert_dict( self.assert_dict(
{ {
"content": { "content": {
@ -1124,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"sender": self.user_id, "sender": self.user_id,
"type": "m.room.test", "type": "m.room.test",
}, },
bundled_aggregations.get("latest_event"), bundled_aggregations["latest_event"],
) )
# Check the unsigned field on the latest event. # Check the unsigned field on the latest event.
self.assert_dict( self.assert_dict(

View File

@ -496,7 +496,7 @@ class RoomStateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"]) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertCountEqual( self.assertCountEqual(
[state_event["type"] for state_event in channel.json_body], [state_event["type"] for state_event in channel.json_list],
{ {
"m.room.create", "m.room.create",
"m.room.power_levels", "m.room.power_levels",

View File

@ -25,6 +25,7 @@ from typing import (
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
List,
MutableMapping, MutableMapping,
Optional, Optional,
Tuple, Tuple,
@ -121,7 +122,15 @@ class FakeChannel:
@property @property
def json_body(self) -> JsonDict: def json_body(self) -> JsonDict:
return json.loads(self.text_body) body = json.loads(self.text_body)
assert isinstance(body, dict)
return body
@property
def json_list(self) -> List[JsonDict]:
body = json.loads(self.text_body)
assert isinstance(body, list)
return body
@property @property
def text_body(self) -> str: def text_body(self) -> str:

View File

@ -28,6 +28,7 @@ from typing import (
Generic, Generic,
Iterable, Iterable,
List, List,
NoReturn,
Optional, Optional,
Tuple, Tuple,
Type, Type,
@ -39,7 +40,7 @@ from unittest.mock import Mock, patch
import canonicaljson import canonicaljson
import signedjson.key import signedjson.key
import unpaddedbase64 import unpaddedbase64
from typing_extensions import Protocol from typing_extensions import Concatenate, ParamSpec, Protocol
from twisted.internet.defer import Deferred, ensureDeferred from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -67,7 +68,7 @@ from synapse.logging.context import (
from synapse.rest import RegisterServletsFunc from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, UserID, create_requester from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
@ -88,6 +89,10 @@ setup_logging()
TV = TypeVar("TV") TV = TypeVar("TV")
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True) _ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
P = ParamSpec("P")
R = TypeVar("R")
S = TypeVar("S")
class _TypedFailure(Generic[_ExcType], Protocol): class _TypedFailure(Generic[_ExcType], Protocol):
"""Extension to twisted.Failure, where the 'value' has a certain type.""" """Extension to twisted.Failure, where the 'value' has a certain type."""
@ -97,7 +102,7 @@ class _TypedFailure(Generic[_ExcType], Protocol):
... ...
def around(target): def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
"""A CLOS-style 'around' modifier, which wraps the original method of the """A CLOS-style 'around' modifier, which wraps the original method of the
given instance with another piece of code. given instance with another piece of code.
@ -106,11 +111,11 @@ def around(target):
return orig(*args, **kwargs) return orig(*args, **kwargs)
""" """
def _around(code): def _around(code: Callable[Concatenate[S, P], R]) -> None:
name = code.__name__ name = code.__name__
orig = getattr(target, name) orig = getattr(target, name)
def new(*args, **kwargs): def new(*args: P.args, **kwargs: P.kwargs) -> R:
return code(orig, *args, **kwargs) return code(orig, *args, **kwargs)
setattr(target, name, new) setattr(target, name, new)
@ -131,7 +136,7 @@ class TestCase(unittest.TestCase):
level = getattr(method, "loglevel", getattr(self, "loglevel", None)) level = getattr(method, "loglevel", getattr(self, "loglevel", None))
@around(self) @around(self)
def setUp(orig): def setUp(orig: Callable[[], R]) -> R:
# if we're not starting in the sentinel logcontext, then to be honest # if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off. # all future bets are off.
if current_context(): if current_context():
@ -144,7 +149,7 @@ class TestCase(unittest.TestCase):
if level is not None and old_level != level: if level is not None and old_level != level:
@around(self) @around(self)
def tearDown(orig): def tearDown(orig: Callable[[], R]) -> R:
ret = orig() ret = orig()
logging.getLogger().setLevel(old_level) logging.getLogger().setLevel(old_level)
return ret return ret
@ -158,7 +163,7 @@ class TestCase(unittest.TestCase):
return orig() return orig()
@around(self) @around(self)
def tearDown(orig): def tearDown(orig: Callable[[], R]) -> R:
ret = orig() ret = orig()
# force a GC to workaround problems with deferreds leaking logcontexts when # force a GC to workaround problems with deferreds leaking logcontexts when
# they are GCed (see the logcontext docs) # they are GCed (see the logcontext docs)
@ -167,7 +172,7 @@ class TestCase(unittest.TestCase):
return ret return ret
def assertObjectHasAttributes(self, attrs, obj): def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
"""Asserts that the given object has each of the attributes given, and """Asserts that the given object has each of the attributes given, and
that the value of each matches according to assertEqual.""" that the value of each matches according to assertEqual."""
for key in attrs.keys(): for key in attrs.keys():
@ -178,12 +183,12 @@ class TestCase(unittest.TestCase):
except AssertionError as e: except AssertionError as e:
raise (type(e))(f"Assert error for '.{key}':") from e raise (type(e))(f"Assert error for '.{key}':") from e
def assert_dict(self, required, actual): def assert_dict(self, required: dict, actual: dict) -> None:
"""Does a partial assert of a dict. """Does a partial assert of a dict.
Args: Args:
required (dict): The keys and value which MUST be in 'actual'. required: The keys and value which MUST be in 'actual'.
actual (dict): The test result. Extra keys will not be checked. actual: The test result. Extra keys will not be checked.
""" """
for key in required: for key in required:
self.assertEqual( self.assertEqual(
@ -191,31 +196,31 @@ class TestCase(unittest.TestCase):
) )
def DEBUG(target): def DEBUG(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.DEBUG. """A decorator to set the .loglevel attribute to logging.DEBUG.
Can apply to either a TestCase or an individual test method.""" Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.DEBUG target.loglevel = logging.DEBUG # type: ignore[attr-defined]
return target return target
def INFO(target): def INFO(target: TV) -> TV:
"""A decorator to set the .loglevel attribute to logging.INFO. """A decorator to set the .loglevel attribute to logging.INFO.
Can apply to either a TestCase or an individual test method.""" Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.INFO target.loglevel = logging.INFO # type: ignore[attr-defined]
return target return target
def logcontext_clean(target): def logcontext_clean(target: TV) -> TV:
"""A decorator which marks the TestCase or method as 'logcontext_clean' """A decorator which marks the TestCase or method as 'logcontext_clean'
... ie, any logcontext errors should cause a test failure ... ie, any logcontext errors should cause a test failure
""" """
def logcontext_error(msg): def logcontext_error(msg: str) -> NoReturn:
raise AssertionError("logcontext error: %s" % (msg)) raise AssertionError("logcontext error: %s" % (msg))
patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error) patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
return patcher(target) return patcher(target) # type: ignore[call-overload]
class HomeserverTestCase(TestCase): class HomeserverTestCase(TestCase):
@ -255,7 +260,7 @@ class HomeserverTestCase(TestCase):
method = getattr(self, methodName) method = getattr(self, methodName)
self._extra_config = getattr(method, "_extra_config", None) self._extra_config = getattr(method, "_extra_config", None)
def setUp(self): def setUp(self) -> None:
""" """
Set up the TestCase by calling the homeserver constructor, optionally Set up the TestCase by calling the homeserver constructor, optionally
hijacking the authentication system to return a fixed user, and then hijacking the authentication system to return a fixed user, and then
@ -306,7 +311,9 @@ class HomeserverTestCase(TestCase):
) )
) )
async def get_user_by_access_token(token=None, allow_guest=False): async def get_user_by_access_token(
token: Optional[str] = None, allow_guest: bool = False
) -> JsonDict:
assert self.helper.auth_user_id is not None assert self.helper.auth_user_id is not None
return { return {
"user": UserID.from_string(self.helper.auth_user_id), "user": UserID.from_string(self.helper.auth_user_id),
@ -314,7 +321,11 @@ class HomeserverTestCase(TestCase):
"is_guest": False, "is_guest": False,
} }
async def get_user_by_req(request, allow_guest=False): async def get_user_by_req(
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
) -> Requester:
assert self.helper.auth_user_id is not None assert self.helper.auth_user_id is not None
return create_requester( return create_requester(
UserID.from_string(self.helper.auth_user_id), UserID.from_string(self.helper.auth_user_id),
@ -339,11 +350,11 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "prepare"): if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs) self.prepare(self.reactor, self.clock, self.hs)
def tearDown(self): def tearDown(self) -> None:
# Reset to not use frozen dicts. # Reset to not use frozen dicts.
events.USE_FROZEN_DICTS = False events.USE_FROZEN_DICTS = False
def wait_on_thread(self, deferred, timeout=10): def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
""" """
Wait until a Deferred is done, where it's waiting on a real thread. Wait until a Deferred is done, where it's waiting on a real thread.
""" """
@ -374,7 +385,7 @@ class HomeserverTestCase(TestCase):
clock (synapse.util.Clock): The Clock, associated with the reactor. clock (synapse.util.Clock): The Clock, associated with the reactor.
Returns: Returns:
A homeserver (synapse.server.HomeServer) suitable for testing. A homeserver suitable for testing.
Function to be overridden in subclasses. Function to be overridden in subclasses.
""" """
@ -408,7 +419,7 @@ class HomeserverTestCase(TestCase):
"/_synapse/admin": servlet_resource, "/_synapse/admin": servlet_resource,
} }
def default_config(self): def default_config(self) -> JsonDict:
""" """
Get a default HomeServer config dict. Get a default HomeServer config dict.
""" """
@ -421,7 +432,9 @@ class HomeserverTestCase(TestCase):
return config return config
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
""" """
Prepare for the test. This involves things like mocking out parts of Prepare for the test. This involves things like mocking out parts of
the homeserver, or building test data common across the whole test the homeserver, or building test data common across the whole test
@ -519,7 +532,7 @@ class HomeserverTestCase(TestCase):
config_obj.parse_config_dict(config, "", "") config_obj.parse_config_dict(config, "", "")
kwargs["config"] = config_obj kwargs["config"] = config_obj
async def run_bg_updates(): async def run_bg_updates() -> None:
with LoggingContext("run_bg_updates"): with LoggingContext("run_bg_updates"):
self.get_success(stor.db_pool.updates.run_background_updates(False)) self.get_success(stor.db_pool.updates.run_background_updates(False))
@ -538,11 +551,7 @@ class HomeserverTestCase(TestCase):
""" """
self.reactor.pump([by] * 100) self.reactor.pump([by] * 100)
def get_success( def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
self,
d: Awaitable[TV],
by: float = 0.0,
) -> TV:
deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type] deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
self.pump(by=by) self.pump(by=by)
return self.successResultOf(deferred) return self.successResultOf(deferred)
@ -755,7 +764,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
OTHER_SERVER_NAME = "other.example.com" OTHER_SERVER_NAME = "other.example.com"
OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test") OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs) super().prepare(reactor, clock, hs)
# poke the other server's signing key into the key store, so that we don't # poke the other server's signing key into the key store, so that we don't
@ -879,7 +888,7 @@ def _auth_header_for_request(
) )
def override_config(extra_config): def override_config(extra_config: JsonDict) -> Callable[[TV], TV]:
"""A decorator which can be applied to test functions to give additional HS config """A decorator which can be applied to test functions to give additional HS config
For use For use
@ -892,12 +901,13 @@ def override_config(extra_config):
... ...
Args: Args:
extra_config(dict): Additional config settings to be merged into the default extra_config: Additional config settings to be merged into the default
config dict before instantiating the test homeserver. config dict before instantiating the test homeserver.
""" """
def decorator(func): def decorator(func: TV) -> TV:
func._extra_config = extra_config # This attribute is being defined.
func._extra_config = extra_config # type: ignore[attr-defined]
return func return func
return decorator return decorator