Additional type hints for the client REST servlets (part 3). (#10707)

This commit is contained in:
Patrick Cloke 2021-08-31 13:22:29 -04:00 committed by GitHub
parent 78e590d473
commit 287918e2d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 306 additions and 150 deletions

View file

@ -14,7 +14,9 @@
# limitations under the License.
import logging
import random
from typing import List, Union
from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.server import Request
import synapse
import synapse.api.auth
@ -29,15 +31,13 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
from synapse.config.captcha import CaptchaConfig
from synapse.config.consent import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@ -45,6 +45,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.types import JsonDict
@ -59,17 +60,16 @@ from synapse.util.threepids import (
from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/email/requestToken$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=self.config.email_registration_template_text,
)
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/msisdn/requestToken$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(
@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
"/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.config.email_registration_template_failure_html
)
async def on_GET(self, request, medium):
async def on_GET(self, request: Request, medium: str) -> None:
if medium != "email":
raise SynapseError(
400, "This medium is currently not supported for registration"
@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_patterns("/register/available")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.registration_handler = hs.get_registration_handler()
@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
),
)
async def on_GET(self, request):
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
@ -419,11 +407,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@ -445,23 +429,21 @@ class RegisterRestServlet(RestServlet):
)
@interactive_auth_handler
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
client_addr = request.getClientIP()
await self.ratelimiter.ratelimit(None, client_addr, update=False)
kind = b"user"
if b"kind" in request.args:
kind = request.args[b"kind"][0]
kind = parse_string(request, "kind", default="user")
if kind == b"guest":
if kind == "guest":
ret = await self._do_guest_registration(body, address=client_addr)
return ret
elif kind != b"user":
elif kind != "user":
raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
f"Do not understand membership kind: {kind}",
)
if self._msc2918_enabled:
@ -749,7 +731,7 @@ class RegisterRestServlet(RestServlet):
async def _do_appservice_registration(
self, username, as_token, body, should_issue_refresh_token: bool = False
):
) -> JsonDict:
user_id = await self.registration_handler.appservice_register(
username, as_token
)
@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet):
params: JsonDict,
is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False,
):
) -> JsonDict:
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet):
return result
async def _do_guest_registration(self, params, address=None):
async def _do_guest_registration(
self, params: JsonDict, address: Optional[str] = None
) -> Tuple[int, JsonDict]:
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
user_id = await self.registration_handler.register_user(
@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet):
def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
auth_handler: AuthHandler,
config: HomeServerConfig, auth_handler: AuthHandler
) -> List[List[str]]:
"""Get a suitable flows list for registration
@ -929,7 +911,7 @@ def _calculate_registration_flows(
return flows
def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)