Add missing type hints to tests.handlers. (#14680)

And do not allow untyped defs in tests.handlers.
This commit is contained in:
Patrick Cloke 2022-12-16 06:53:01 -05:00 committed by GitHub
parent 54c012c5a8
commit 652d1669c5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 527 additions and 378 deletions

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Tuple
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from synapse.util.stringutils import random_string
@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config
try:
import authlib # noqa: F401
from authlib.oidc.core import UserInfo
from authlib.oidc.discovery import OpenIDProviderMetadata
from synapse.handlers.oidc import Token, UserAttributeDict
HAS_OIDC = True
except ImportError:
@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = {
class TestMappingProvider:
@staticmethod
def parse_config(config):
return
def parse_config(config: JsonDict) -> None:
return None
def __init__(self, config):
def __init__(self, config: None):
pass
def get_remote_user_id(self, userinfo):
def get_remote_user_id(self, userinfo: "UserInfo") -> str:
return userinfo["sub"]
async def map_user_attributes(self, userinfo, token):
return {"localpart": userinfo["username"], "display_name": None}
async def map_user_attributes(
self, userinfo: "UserInfo", token: "Token"
) -> "UserAttributeDict":
# This is testing not providing the full map.
return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item]
# Do not include get_extra_attributes to test backwards compatibility paths.
class TestMappingProviderExtra(TestMappingProvider):
async def get_extra_attributes(self, userinfo, token):
async def get_extra_attributes(
self, userinfo: "UserInfo", token: "Token"
) -> JsonDict:
return {"phone": userinfo["phone"]}
class TestMappingProviderFailures(TestMappingProvider):
async def map_user_attributes(self, userinfo, token, failures):
return {
# Superclass is testing the legacy interface for map_user_attributes.
async def map_user_attributes( # type: ignore[override]
self, userinfo: "UserInfo", token: "Token", failures: int
) -> "UserAttributeDict":
return { # type: ignore[typeddict-item]
"localpart": userinfo["username"] + (str(failures) if failures else ""),
"display_name": None,
}
@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.hs_patcher.stop()
return super().tearDown()
def reset_mocks(self):
def reset_mocks(self) -> None:
"""Reset all the Mocks."""
self.fake_server.reset_mocks()
self.render_error.reset_mock()
self.complete_sso_login.reset_mock()
def metadata_edit(self, values):
def metadata_edit(self, values: dict) -> ContextManager[Mock]:
"""Modify the result that will be returned by the well-known query"""
metadata = self.fake_server.get_metadata()
@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
return _build_callback_request(code, state, session), grant
def assertRenderedError(self, error, error_description=None):
def assertRenderedError(
self, error: str, error_description: Optional[str] = None
) -> Tuple[Any, ...]:
self.render_error.assert_called_once()
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadatas are extensively validated."""
h = self.provider
def force_load_metadata():
async def force_load():
def force_load_metadata() -> Awaitable[None]:
async def force_load() -> "OpenIDProviderMetadata":
return await h.load_metadata(force=True)
return get_awaitable_result(force_load())
@ -1198,7 +1212,7 @@ def _build_callback_request(
state: str,
session: str,
ip_address: str = "10.0.0.1",
):
) -> Mock:
"""Builds a fake SynapseRequest to mock the browser callback
Returns a Mock object which looks like the SynapseRequest we get from a browser