Add missing type hints to non-client REST servlets. (#10817)

Including admin, consent, key, synapse, and media. All REST servlets
(the synapse.rest module) now require typed method definitions.
This commit is contained in:
Patrick Cloke 2021-09-15 08:45:32 -04:00 committed by GitHub
parent 8c7a531e27
commit b93259082c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 169 additions and 96 deletions

View File

@ -1 +1 @@
Convert the internal `FileInfo` class to attrs and add type hints. Add missing type hints to REST servlets.

1
changelog.d/10817.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to REST servlets.

View File

@ -90,7 +90,7 @@ files =
tests/util/test_itertools.py, tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py
[mypy-synapse.rest.client.*] [mypy-synapse.rest.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.util.batching_queue] [mypy-synapse.util.batching_queue]

View File

@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.server import JsonResource from typing import TYPE_CHECKING
from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import ( from synapse.rest.client import (
account, account,
@ -57,6 +59,9 @@ from synapse.rest.client import (
voip, voip,
) )
if TYPE_CHECKING:
from synapse.server import HomeServer
class ClientRestResource(JsonResource): class ClientRestResource(JsonResource):
"""Matrix Client API REST resource. """Matrix Client API REST resource.
@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
* etc * etc
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False) JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs) self.register_servlets(self, hs)
@staticmethod @staticmethod
def register_servlets(client_resource, hs): def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
versions.register_servlets(hs, client_resource) versions.register_servlets(hs, client_resource)
# Deprecated in r0 # Deprecated in r0

View File

@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id, device_id: str self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)

View File

@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet):
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
def register(self, json_resource: HttpServer): def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice" PATTERN = "/send_server_notice"
json_resource.register_paths( json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__ "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__

View File

@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet):
self.nonces: Dict[str, int] = {} self.nonces: Dict[str, int] = {}
self.hs = hs self.hs = hs
def _clear_old_nonces(self): def _clear_old_nonces(self) -> None:
""" """
Clear out old nonces that are older than NONCE_TIMEOUT. Clear out old nonces that are older than NONCE_TIMEOUT.
""" """

View File

@ -17,17 +17,22 @@ import logging
from hashlib import sha256 from hashlib import sha256
from http import HTTPStatus from http import HTTPStatus
from os import path from os import path
from typing import Dict, List from typing import TYPE_CHECKING, Any, Dict, List
import jinja2 import jinja2
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
from twisted.web.server import Request
from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_bytes_from_args, parse_string from synapse.http.servlet import parse_bytes_from_args, parse_string
from synapse.types import UserID from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
# language to use for the templates. TODO: figure this out from Accept-Language # language to use for the templates. TODO: figure this out from Accept-Language
TEMPLATE_LANGUAGE = "en" TEMPLATE_LANGUAGE = "en"
@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
against the user. against the user.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): homeserver
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
@ -106,18 +107,14 @@ class ConsentResource(DirectServeHtmlResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8") self._hmac_secret = hs.config.form_secret.encode("utf-8")
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
"""
Args:
request (twisted.web.http.Request):
"""
version = parse_string(request, "v", default=self._default_consent_version) version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", default="") username = parse_string(request, "u", default="")
userhmac = None userhmac = None
has_consented = False has_consented = False
public_version = username == "" public_version = username == ""
if not public_version: if not public_version:
args: Dict[bytes, List[bytes]] = request.args args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac_bytes = parse_bytes_from_args(args, "h", required=True) userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes) self._check_hash(username, userhmac_bytes)
@ -147,14 +144,10 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound: except TemplateNotFound:
raise NotFoundError("Unknown policy version") raise NotFoundError("Unknown policy version")
async def _async_render_POST(self, request): async def _async_render_POST(self, request: Request) -> None:
"""
Args:
request (twisted.web.http.Request):
"""
version = parse_string(request, "v", required=True) version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True) username = parse_string(request, "u", required=True)
args: Dict[bytes, List[bytes]] = request.args args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac = parse_bytes_from_args(args, "h", required=True) userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac) self._check_hash(username, userhmac)
@ -177,7 +170,9 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound: except TemplateNotFound:
raise NotFoundError("success.html not found") raise NotFoundError("success.html not found")
def _render_template(self, request, template_name, **template_args): def _render_template(
self, request: Request, template_name: str, **template_args: Any
) -> None:
# get_template checks for ".." so we don't need to worry too much # get_template checks for ".." so we don't need to worry too much
# about path traversal here. # about path traversal here.
template_html = self._jinja_env.get_template( template_html = self._jinja_env.get_template(
@ -186,11 +181,11 @@ class ConsentResource(DirectServeHtmlResource):
html = template_html.render(**template_args) html = template_html.render(**template_args)
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
def _check_hash(self, userid, userhmac): def _check_hash(self, userid: str, userhmac: bytes) -> None:
""" """
Args: Args:
userid (unicode): userid:
userhmac (bytes): userhmac:
Raises: Raises:
SynapseError if the hash doesn't match SynapseError if the hash doesn't match

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
class HealthResource(Resource): class HealthResource(Resource):
@ -25,6 +26,6 @@ class HealthResource(Resource):
isLeaf = 1 isLeaf = 1
def render_GET(self, request): def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", b"text/plain") request.setHeader(b"Content-Type", b"text/plain")
return b"OK" return b"OK"

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from twisted.web.resource import Resource from twisted.web.resource import Resource
from .local_key_resource import LocalKey from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey from .remote_key_resource import RemoteKey
if TYPE_CHECKING:
from synapse.server import HomeServer
class KeyApiV2Resource(Resource): class KeyApiV2Resource(Resource):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
Resource.__init__(self) Resource.__init__(self)
self.putChild(b"server", LocalKey(hs)) self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs)) self.putChild(b"query", RemoteKey(hs))

View File

@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.http.server import respond_with_json_bytes from synapse.http.server import respond_with_json_bytes
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,18 +63,18 @@ class LocalKey(Resource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.config = hs.config self.config = hs.config
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec()) self.update_response_body(self.clock.time_msec())
Resource.__init__(self) Resource.__init__(self)
def update_response_body(self, time_now_msec): def update_response_body(self, time_now_msec: int) -> None:
refresh_interval = self.config.key_refresh_interval refresh_interval = self.config.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval) self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object()) self.response_body = encode_canonical_json(self.response_json_object())
def response_json_object(self): def response_json_object(self) -> JsonDict:
verify_keys = {} verify_keys = {}
for key in self.config.signing_key: for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode() verify_key_bytes = key.verify_key.encode()
@ -94,7 +99,7 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key) json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object return json_object
def render_GET(self, request): def render_GET(self, request: Request) -> int:
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains. # Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts: if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:

View File

@ -13,17 +13,23 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.types import JsonDict
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,7 +91,7 @@ class RemoteKey(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.fetcher = ServerKeyFetcher(hs) self.fetcher = ServerKeyFetcher(hs)
@ -94,7 +100,8 @@ class RemoteKey(DirectServeJsonResource):
self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config self.config = hs.config
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
assert request.postpath is not None
if len(request.postpath) == 1: if len(request.postpath) == 1:
(server,) = request.postpath (server,) = request.postpath
query: dict = {server.decode("ascii"): {}} query: dict = {server.decode("ascii"): {}}
@ -110,14 +117,19 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True) await self.query_keys(request, query, query_remote_on_cache_miss=True)
async def _async_render_POST(self, request): async def _async_render_POST(self, request: Request) -> None:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
query = content["server_keys"] query = content["server_keys"]
await self.query_keys(request, query, query_remote_on_cache_miss=True) await self.query_keys(request, query, query_remote_on_cache_miss=True)
async def query_keys(self, request, query, query_remote_on_cache_miss=False): async def query_keys(
self,
request: Request,
query: JsonDict,
query_remote_on_cache_miss: bool = False,
) -> None:
logger.info("Handling query for keys %r", query) logger.info("Handling query for keys %r", query)
store_queries = [] store_queries = []
@ -142,8 +154,8 @@ class RemoteKey(DirectServeJsonResource):
# Note that the value is unused. # Note that the value is unused.
cache_misses: Dict[str, Dict[str, int]] = {} cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), results in cached.items(): for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in results] results = [(result["ts_added_ms"], result) for result in key_results]
if not results and key_id is not None: if not results and key_id is not None:
cache_misses.setdefault(server_name, {})[key_id] = 0 cache_misses.setdefault(server_name, {})[key_id] = 0
@ -230,6 +242,6 @@ class RemoteKey(DirectServeJsonResource):
signed_keys.append(key_json) signed_keys.append(key_json)
results = {"server_keys": signed_keys} response = {"server_keys": signed_keys}
respond_with_json(request, 200, results, canonical_json=True) respond_with_json(request, 200, response, canonical_json=True)

View File

@ -16,7 +16,8 @@
import logging import logging
import os import os
import urllib import urllib
from typing import Awaitable, Dict, Generator, List, Optional, Tuple from types import TracebackType
from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
import attr import attr
@ -122,7 +123,7 @@ def add_file_headers(
upload_name: The name of the requested file, if any. upload_name: The name of the requested file, if any.
""" """
def _quote(x): def _quote(x: str) -> str:
return urllib.parse.quote(x.encode("utf-8")) return urllib.parse.quote(x.encode("utf-8"))
# Default to a UTF-8 charset for text content types. # Default to a UTF-8 charset for text content types.
@ -282,10 +283,15 @@ class Responder:
""" """
pass pass
def __enter__(self): def __enter__(self) -> None:
pass pass
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
pass pass
@ -317,31 +323,31 @@ class FileInfo:
# The below properties exist to maintain compatibility with third-party modules. # The below properties exist to maintain compatibility with third-party modules.
@property @property
def thumbnail_width(self): def thumbnail_width(self) -> Optional[int]:
if not self.thumbnail: if not self.thumbnail:
return None return None
return self.thumbnail.width return self.thumbnail.width
@property @property
def thumbnail_height(self): def thumbnail_height(self) -> Optional[int]:
if not self.thumbnail: if not self.thumbnail:
return None return None
return self.thumbnail.height return self.thumbnail.height
@property @property
def thumbnail_method(self): def thumbnail_method(self) -> Optional[str]:
if not self.thumbnail: if not self.thumbnail:
return None return None
return self.thumbnail.method return self.thumbnail.method
@property @property
def thumbnail_type(self): def thumbnail_type(self) -> Optional[str]:
if not self.thumbnail: if not self.thumbnail:
return None return None
return self.thumbnail.type return self.thumbnail.type
@property @property
def thumbnail_length(self): def thumbnail_length(self) -> Optional[int]:
if not self.thumbnail: if not self.thumbnail:
return None return None
return self.thumbnail.length return self.thumbnail.length

View File

@ -16,7 +16,7 @@
import functools import functools
import os import os
import re import re
from typing import Callable, List from typing import Any, Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
@ -27,7 +27,7 @@ def _wrap_in_base_path(func: Callable[..., str]) -> Callable[..., str]:
""" """
@functools.wraps(func) @functools.wraps(func)
def _wrapped(self, *args, **kwargs): def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str:
path = func(self, *args, **kwargs) path = func(self, *args, **kwargs)
return os.path.join(self.base_path, path) return os.path.join(self.base_path, path)
@ -129,7 +129,7 @@ class MediaFilePaths:
# using the new path. # using the new path.
def remote_media_thumbnail_rel_legacy( def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str self, server_name: str, file_id: str, width: int, height: int, content_type: str
): ) -> str:
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join( return os.path.join(

View File

@ -21,6 +21,7 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.internet.defer import Deferred
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request from twisted.web.server import Request
@ -32,6 +33,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.repository import ThumbnailRequirement
from synapse.logging.context import defer_to_thread from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID from synapse.types import UserID
@ -114,7 +116,7 @@ class MediaRepository:
self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
) )
def _start_update_recently_accessed(self): def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process( return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed "update_recently_accessed_media", self._update_recently_accessed
) )
@ -469,7 +471,9 @@ class MediaRepository:
return media_info return media_info
def _get_thumbnail_requirements(self, media_type): def _get_thumbnail_requirements(
self, media_type: str
) -> Tuple[ThumbnailRequirement, ...]:
scpos = media_type.find(";") scpos = media_type.find(";")
if scpos > 0: if scpos > 0:
media_type = media_type[:scpos] media_type = media_type[:scpos]

View File

@ -15,7 +15,20 @@ import contextlib
import logging import logging
import os import os
import shutil import shutil
from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence from types import TracebackType
from typing import (
IO,
TYPE_CHECKING,
Any,
Awaitable,
BinaryIO,
Callable,
Generator,
Optional,
Sequence,
Tuple,
Type,
)
import attr import attr
@ -83,12 +96,14 @@ class MediaStorage:
return fname return fname
async def write_to_file(self, source: IO, output: IO): async def write_to_file(self, source: IO, output: IO) -> None:
"""Asynchronously write the `source` to `output`.""" """Asynchronously write the `source` to `output`."""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output) await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager @contextlib.contextmanager
def store_into_file(self, file_info: FileInfo): def store_into_file(
self, file_info: FileInfo
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as """Context manager used to get a file like object to write into, as
described by file_info. described by file_info.
@ -125,7 +140,7 @@ class MediaStorage:
try: try:
with open(fname, "wb") as f: with open(fname, "wb") as f:
async def finish(): async def finish() -> None:
# Ensure that all writes have been flushed and close the # Ensure that all writes have been flushed and close the
# file. # file.
f.flush() f.flush()
@ -315,7 +330,12 @@ class FileResponder(Responder):
FileSender().beginFileTransfer(self.open_file, consumer) FileSender().beginFileTransfer(self.open_file, consumer)
) )
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.open_file.close() self.open_file.close()
@ -339,7 +359,7 @@ class ReadableFileWrapper:
clock = attr.ib(type=Clock) clock = attr.ib(type=Clock)
path = attr.ib(type=str) path = attr.ib(type=str)
async def write_chunks_to(self, callback: Callable[[bytes], None]): async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
"""Reads the file in chunks and calls the callback with each chunk.""" """Reads the file in chunks and calls the callback with each chunk."""
with open(self.path, "rb") as file: with open(self.path, "rb") as file:

View File

@ -27,6 +27,7 @@ from urllib import parse as urlparse
import attr import attr
from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.server import Request from twisted.web.server import Request
@ -473,7 +474,7 @@ class PreviewUrlResource(DirectServeJsonResource):
etag=etag, etag=etag,
) )
def _start_expire_url_cache_data(self): def _start_expire_url_cache_data(self) -> Deferred:
return run_as_background_process( return run_as_background_process(
"expire_url_cache_data", self._expire_url_cache_data "expire_url_cache_data", self._expire_url_cache_data
) )
@ -782,7 +783,7 @@ def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
def _iterate_over_text( def _iterate_over_text(
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion, """Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags. skipping text nodes inside certain tags.

View File

@ -99,7 +99,7 @@ class StorageProviderWrapper(StorageProvider):
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else: else:
# TODO: Handle errors. # TODO: Handle errors.
async def store(): async def store() -> None:
try: try:
return await maybe_awaitable( return await maybe_awaitable(
self.backend.store_file(path, file_info) self.backend.store_file(path, file_info)
@ -128,7 +128,7 @@ class FileStorageProviderBackend(StorageProvider):
self.cache_directory = hs.config.media_store_path self.cache_directory = hs.config.media_store_path
self.base_directory = config self.base_directory = config
def __str__(self): def __str__(self) -> str:
return "FileStorageProviderBackend[%s]" % (self.base_directory,) return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path: str, file_info: FileInfo) -> None: async def store_file(self, path: str, file_info: FileInfo) -> None:

View File

@ -41,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
@staticmethod @staticmethod
def set_limits(max_image_pixels: int): def set_limits(max_image_pixels: int) -> None:
Image.MAX_IMAGE_PIXELS = max_image_pixels Image.MAX_IMAGE_PIXELS = max_image_pixels
def __init__(self, input_path: str): def __init__(self, input_path: str):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Generator
from twisted.web.server import Request from twisted.web.server import Request
@ -45,7 +45,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
self._server_name = hs.hostname self._server_name = hs.hostname
self._consent_version = hs.config.consent.user_consent_version self._consent_version = hs.config.consent.user_consent_version
def template_search_dirs(): def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory: if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir: if hs.config.sso.sso_template_dir:
@ -88,7 +88,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
html = template.render(template_params) html = template.render(template_params)
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
async def _async_render_POST(self, request: Request): async def _async_render_POST(self, request: Request) -> None:
try: try:
session_id = get_username_mapping_session_cookie_from_request(request) session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e: except SynapseError as e:

View File

@ -13,16 +13,20 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OIDCResource(Resource): class OIDCResource(Resource):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
Resource.__init__(self) Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs)) self.putChild(b"callback", OIDCCallbackResource(hs))

View File

@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from synapse.http.server import DirectServeHtmlResource from synapse.http.server import DirectServeHtmlResource
from synapse.http.site import SynapseRequest
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -30,10 +31,10 @@ class OIDCCallbackResource(DirectServeHtmlResource):
super().__init__() super().__init__()
self._oidc_handler = hs.get_oidc_handler() self._oidc_handler = hs.get_oidc_handler()
async def _async_render_GET(self, request): async def _async_render_GET(self, request: SynapseRequest) -> None:
await self._oidc_handler.handle_oidc_callback(request) await self._oidc_handler.handle_oidc_callback(request)
async def _async_render_POST(self, request): async def _async_render_POST(self, request: SynapseRequest) -> None:
# the auth response can be returned via an x-www-form-urlencoded form instead # the auth response can be returned via an x-www-form-urlencoded form instead
# of GET params, as per # of GET params, as per
# https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html. # https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, Generator, List, Tuple
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request from twisted.web.server import Request
@ -27,6 +27,7 @@ from synapse.http.server import (
) )
from synapse.http.servlet import parse_boolean, parse_string from synapse.http.servlet import parse_boolean, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util.templates import build_jinja_env from synapse.util.templates import build_jinja_env
if TYPE_CHECKING: if TYPE_CHECKING:
@ -57,7 +58,7 @@ class AvailabilityCheckResource(DirectServeJsonResource):
super().__init__() super().__init__()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
async def _async_render_GET(self, request: Request): async def _async_render_GET(self, request: Request) -> Tuple[int, JsonDict]:
localpart = parse_string(request, "username", required=True) localpart = parse_string(request, "username", required=True)
session_id = get_username_mapping_session_cookie_from_request(request) session_id = get_username_mapping_session_cookie_from_request(request)
@ -73,7 +74,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
super().__init__() super().__init__()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
def template_search_dirs(): def template_search_dirs() -> Generator[str, None, None]:
if hs.config.server.custom_template_directory: if hs.config.server.custom_template_directory:
yield hs.config.server.custom_template_directory yield hs.config.server.custom_template_directory
if hs.config.sso.sso_template_dir: if hs.config.sso.sso_template_dir:
@ -104,7 +105,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
html = template.render(template_params) html = template.render(template_params)
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
async def _async_render_POST(self, request: SynapseRequest): async def _async_render_POST(self, request: SynapseRequest) -> None:
# This will always be set by the time Twisted calls us. # This will always be set by the time Twisted calls us.
assert request.args is not None assert request.args is not None

View File

@ -13,17 +13,21 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SAML2Resource(Resource): class SAML2Resource(Resource):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
Resource.__init__(self) Resource.__init__(self)
self.putChild(b"metadata.xml", SAML2MetadataResource(hs)) self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
self.putChild(b"authn_response", SAML2ResponseResource(hs)) self.putChild(b"authn_response", SAML2ResponseResource(hs))

View File

@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
import saml2.metadata import saml2.metadata
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
if TYPE_CHECKING:
from synapse.server import HomeServer
class SAML2MetadataResource(Resource): class SAML2MetadataResource(Resource):
@ -23,11 +28,11 @@ class SAML2MetadataResource(Resource):
isLeaf = 1 isLeaf = 1
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
Resource.__init__(self) Resource.__init__(self)
self.sp_config = hs.config.saml2_sp_config self.sp_config = hs.config.saml2_sp_config
def render_GET(self, request): def render_GET(self, request: Request) -> bytes:
metadata_xml = saml2.metadata.create_metadata_string( metadata_xml = saml2.metadata.create_metadata_string(
configfile=None, config=self.sp_config configfile=None, config=self.sp_config
) )

View File

@ -15,7 +15,10 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeHtmlResource from synapse.http.server import DirectServeHtmlResource
from synapse.http.site import SynapseRequest
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -31,7 +34,7 @@ class SAML2ResponseResource(DirectServeHtmlResource):
self._saml_handler = hs.get_saml_handler() self._saml_handler = hs.get_saml_handler()
self._sso_handler = hs.get_sso_handler() self._sso_handler = hs.get_sso_handler()
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
# We're not expecting any GET request on that resource if everything goes right, # We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint. # but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should # In this case, just tell the user that something went wrong and they should
@ -40,5 +43,5 @@ class SAML2ResponseResource(DirectServeHtmlResource):
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
) )
async def _async_render_POST(self, request): async def _async_render_POST(self, request: SynapseRequest) -> None:
await self._saml_handler.handle_saml_response(request) await self._saml_handler.handle_saml_response(request)

View File

@ -13,26 +13,26 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.http.server import set_cors_headers from synapse.http.server import set_cors_headers
from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WellKnownBuilder: class WellKnownBuilder:
"""Utility to construct the well-known response def __init__(self, hs: "HomeServer"):
Args:
hs (synapse.server.HomeServer):
"""
def __init__(self, hs):
self._config = hs.config self._config = hs.config
def get_well_known(self): def get_well_known(self) -> Optional[JsonDict]:
# if we don't have a public_baseurl, we can't help much here. # if we don't have a public_baseurl, we can't help much here.
if self._config.server.public_baseurl is None: if self._config.server.public_baseurl is None:
return None return None
@ -52,11 +52,11 @@ class WellKnownResource(Resource):
isLeaf = 1 isLeaf = 1
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
Resource.__init__(self) Resource.__init__(self)
self._well_known_builder = WellKnownBuilder(hs) self._well_known_builder = WellKnownBuilder(hs)
def render_GET(self, request): def render_GET(self, request: Request) -> bytes:
set_cors_headers(request) set_cors_headers(request)
r = self._well_known_builder.get_well_known() r = self._well_known_builder.get_well_known()
if not r: if not r: