Misc typing fixes for tests, part 2 of N (#11330)

This commit is contained in:
David Robertson 2021-11-16 10:41:35 +00:00 committed by GitHub
parent e72135b9d3
commit 0dda1a7968
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 66 additions and 29 deletions

View file

@ -44,6 +44,7 @@ from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest
from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes, Membership
@ -95,16 +96,13 @@ def around(target):
return _around
T = TypeVar("T")
class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
def __init__(self, methodName: str):
super().__init__(methodName)
method = getattr(self, methodName)
@ -220,16 +218,16 @@ class HomeserverTestCase(TestCase):
Attributes:
servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
hijack_auth (bool): Whether to hijack auth to return the user specified
hijack_auth: Whether to hijack auth to return the user specified
in user_id.
"""
hijack_auth = True
needs_threadpool = False
hijack_auth: ClassVar[bool] = True
needs_threadpool: ClassVar[bool] = False
servlets: ClassVar[List[RegisterServletsFunc]] = []
def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
def __init__(self, methodName: str):
super().__init__(methodName)
# see if we have any additional config for this test
method = getattr(self, methodName)
@ -301,9 +299,10 @@ class HomeserverTestCase(TestCase):
None,
)
self.hs.get_auth().get_user_by_req = get_user_by_req
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
self.hs.get_auth().get_access_token_from_request = Mock(
# Type ignore: mypy doesn't like us assigning to methods.
self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
return_value="1234"
)
@ -417,7 +416,7 @@ class HomeserverTestCase(TestCase):
path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
request: Type[T] = SynapseRequest,
request: Type[Request] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
@ -596,7 +595,7 @@ class HomeserverTestCase(TestCase):
nonce_str += b"\x00notadmin"
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
want_mac = want_mac.hexdigest()
want_mac_digest = want_mac.hexdigest()
body = json.dumps(
{
@ -605,7 +604,7 @@ class HomeserverTestCase(TestCase):
"displayname": displayname,
"password": password,
"admin": admin,
"mac": want_mac,
"mac": want_mac_digest,
"inhibit_login": True,
}
)