Misc typing fixes for tests, part 2 of N (#11330)

This commit is contained in:
David Robertson 2021-11-16 10:41:35 +00:00 committed by GitHub
parent e72135b9d3
commit 0dda1a7968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 66 additions and 29 deletions

1
changelog.d/11330.misc Normal file
View File

@ -0,0 +1 @@
Improve type annotations in Synapse's test suite.

View File

@ -193,7 +193,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self): def test_get_or_create_user_mau_not_blocked(self):
self.store.count_monthly_users = Mock( # Type ignore: mypy doesn't like us assigning to methods.
self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
) )
# Ensure does not throw exception # Ensure does not throw exception
@ -201,7 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self): def test_get_or_create_user_mau_blocked(self):
self.store.get_monthly_active_count = Mock( # Type ignore: mypy doesn't like us assigning to methods.
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.lots_of_users) return_value=make_awaitable(self.lots_of_users)
) )
self.get_failure( self.get_failure(
@ -209,7 +211,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
ResourceLimitError, ResourceLimitError,
) )
self.store.get_monthly_active_count = Mock( # Type ignore: mypy doesn't like us assigning to methods.
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value) return_value=make_awaitable(self.hs.config.server.max_mau_value)
) )
self.get_failure( self.get_failure(

View File

@ -28,11 +28,12 @@ from typing import (
MutableMapping, MutableMapping,
Optional, Optional,
Tuple, Tuple,
Union, overload,
) )
from unittest.mock import patch from unittest.mock import patch
import attr import attr
from typing_extensions import Literal
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Site from twisted.web.server import Site
@ -55,6 +56,32 @@ class RestHelper:
site = attr.ib(type=Site) site = attr.ib(type=Site)
auth_user_id = attr.ib() auth_user_id = attr.ib()
@overload
def create_room_as(
self,
room_creator: Optional[str] = ...,
is_public: Optional[bool] = ...,
room_version: Optional[str] = ...,
tok: Optional[str] = ...,
expect_code: Literal[200] = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
) -> str:
...
@overload
def create_room_as(
self,
room_creator: Optional[str] = ...,
is_public: Optional[bool] = ...,
room_version: Optional[str] = ...,
tok: Optional[str] = ...,
expect_code: int = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
) -> Optional[str]:
...
def create_room_as( def create_room_as(
self, self,
room_creator: Optional[str] = None, room_creator: Optional[str] = None,
@ -64,7 +91,7 @@ class RestHelper:
expect_code: int = 200, expect_code: int = 200,
extra_content: Optional[Dict] = None, extra_content: Optional[Dict] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
) -> str: ) -> Optional[str]:
""" """
Create a room. Create a room.
@ -107,6 +134,8 @@ class RestHelper:
if expect_code == 200: if expect_code == 200:
return channel.json_body["room_id"] return channel.json_body["room_id"]
else:
return None
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership( self.change_membership(
@ -176,7 +205,7 @@ class RestHelper:
extra_data: Optional[dict] = None, extra_data: Optional[dict] = None,
tok: Optional[str] = None, tok: Optional[str] = None,
expect_code: int = 200, expect_code: int = 200,
expect_errcode: str = None, expect_errcode: Optional[str] = None,
) -> None: ) -> None:
""" """
Send a membership state event into a room. Send a membership state event into a room.
@ -260,9 +289,7 @@ class RestHelper:
txn_id=None, txn_id=None,
tok=None, tok=None,
expect_code=200, expect_code=200,
custom_headers: Optional[ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
): ):
if txn_id is None: if txn_id is None:
txn_id = "m%s" % (str(time.time())) txn_id = "m%s" % (str(time.time()))
@ -509,7 +536,7 @@ class RestHelper:
went. went.
""" """
cookies = {} cookies: Dict[str, str] = {}
# if we're doing a ui auth, hit the ui auth redirect endpoint # if we're doing a ui auth, hit the ui auth redirect endpoint
if ui_auth_session_id: if ui_auth_session_id:
@ -631,7 +658,13 @@ class RestHelper:
# hit the redirect url again with the right Host header, which should now issue # hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider. # a cookie and redirect to the SSO provider.
location = channel.headers.getRawHeaders("Location")[0] def get_location(channel: FakeChannel) -> str:
location_values = channel.headers.getRawHeaders("Location")
# Keep mypy happy by asserting that location_values is nonempty
assert location_values
return location_values[0]
location = get_location(channel)
parts = urllib.parse.urlsplit(location) parts = urllib.parse.urlsplit(location)
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.hs.get_reactor(),
@ -645,7 +678,7 @@ class RestHelper:
assert channel.code == 302 assert channel.code == 302
channel.extract_cookies(cookies) channel.extract_cookies(cookies)
return channel.headers.getRawHeaders("Location")[0] return get_location(channel)
def initiate_sso_ui_auth( def initiate_sso_ui_auth(
self, ui_auth_session_id: str, cookies: MutableMapping[str, str] self, ui_auth_session_id: str, cookies: MutableMapping[str, str]

View File

@ -24,6 +24,7 @@ from typing import (
MutableMapping, MutableMapping,
Optional, Optional,
Tuple, Tuple,
Type,
Union, Union,
) )
@ -226,7 +227,7 @@ def make_request(
path: Union[bytes, str], path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"", content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None, access_token: Optional[str] = None,
request: Request = SynapseRequest, request: Type[Request] = SynapseRequest,
shorthand: bool = True, shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None, federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False, content_is_form: bool = False,

View File

@ -44,6 +44,7 @@ from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest from twisted.trial import unittest
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse import events from synapse import events
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -95,16 +96,13 @@ def around(target):
return _around return _around
T = TypeVar("T")
class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel' """A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs.""" root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName, *args, **kwargs): def __init__(self, methodName: str):
super().__init__(methodName, *args, **kwargs) super().__init__(methodName)
method = getattr(self, methodName) method = getattr(self, methodName)
@ -220,16 +218,16 @@ class HomeserverTestCase(TestCase):
Attributes: Attributes:
servlets: List of servlet registration function. servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked. user_id (str): The user ID to assume if auth is hijacked.
hijack_auth (bool): Whether to hijack auth to return the user specified hijack_auth: Whether to hijack auth to return the user specified
in user_id. in user_id.
""" """
hijack_auth = True hijack_auth: ClassVar[bool] = True
needs_threadpool = False needs_threadpool: ClassVar[bool] = False
servlets: ClassVar[List[RegisterServletsFunc]] = [] servlets: ClassVar[List[RegisterServletsFunc]] = []
def __init__(self, methodName, *args, **kwargs): def __init__(self, methodName: str):
super().__init__(methodName, *args, **kwargs) super().__init__(methodName)
# see if we have any additional config for this test # see if we have any additional config for this test
method = getattr(self, methodName) method = getattr(self, methodName)
@ -301,9 +299,10 @@ class HomeserverTestCase(TestCase):
None, None,
) )
self.hs.get_auth().get_user_by_req = get_user_by_req # Type ignore: mypy doesn't like us assigning to methods.
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
self.hs.get_auth().get_access_token_from_request = Mock( self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
return_value="1234" return_value="1234"
) )
@ -417,7 +416,7 @@ class HomeserverTestCase(TestCase):
path: Union[bytes, str], path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"", content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None, access_token: Optional[str] = None,
request: Type[T] = SynapseRequest, request: Type[Request] = SynapseRequest,
shorthand: bool = True, shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None, federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False, content_is_form: bool = False,
@ -596,7 +595,7 @@ class HomeserverTestCase(TestCase):
nonce_str += b"\x00notadmin" nonce_str += b"\x00notadmin"
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
want_mac = want_mac.hexdigest() want_mac_digest = want_mac.hexdigest()
body = json.dumps( body = json.dumps(
{ {
@ -605,7 +604,7 @@ class HomeserverTestCase(TestCase):
"displayname": displayname, "displayname": displayname,
"password": password, "password": password,
"admin": admin, "admin": admin,
"mac": want_mac, "mac": want_mac_digest,
"inhibit_login": True, "inhibit_login": True,
} }
) )