Add missing type hints to non-client REST servlets. (#10817)

Including admin, consent, key, synapse, and media. All REST servlets
(the synapse.rest module) now require typed method definitions.
This commit is contained in:
Patrick Cloke 2021-09-15 08:45:32 -04:00 committed by GitHub
parent 8c7a531e27
commit b93259082c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 169 additions and 96 deletions

View File

@ -1 +1 @@
Convert the internal `FileInfo` class to attrs and add type hints.
Add missing type hints to REST servlets.

1
changelog.d/10817.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to REST servlets.

View File

@ -90,7 +90,7 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py
[mypy-synapse.rest.client.*]
[mypy-synapse.rest.*]
disallow_untyped_defs = True
[mypy-synapse.util.batching_queue]

View File

@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import JsonResource
from typing import TYPE_CHECKING
from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
from synapse.rest.client import (
account,
@ -57,6 +59,9 @@ from synapse.rest.client import (
voip,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.
@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
* etc
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(client_resource, hs):
def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
versions.register_servlets(hs, client_resource)
# Deprecated in r0

View File

@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet):
self.store = hs.get_datastore()
async def on_GET(
self, request: SynapseRequest, user_id, device_id: str
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

View File

@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet):
self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs)
def register(self, json_resource: HttpServer):
def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__

View File

@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet):
self.nonces: Dict[str, int] = {}
self.hs = hs
def _clear_old_nonces(self):
def _clear_old_nonces(self) -> None:
"""
Clear out old nonces that are older than NONCE_TIMEOUT.
"""

View File

@ -17,17 +17,22 @@ import logging
from hashlib import sha256
from http import HTTPStatus
from os import path
from typing import Dict, List
from typing import TYPE_CHECKING, Any, Dict, List
import jinja2
from jinja2 import TemplateNotFound
from twisted.web.server import Request
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_bytes_from_args, parse_string
from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
# language to use for the templates. TODO: figure this out from Accept-Language
TEMPLATE_LANGUAGE = "en"
@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
against the user.
"""
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): homeserver
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@ -106,18 +107,14 @@ class ConsentResource(DirectServeHtmlResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8")
async def _async_render_GET(self, request):
"""
Args:
request (twisted.web.http.Request):
"""
async def _async_render_GET(self, request: Request) -> None:
version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", default="")
userhmac = None
has_consented = False
public_version = username == ""
if not public_version:
args: Dict[bytes, List[bytes]] = request.args
args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes)
@ -147,14 +144,10 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
async def _async_render_POST(self, request):
"""
Args:
request (twisted.web.http.Request):
"""
async def _async_render_POST(self, request: Request) -> None:
version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True)
args: Dict[bytes, List[bytes]] = request.args
args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac)
@ -177,7 +170,9 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound:
raise NotFoundError("success.html not found")
def _render_template(self, request, template_name, **template_args):
def _render_template(
self, request: Request, template_name: str, **template_args: Any
) -> None:
# get_template checks for ".." so we don't need to worry too much
# about path traversal here.
template_html = self._jinja_env.get_template(
@ -186,11 +181,11 @@ class ConsentResource(DirectServeHtmlResource):
html = template_html.render(**template_args)
respond_with_html(request, 200, html)
def _check_hash(self, userid, userhmac):
def _check_hash(self, userid: str, userhmac: bytes) -> None:
"""
Args:
userid (unicode):
userhmac (bytes):
userid:
userhmac:
Raises:
SynapseError if the hash doesn't match

View File

@ -13,6 +13,7 @@
# limitations under the License.
from twisted.web.resource import Resource
from twisted.web.server import Request
class HealthResource(Resource):
@ -25,6 +26,6 @@ class HealthResource(Resource):
isLeaf = 1
def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", b"text/plain")
return b"OK"

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
if TYPE_CHECKING:
from synapse.server import HomeServer
class KeyApiV2Resource(Resource):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs))

View File

@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64
from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.http.server import respond_with_json_bytes
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -58,18 +63,18 @@ class LocalKey(Resource):
isLeaf = True
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.config = hs.config
self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
def update_response_body(self, time_now_msec):
def update_response_body(self, time_now_msec: int) -> None:
refresh_interval = self.config.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())
def response_json_object(self):
def response_json_object(self) -> JsonDict:
verify_keys = {}
for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode()
@ -94,7 +99,7 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
def render_GET(self, request):
def render_GET(self, request: Request) -> int:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:

View File

@ -13,17 +13,23 @@
# limitations under the License.
import logging
from typing import Dict
from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -85,7 +91,7 @@ class RemoteKey(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.fetcher = ServerKeyFetcher(hs)
@ -94,7 +100,8 @@ class RemoteKey(DirectServeJsonResource):
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
assert request.postpath is not None
if len(request.postpath) == 1:
(server,) = request.postpath
query: dict = {server.decode("ascii"): {}}
@ -110,14 +117,19 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
async def _async_render_POST(self, request):
async def _async_render_POST(self, request: Request) -> None:
content = parse_json_object_from_request(request)
query = content["server_keys"]
await self.query_keys(request, query, query_remote_on_cache_miss=True)
async def query_keys(self, request, query, query_remote_on_cache_miss=False):
async def query_keys(
self,
request: Request,
query: JsonDict,
query_remote_on_cache_miss: bool = False,
) -> None:
logger.info("Handling query for keys %r", query)
store_queries = []
@ -142,8 +154,8 @@ class RemoteKey(DirectServeJsonResource):
# Note that the value is unused.
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]
for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in key_results]
if not results and key_id is not None:
cache_misses.setdefault(server_name, {})[key_id] = 0
@ -230,6 +242,6 @@ class RemoteKey(DirectServeJsonResource):
signed_keys.append(key_json)
results = {"server_keys": signed_keys}
response = {"server_keys": signed_keys}
respond_with_json(request, 200, results, canonical_json=True)
respond_with_json(request, 200, response, canonical_json=True)

View File

@ -16,7 +16,8 @@
import logging
import os
import urllib
from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from types import TracebackType
from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
import attr
@ -122,7 +123,7 @@ def add_file_headers(
upload_name: The name of the requested file, if any.
"""
def _quote(x):
def _quote(x: str) -> str:
return urllib.parse.quote(x.encode("utf-8"))
# Default to a UTF-8 charset for text content types.
@ -282,10 +283,15 @@ class Responder:
"""
pass
def __enter__(self):
def __enter__(self) -> None:
pass
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
pass
@ -317,31 +323,31 @@ class FileInfo:
# The below properties exist to maintain compatibility with third-party modules.
@property
def thumbnail_width(self):
def thumbnail_width(self) -> Optional[int]:
if not self.thumbnail:
return None
return self.thumbnail.width
@property
def thumbnail_height(self):
def thumbnail_height(self) -> Optional[int]:
if not self.thumbnail:
return None
return self.thumbnail.height
@property
def thumbnail_method(self):
def thumbnail_method(self) -> Optional[str]:
if not self.thumbnail:
return None
return self.thumbnail.method
@property
def thumbnail_type(self):
def thumbnail_type(self) -> Optional[str]:
if not self.thumbnail:
return None
return self.thumbnail.type
@property
def thumbnail_length(self):
def thumbnail_length(self) -> Optional[int]:
if not self.thumbnail:
return None
return self.thumbnail.length

View File

@ -16,7 +16,7 @@
import functools
import os
import re
from typing import Callable, List
from typing import Any, Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
@ -27,7 +27,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]:
"""
@functools.wraps(func)
def _wrapped(self, *args, **kwargs):
def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str:
path = func(self, *args, **kwargs)
return os.path.join(self.base_path, path)
@ -129,7 +129,7 @@ class MediaFilePaths:
# using the new path.
def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str
):
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(

View File

@ -21,6 +21,7 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
from twisted.web.server import Request
@ -32,6 +33,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.config._base import ConfigError
from synapse.config.repository import ThumbnailRequirement
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
@ -114,7 +116,7 @@ class MediaRepository:
self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
)
def _start_update_recently_accessed(self):
def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed
)
@ -469,7 +471,9 @@ class MediaRepository:
return media_info
def _get_thumbnail_requirements(self, media_type):
def _get_thumbnail_requirements(
self, media_type: str
) -> Tuple[ThumbnailRequirement, ...]:
scpos = media_type.find(";")
if scpos > 0:
media_type = media_type[:scpos]

View File

@ -15,7 +15,20 @@ import contextlib
import logging
import os
import shutil
from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
from types import TracebackType
from typing import (
IO,
TYPE_CHECKING,
Any,
Awaitable,
BinaryIO,
Callable,
Generator,
Optional,
Sequence,
Tuple,
Type,
)
import attr
@ -83,12 +96,14 @@ class MediaStorage:
return fname
async def write_to_file(self, source: IO, output: IO):
async def write_to_file(self, source: IO, output: IO) -> None:
"""Asynchronously write the `source` to `output`."""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager
def store_into_file(self, file_info: FileInfo):
def store_into_file(
self, file_info: FileInfo
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as
described by file_info.
@ -125,7 +140,7 @@ class MediaStorage:
try:
with open(fname, "wb") as f:
async def finish():
async def finish() -> None:
# Ensure that all writes have been flushed and close the
# file.
f.flush()
@ -315,7 +330,12 @@ class FileResponder(Responder):
FileSender().beginFileTransfer(self.open_file, consumer)
)
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.open_file.close()
@ -339,7 +359,7 @@ class ReadableFileWrapper:
clock = attr.ib(type=Clock)
path = attr.ib(type=str)
async def write_chunks_to(self, callback: Callable[[bytes], None]):
async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
"""Reads the file in chunks and calls the callback with each chunk."""
with open(self.path, "rb") as file:

View File

@ -27,6 +27,7 @@ from urllib import parse as urlparse
import attr
from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
from twisted.web.server import Request
@ -473,7 +474,7 @@ class PreviewUrlResource(DirectServeJsonResource):
etag=etag,
)
def _start_expire_url_cache_data(self):
def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data
)
@ -782,7 +783,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
def _iterate_over_text(
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.

View File

@ -99,7 +99,7 @@ class StorageProviderWrapper(StorageProvider):
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
async def store():
async def store() -> None:
try:
return await maybe_awaitable(
self.backend.store_file(path, file_info)
@ -128,7 +128,7 @@ class FileStorageProviderBackend(StorageProvider):
self.cache_directory = hs.config.media_store_path
self.base_directory = config
def __str__(self):
def __str__(self) -> str:
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path: str, file_info: FileInfo) -> None:

View File

@ -41,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
@staticmethod
def set_limits(max_image_pixels: int):
def set_limits(max_image_pixels: int) -> None:
Image.MAX_IMAGE_PIXELS = max_image_pixels
def __init__(self, input_path: str):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generator
from twisted.web.server import Request
@ -45,7 +45,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
self._server_name = hs.hostname
self._consent_version = hs.config.consent.user_consent_version
def template_search_dirs():
def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir:
@ -88,7 +88,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
async def _async_render_POST(self, request: Request):
async def _async_render_POST(self, request: Request) -> None:
try:
session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e:

View File

@ -13,16 +13,20 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class OIDCResource(Resource):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs))

View File

@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING
from synapse.http.server import DirectServeHtmlResource
from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -30,10 +31,10 @@ class OIDCCallbackResource(DirectServeHtmlResource):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: SynapseRequest) -> None:
await self._oidc_handler.handle_oidc_callback(request)
async def _async_render_POST(self, request):
async def _async_render_POST(self, request: SynapseRequest) -> None:
# the auth response can be returned via an x-www-form-urlencoded form instead
# of GET params, as per
# https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.

View File

@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Generator, List, Tuple
from twisted.web.resource import Resource
from twisted.web.server import Request
@ -27,6 +27,7 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_boolean, parse_string
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util.templates import build_jinja_env
if TYPE_CHECKING:
@ -57,7 +58,7 @@ class AvailabilityCheckResource(DirectServeJsonResource):
super().__init__()
self._sso_handler = hs.get_sso_handler()
async def _async_render_GET(self, request: Request):
async def _async_render_GET(self, request: Request) -> Tuple[int, JsonDict]:
localpart = parse_string(request, "username", required=True)
session_id = get_username_mapping_session_cookie_from_request(request)
@ -73,7 +74,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
super().__init__()
self._sso_handler = hs.get_sso_handler()
def template_search_dirs():
def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir:
@ -104,7 +105,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
async def _async_render_POST(self, request: SynapseRequest):
async def _async_render_POST(self, request: SynapseRequest) -> None:
# This will always be set by the time Twisted calls us.
assert request.args is not None

View File

@ -13,17 +13,21 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from twisted.web.resource import Resource
from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class SAML2Resource(Resource):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
self.putChild(b"authn_response", SAML2ResponseResource(hs))

View File

@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import saml2.metadata
from twisted.web.resource import Resource
from twisted.web.server import Request
if TYPE_CHECKING:
from synapse.server import HomeServer
class SAML2MetadataResource(Resource):
@ -23,11 +28,11 @@ class SAML2MetadataResource(Resource):
isLeaf = 1
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.sp_config = hs.config.saml2_sp_config
def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
metadata_xml = saml2.metadata.create_metadata_string(
configfile=None, config=self.sp_config
)

View File

@ -15,7 +15,10 @@
from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeHtmlResource
from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -31,7 +34,7 @@ class SAML2ResponseResource(DirectServeHtmlResource):
self._saml_handler = hs.get_saml_handler()
self._sso_handler = hs.get_sso_handler()
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
@ -40,5 +43,5 @@ class SAML2ResponseResource(DirectServeHtmlResource):
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
async def _async_render_POST(self, request):
async def _async_render_POST(self, request: SynapseRequest) -> None:
await self._saml_handler.handle_saml_response(request)

View File

@ -13,26 +13,26 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.http.server import set_cors_headers
from synapse.types import JsonDict
from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class WellKnownBuilder:
"""Utility to construct the well-known response
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self._config = hs.config
def get_well_known(self):
def get_well_known(self) -> Optional[JsonDict]:
# if we don't have a public_baseurl, we can't help much here.
if self._config.server.public_baseurl is None:
return None
@ -52,11 +52,11 @@ class WellKnownResource(Resource):
isLeaf = 1
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self._well_known_builder = WellKnownBuilder(hs)
def render_GET(self, request):
def render_GET(self, request: Request) -> bytes:
set_cors_headers(request)
r = self._well_known_builder.get_well_known()
if not r: