Additional type hints for the client REST servlets (part 3). (#10707)

This commit is contained in:
Patrick Cloke 2021-08-31 13:22:29 -04:00 committed by GitHub
parent 78e590d473
commit 287918e2d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 306 additions and 150 deletions

1
changelog.d/10707.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to REST servlets.

View File

@ -13,12 +13,19 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,13 +39,15 @@ class AccountDataServlet(RestServlet):
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, account_data_type): async def on_PUT(
self, request: SynapseRequest, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
@ -49,7 +58,9 @@ class AccountDataServlet(RestServlet):
return 200, {} return 200, {}
async def on_GET(self, request, user_id, account_data_type): async def on_GET(
self, request: SynapseRequest, user_id: str, account_data_type: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.") raise AuthError(403, "Cannot get account data for other users.")
@ -76,13 +87,19 @@ class RoomAccountDataServlet(RestServlet):
"/account_data/(?P<account_data_type>[^/]*)" "/account_data/(?P<account_data_type>[^/]*)"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler() self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, account_data_type): async def on_PUT(
self,
request: SynapseRequest,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
@ -102,7 +119,13 @@ class RoomAccountDataServlet(RestServlet):
return 200, {} return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type): async def on_GET(
self,
request: SynapseRequest,
user_id: str,
room_id: str,
account_data_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.") raise AuthError(403, "Cannot get account data for other users.")
@ -117,6 +140,6 @@ class RoomAccountDataServlet(RestServlet):
return 200, event return 200, event
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
AccountDataServlet(hs).register(http_server) AccountDataServlet(hs).register(http_server)
RoomAccountDataServlet(hs).register(http_server) RoomAccountDataServlet(hs).register(http_server)

View File

@ -156,7 +156,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
group_id: str, group_id: str,
category_id: Optional[str], category_id: Optional[str],
room_id: str, room_id: str,
): ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -188,7 +188,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
): ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -451,7 +451,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
): ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -674,7 +674,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
): ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -706,7 +706,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: SynapseRequest, group_id, user_id self, request: SynapseRequest, group_id: str, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
@ -738,7 +738,7 @@ class GroupAdminUsersKickServlet(RestServlet):
@_validate_group_id @_validate_group_id
async def on_PUT( async def on_PUT(
self, request: SynapseRequest, group_id, user_id self, request: SynapseRequest, group_id: str, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()

View File

@ -13,13 +13,20 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReadReceiptEventFields from synapse.api.constants import ReadReceiptEventFields
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,14 +37,16 @@ class ReceiptRestServlet(RestServlet):
"/(?P<event_id>[^/]*)$" "/(?P<event_id>[^/]*)$"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.receipts_handler = hs.get_receipts_handler() self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
async def on_POST(self, request, room_id, receipt_type, event_id): async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read": if receipt_type != "m.read":
@ -67,5 +76,5 @@ class ReceiptRestServlet(RestServlet):
return 200, {} return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server) ReceiptRestServlet(hs).register(http_server)

View File

@ -14,7 +14,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from typing import List, Union from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.server import Request
import synapse import synapse
import synapse.api.auth import synapse.api.auth
@ -29,15 +31,13 @@ from synapse.api.errors import (
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.config.captcha import CaptchaConfig
from synapse.config.consent import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
@ -45,6 +45,7 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer from synapse.push.mailer import Mailer
from synapse.types import JsonDict from synapse.types import JsonDict
@ -59,17 +60,16 @@ from synapse.util.threepids import (
from ._base import client_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet): class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/email/requestToken$") PATTERNS = client_patterns("/register/email/requestToken$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_identity_handler() self.identity_handler = hs.get_identity_handler()
@ -83,7 +83,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=self.config.email_registration_template_text, template_text=self.config.email_registration_template_text,
) )
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config: if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -171,16 +171,12 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet): class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/register/msisdn/requestToken$") PATTERNS = client_patterns("/register/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.identity_handler = hs.get_identity_handler() self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict( assert_params_in_dict(
@ -255,11 +251,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
"/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True "/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -272,7 +264,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.config.email_registration_template_failure_html self.config.email_registration_template_failure_html
) )
async def on_GET(self, request, medium): async def on_GET(self, request: Request, medium: str) -> None:
if medium != "email": if medium != "email":
raise SynapseError( raise SynapseError(
400, "This medium is currently not supported for registration" 400, "This medium is currently not supported for registration"
@ -326,11 +318,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet): class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_patterns("/register/available") PATTERNS = client_patterns("/register/available")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
@ -350,7 +338,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
), ),
) )
async def on_GET(self, request): async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
if not self.hs.config.enable_registration: if not self.hs.config.enable_registration:
raise SynapseError( raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
@ -419,11 +407,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_patterns("/register$") PATTERNS = client_patterns("/register$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
@ -445,23 +429,21 @@ class RegisterRestServlet(RestServlet):
) )
@interactive_auth_handler @interactive_auth_handler
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
client_addr = request.getClientIP() client_addr = request.getClientIP()
await self.ratelimiter.ratelimit(None, client_addr, update=False) await self.ratelimiter.ratelimit(None, client_addr, update=False)
kind = b"user" kind = parse_string(request, "kind", default="user")
if b"kind" in request.args:
kind = request.args[b"kind"][0]
if kind == b"guest": if kind == "guest":
ret = await self._do_guest_registration(body, address=client_addr) ret = await self._do_guest_registration(body, address=client_addr)
return ret return ret
elif kind != b"user": elif kind != "user":
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind.decode("utf8"),) f"Do not understand membership kind: {kind}",
) )
if self._msc2918_enabled: if self._msc2918_enabled:
@ -749,7 +731,7 @@ class RegisterRestServlet(RestServlet):
async def _do_appservice_registration( async def _do_appservice_registration(
self, username, as_token, body, should_issue_refresh_token: bool = False self, username, as_token, body, should_issue_refresh_token: bool = False
): ) -> JsonDict:
user_id = await self.registration_handler.appservice_register( user_id = await self.registration_handler.appservice_register(
username, as_token username, as_token
) )
@ -766,7 +748,7 @@ class RegisterRestServlet(RestServlet):
params: JsonDict, params: JsonDict,
is_appservice_ghost: bool = False, is_appservice_ghost: bool = False,
should_issue_refresh_token: bool = False, should_issue_refresh_token: bool = False,
): ) -> JsonDict:
"""Complete registration of newly-registered user """Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token. Allocates device_id if one was not given; also creates access_token.
@ -810,7 +792,9 @@ class RegisterRestServlet(RestServlet):
return result return result
async def _do_guest_registration(self, params, address=None): async def _do_guest_registration(
self, params: JsonDict, address: Optional[str] = None
) -> Tuple[int, JsonDict]:
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled") raise SynapseError(403, "Guest access is disabled")
user_id = await self.registration_handler.register_user( user_id = await self.registration_handler.register_user(
@ -848,9 +832,7 @@ class RegisterRestServlet(RestServlet):
def _calculate_registration_flows( def _calculate_registration_flows(
# technically `config` has to provide *all* of these interfaces, not just one config: HomeServerConfig, auth_handler: AuthHandler
config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
auth_handler: AuthHandler,
) -> List[List[str]]: ) -> List[List[str]]:
"""Get a suitable flows list for registration """Get a suitable flows list for registration
@ -929,7 +911,7 @@ def _calculate_registration_flows(
return flows return flows
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
EmailRegisterRequestTokenRestServlet(hs).register(http_server) EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server) MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server) UsernameAvailabilityRestServlet(hs).register(http_server)

View File

@ -19,25 +19,32 @@ any time to reflect changes in the MSC.
""" """
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import ShadowBanError, SynapseError from synapse.api.errors import ShadowBanError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.http.site import SynapseRequest
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import ( from synapse.storage.relations import (
AggregationPaginationToken, AggregationPaginationToken,
PaginationChunk, PaginationChunk,
RelationPaginationToken, RelationPaginationToken,
) )
from synapse.types import JsonDict
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from ._base import client_patterns from ._base import client_patterns
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,13 +66,13 @@ class RelationSendServlet(RestServlet):
"/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)" "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
http_server.register_paths( http_server.register_paths(
"POST", "POST",
client_patterns(self.PATTERN + "$", releases=()), client_patterns(self.PATTERN + "$", releases=()),
@ -79,14 +86,35 @@ class RelationSendServlet(RestServlet):
self.__class__.__name__, self.__class__.__name__,
) )
def on_PUT(self, request, *args, **kwargs): def on_PUT(
self,
request: SynapseRequest,
room_id: str,
parent_id: str,
relation_type: str,
event_type: str,
txn_id: Optional[str] = None,
) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_PUT_or_POST, request, *args, **kwargs request,
self.on_PUT_or_POST,
request,
room_id,
parent_id,
relation_type,
event_type,
txn_id,
) )
async def on_PUT_or_POST( async def on_PUT_or_POST(
self, request, room_id, parent_id, relation_type, event_type, txn_id=None self,
): request: SynapseRequest,
room_id: str,
parent_id: str,
relation_type: str,
event_type: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
@ -136,7 +164,7 @@ class RelationPaginationServlet(RestServlet):
releases=(), releases=(),
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -145,8 +173,13 @@ class RelationPaginationServlet(RestServlet):
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
async def on_GET( async def on_GET(
self, request, room_id, parent_id, relation_type=None, event_type=None self,
): request: SynapseRequest,
room_id: str,
parent_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable( await self.auth.check_user_in_room_or_world_readable(
@ -156,6 +189,8 @@ class RelationPaginationServlet(RestServlet):
# This gets the original event and checks that a) the event exists and # This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it. # b) the user is allowed to view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id) event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from") from_token_str = parse_string(request, "from")
@ -233,15 +268,20 @@ class RelationAggregationPaginationServlet(RestServlet):
releases=(), releases=(),
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
async def on_GET( async def on_GET(
self, request, room_id, parent_id, relation_type=None, event_type=None self,
): request: SynapseRequest,
room_id: str,
parent_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable( await self.auth.check_user_in_room_or_world_readable(
@ -253,6 +293,8 @@ class RelationAggregationPaginationServlet(RestServlet):
# This checks that a) the event exists and b) the user is allowed to # This checks that a) the event exists and b) the user is allowed to
# view it. # view it.
event = await self.event_handler.get_event(requester.user, room_id, parent_id) event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if event is None:
raise SynapseError(404, "Unknown parent event.")
if relation_type not in (RelationTypes.ANNOTATION, None): if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
@ -315,7 +357,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
releases=(), releases=(),
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -323,7 +365,15 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): async def on_GET(
self,
request: SynapseRequest,
room_id: str,
parent_id: str,
relation_type: str,
event_type: str,
key: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable( await self.auth.check_user_in_room_or_world_readable(
@ -374,7 +424,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
return 200, return_value return 200, return_value
def register_servlets(hs, http_server): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationSendServlet(hs).register(http_server) RelationSendServlet(hs).register(http_server)
RelationPaginationServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server)
RelationAggregationPaginationServlet(hs).register(http_server) RelationAggregationPaginationServlet(hs).register(http_server)

View File

@ -16,9 +16,11 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """ """ This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging import logging
import re import re
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from urllib import parse as urlparse from urllib import parse as urlparse
from twisted.web.server import Request
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -30,6 +32,7 @@ from synapse.api.errors import (
) )
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2 from synapse.events.utils import format_event_for_client_v2
from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
ResolveRoomIdMixin, ResolveRoomIdMixin,
RestServlet, RestServlet,
@ -57,7 +60,7 @@ logger = logging.getLogger(__name__)
class TransactionRestServlet(RestServlet): class TransactionRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
@ -65,20 +68,22 @@ class TransactionRestServlet(RestServlet):
class RoomCreateRestServlet(TransactionRestServlet): class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here # No PATTERN; we have custom dispatch rules here
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler() self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
PATTERNS = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
def on_PUT(self, request, txn_id): def on_PUT(
self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request) return self.txns.fetch_or_execute_request(request, self.on_POST, request)
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
info, _ = await self._room_creation_handler.create_room( info, _ = await self._room_creation_handler.create_room(
@ -87,21 +92,21 @@ class RoomCreateRestServlet(TransactionRestServlet):
return 200, info return 200, info
def get_room_config(self, request): def get_room_config(self, request: Request) -> JsonDict:
user_supplied_config = parse_json_object_from_request(request) user_supplied_config = parse_json_object_from_request(request)
return user_supplied_config return user_supplied_config
# TODO: Needs unit testing for generic events # TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet): class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@ -136,13 +141,19 @@ class RoomStateEventRestServlet(TransactionRestServlet):
self.__class__.__name__, self.__class__.__name__,
) )
def on_GET_no_state_key(self, request, room_id, event_type): def on_GET_no_state_key(
self, request: SynapseRequest, room_id: str, event_type: str
) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_GET(request, room_id, event_type, "") return self.on_GET(request, room_id, event_type, "")
def on_PUT_no_state_key(self, request, room_id, event_type): def on_PUT_no_state_key(
self, request: SynapseRequest, room_id: str, event_type: str
) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_PUT(request, room_id, event_type, "") return self.on_PUT(request, room_id, event_type, "")
async def on_GET(self, request, room_id, event_type, state_key): async def on_GET(
self, request: SynapseRequest, room_id: str, event_type: str, state_key: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string( format = parse_string(
request, "format", default="content", allowed_values=["content", "event"] request, "format", default="content", allowed_values=["content", "event"]
@ -165,7 +176,17 @@ class RoomStateEventRestServlet(TransactionRestServlet):
elif format == "content": elif format == "content":
return 200, data.get_dict()["content"] return 200, data.get_dict()["content"]
async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): # Format must be event or content, per the parse_string call above.
raise RuntimeError(f"Unknown format: {format:r}.")
async def on_PUT(
self,
request: SynapseRequest,
room_id: str,
event_type: str,
state_key: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if txn_id: if txn_id:
@ -211,27 +232,35 @@ class RoomStateEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet): class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True) register_txn_path(self, PATTERNS, http_server, with_get=True)
async def on_POST(self, request, room_id, event_type, txn_id=None): async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_type: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
event_dict = { event_dict: JsonDict = {
"type": event_type, "type": event_type,
"content": content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
} }
# Twisted will have processed the args by now.
assert request.args is not None
if b"ts" in request.args and requester.app_service: if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
@ -249,10 +278,14 @@ class RoomSendEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id) set_tag("event_id", event_id)
return 200, {"event_id": event_id} return 200, {"event_id": event_id}
def on_GET(self, request, room_id, event_type, txn_id): def on_GET(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Tuple[int, str]:
return 200, "Not implemented" return 200, "Not implemented"
def on_PUT(self, request, room_id, event_type, txn_id): def on_PUT(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
@ -262,12 +295,12 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up
self.auth = hs.get_auth() self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
PATTERNS = "/join/(?P<room_identifier>[^/]*)" PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@ -277,7 +310,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
request: SynapseRequest, request: SynapseRequest,
room_identifier: str, room_identifier: str,
txn_id: Optional[str] = None, txn_id: Optional[str] = None,
): ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
try: try:
@ -308,7 +341,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
return 200, {"room_id": room_id} return 200, {"room_id": room_id}
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(
self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
@ -320,12 +355,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
class PublicRoomListRestServlet(TransactionRestServlet): class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True) PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
server = parse_string(request, "server") server = parse_string(request, "server")
try: try:
@ -374,7 +409,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
return 200, data return 200, data
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server") server = parse_string(request, "server")
@ -438,13 +473,15 @@ class PublicRoomListRestServlet(TransactionRestServlet):
class RoomMemberListRestServlet(RestServlet): class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True) PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, room_id): async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler handler = self.message_handler
@ -490,12 +527,14 @@ class RoomMemberListRestServlet(RestServlet):
class JoinedRoomMemberListRestServlet(RestServlet): class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True) PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, room_id): async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
users_with_profile = await self.message_handler.get_joined_members( users_with_profile = await self.message_handler.get_joined_members(
@ -509,17 +548,21 @@ class JoinedRoomMemberListRestServlet(RestServlet):
class RoomMessageListRestServlet(RestServlet): class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True) PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, room_id): async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request( pagination_config = await PaginationConfig.from_request(
self.store, request, default_limit=10 self.store, request, default_limit=10
) )
# Twisted will have processed the args by now.
assert request.args is not None
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8") filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str: if filter_str:
@ -549,12 +592,14 @@ class RoomMessageListRestServlet(RestServlet):
class RoomStateRestServlet(RestServlet): class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True) PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, room_id): async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, List[JsonDict]]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room # Get all the current state for this room
events = await self.message_handler.get_state_events( events = await self.message_handler.get_state_events(
@ -569,13 +614,15 @@ class RoomStateRestServlet(RestServlet):
class RoomInitialSyncRestServlet(RestServlet): class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True) PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, room_id): async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(self.store, request) pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync( content = await self.initial_sync_handler.room_initial_sync(
@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, room_id, event_id): async def on_GET(
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
try: try:
event = await self.event_handler.get_event( event = await self.event_handler.get_event(
@ -610,10 +659,10 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:
event = await self._event_serializer.serialize_event(event, time_now) event_dict = await self._event_serializer.serialize_event(event, time_now)
return 200, event return 200, event_dict
return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
class RoomEventContextServlet(RestServlet): class RoomEventContextServlet(RestServlet):
@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler() self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, room_id, event_id): async def on_GET(
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10) limit = parse_integer(request, "limit", default=10)
@ -669,23 +720,27 @@ class RoomEventContextServlet(RestServlet):
class RoomForgetRestServlet(TransactionRestServlet): class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
async def on_POST(self, request, room_id, txn_id=None): async def on_POST(
self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
await self.room_member_handler.forget(user=requester.user, room_id=room_id) await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {} return 200, {}
def on_PUT(self, request, room_id, txn_id): def on_PUT(
self, request: SynapseRequest, room_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
@ -695,12 +750,12 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet): class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERNS = ( PATTERNS = (
"/rooms/(?P<room_id>[^/]*)/" "/rooms/(?P<room_id>[^/]*)/"
@ -708,7 +763,13 @@ class RoomMembershipRestServlet(TransactionRestServlet):
) )
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
async def on_POST(self, request, room_id, membership_action, txn_id=None): async def on_POST(
self,
request: SynapseRequest,
room_id: str,
membership_action: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in { if requester.is_guest and membership_action not in {
@ -771,13 +832,15 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value return 200, return_value
def _has_3pid_invite_keys(self, content): def _has_3pid_invite_keys(self, content: JsonDict) -> bool:
for key in {"id_server", "medium", "address"}: for key in {"id_server", "medium", "address"}:
if key not in content: if key not in content:
return False return False
return True return True
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
@ -786,16 +849,22 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
async def on_POST(self, request, room_id, event_id, txn_id=None): async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_id: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -821,7 +890,9 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id) set_tag("event_id", event_id)
return 200, {"event_id": event_id} return 200, {"event_id": event_id}
def on_PUT(self, request, room_id, event_id, txn_id): def on_PUT(
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id) set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
@ -846,7 +917,9 @@ class RoomTypingRestServlet(RestServlet):
hs.config.worker.writers.typing == hs.get_instance_name() hs.config.worker.writers.typing == hs.get_instance_name()
) )
async def on_PUT(self, request, room_id, user_id): async def on_PUT(
self, request: SynapseRequest, room_id: str, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if not self._is_typing_writer: if not self._is_typing_writer:
@ -897,7 +970,9 @@ class RoomAliasListServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.directory_handler = hs.get_directory_handler() self.directory_handler = hs.get_directory_handler()
async def on_GET(self, request, room_id): async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
alias_list = await self.directory_handler.get_aliases_for_room( alias_list = await self.directory_handler.get_aliases_for_room(
@ -910,12 +985,12 @@ class RoomAliasListServlet(RestServlet):
class SearchRestServlet(RestServlet): class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True) PATTERNS = client_patterns("/search$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.search_handler = hs.get_search_handler() self.search_handler = hs.get_search_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_POST(self, request): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -929,19 +1004,24 @@ class SearchRestServlet(RestServlet):
class JoinedRoomsRestServlet(RestServlet): class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True) PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
return 200, {"joined_rooms": list(room_ids)} return 200, {"joined_rooms": list(room_ids)}
def register_txn_path(servlet, regex_string, http_server, with_get=False): def register_txn_path(
servlet: RestServlet,
regex_string: str,
http_server: HttpServer,
with_get: bool = False,
) -> None:
"""Registers a transaction-based path. """Registers a transaction-based path.
This registers two paths: This registers two paths:
@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
POST regex_string POST regex_string
Args: Args:
regex_string (str): The regex string to register. Must NOT have a regex_string: The regex string to register. Must NOT have a
trailing $ as this string will be appended to. trailing $ as this string will be appended to.
http_server: The http_server to register paths with. http_server: The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs. with_get: True to also register respective GET paths for the PUTs.
""" """
on_POST = getattr(servlet, "on_POST", None)
on_PUT = getattr(servlet, "on_PUT", None)
if on_POST is None or on_PUT is None:
raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path")
http_server.register_paths( http_server.register_paths(
"POST", "POST",
client_patterns(regex_string + "$", v1=True), client_patterns(regex_string + "$", v1=True),
servlet.on_POST, on_POST,
servlet.__class__.__name__, servlet.__class__.__name__,
) )
http_server.register_paths( http_server.register_paths(
"PUT", "PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT, on_PUT,
servlet.__class__.__name__, servlet.__class__.__name__,
) )
on_GET = getattr(servlet, "on_GET", None)
if with_get: if with_get:
if on_GET is None:
raise RuntimeError(
"register_txn_path called with with_get = True, but no on_GET method exists"
)
http_server.register_paths( http_server.register_paths(
"GET", "GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET, on_GET,
servlet.__class__.__name__, servlet.__class__.__name__,
) )
@ -1120,7 +1209,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
) )
def register_servlets(hs: "HomeServer", http_server, is_worker=False): def register_servlets(
hs: "HomeServer", http_server: HttpServer, is_worker: bool = False
) -> None:
RoomStateEventRestServlet(hs).register(http_server) RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server) JoinedRoomMemberListRestServlet(hs).register(http_server)
@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomForgetRestServlet(hs).register(http_server) RoomForgetRestServlet(hs).register(http_server)
def register_deprecated_servlets(hs, http_server): def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomInitialSyncRestServlet(hs).register(http_server) RoomInitialSyncRestServlet(hs).register(http_server)