Additional type hints for the client REST servlets (part 3). (#10707)

This commit is contained in:
Patrick Cloke 2021-08-31 13:22:29 -04:00 committed by GitHub
parent 78e590d473
commit 287918e2d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 306 additions and 150 deletions

View file

@ -16,9 +16,11 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
import re
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from urllib import parse as urlparse
from twisted.web.server import Request
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
@ -30,6 +32,7 @@ from synapse.api.errors import (
)
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
from synapse.http.server import HttpServer
from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
@ -57,7 +60,7 @@ logger = logging.getLogger(__name__)
class TransactionRestServlet(RestServlet):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.txns = HttpTransactionCache(hs)
@ -65,20 +68,22 @@ class TransactionRestServlet(RestServlet):
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
def on_PUT(self, request, txn_id):
def on_PUT(
self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
info, _ = await self._room_creation_handler.create_room(
@ -87,21 +92,21 @@ class RoomCreateRestServlet(TransactionRestServlet):
return 200, info
def get_room_config(self, request):
def get_room_config(self, request: Request) -> JsonDict:
user_supplied_config = parse_json_object_from_request(request)
return user_supplied_config
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@ -136,13 +141,19 @@ class RoomStateEventRestServlet(TransactionRestServlet):
self.__class__.__name__,
)
def on_GET_no_state_key(self, request, room_id, event_type):
def on_GET_no_state_key(
self, request: SynapseRequest, room_id: str, event_type: str
) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_GET(request, room_id, event_type, "")
def on_PUT_no_state_key(self, request, room_id, event_type):
def on_PUT_no_state_key(
self, request: SynapseRequest, room_id: str, event_type: str
) -> Awaitable[Tuple[int, JsonDict]]:
return self.on_PUT(request, room_id, event_type, "")
async def on_GET(self, request, room_id, event_type, state_key):
async def on_GET(
self, request: SynapseRequest, room_id: str, event_type: str, state_key: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
format = parse_string(
request, "format", default="content", allowed_values=["content", "event"]
@ -165,7 +176,17 @@ class RoomStateEventRestServlet(TransactionRestServlet):
elif format == "content":
return 200, data.get_dict()["content"]
async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
# Format must be event or content, per the parse_string call above.
raise RuntimeError(f"Unknown format: {format:r}.")
async def on_PUT(
self,
request: SynapseRequest,
room_id: str,
event_type: str,
state_key: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if txn_id:
@ -211,27 +232,35 @@ class RoomStateEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
async def on_POST(self, request, room_id, event_type, txn_id=None):
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_type: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
event_dict = {
event_dict: JsonDict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
}
# Twisted will have processed the args by now.
assert request.args is not None
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
@ -249,10 +278,14 @@ class RoomSendEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
def on_GET(self, request, room_id, event_type, txn_id):
def on_GET(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Tuple[int, str]:
return 200, "Not implemented"
def on_PUT(self, request, room_id, event_type, txn_id):
def on_PUT(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@ -262,12 +295,12 @@ class RoomSendEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super(ResolveRoomIdMixin, self).__init__(hs) # ensure the Mixin is set up
self.auth = hs.get_auth()
def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
# /join/$room_identifier[/$txn_id]
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
@ -277,7 +310,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
request: SynapseRequest,
room_identifier: str,
txn_id: Optional[str] = None,
):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
@ -308,7 +341,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
return 200, {"room_id": room_id}
def on_PUT(self, request, room_identifier, txn_id):
def on_PUT(
self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@ -320,12 +355,12 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
server = parse_string(request, "server")
try:
@ -374,7 +409,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
return 200, data
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server")
@ -438,13 +473,15 @@ class PublicRoomListRestServlet(TransactionRestServlet):
class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
# TODO support Pagination stream API (limit/tokens)
requester = await self.auth.get_user_by_req(request, allow_guest=True)
handler = self.message_handler
@ -490,12 +527,14 @@ class RoomMemberListRestServlet(RestServlet):
class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
users_with_profile = await self.message_handler.get_joined_members(
@ -509,17 +548,21 @@ class JoinedRoomMemberListRestServlet(RestServlet):
class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(
self.store, request, default_limit=10
)
# Twisted will have processed the args by now.
assert request.args is not None
as_client_event = b"raw" not in request.args
filter_str = parse_string(request, "filter", encoding="utf-8")
if filter_str:
@ -549,12 +592,14 @@ class RoomMessageListRestServlet(RestServlet):
class RoomStateRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, List[JsonDict]]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room
events = await self.message_handler.get_state_events(
@ -569,13 +614,15 @@ class RoomStateRestServlet(RestServlet):
class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync(
@ -589,14 +636,16 @@ class RoomEventServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id, event_id):
async def on_GET(
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
event = await self.event_handler.get_event(
@ -610,10 +659,10 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
event = await self._event_serializer.serialize_event(event, time_now)
return 200, event
event_dict = await self._event_serializer.serialize_event(event, time_now)
return 200, event_dict
return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
class RoomEventContextServlet(RestServlet):
@ -621,14 +670,16 @@ class RoomEventContextServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id, event_id):
async def on_GET(
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10)
@ -669,23 +720,27 @@ class RoomEventContextServlet(RestServlet):
class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
async def on_POST(self, request, room_id, txn_id=None):
async def on_POST(
self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
def on_PUT(self, request, room_id, txn_id):
def on_PUT(
self, request: SynapseRequest, room_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@ -695,12 +750,12 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/[invite|join|leave]
PATTERNS = (
"/rooms/(?P<room_id>[^/]*)/"
@ -708,7 +763,13 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
async def on_POST(self, request, room_id, membership_action, txn_id=None):
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
membership_action: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
@ -771,13 +832,15 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value
def _has_3pid_invite_keys(self, content):
def _has_3pid_invite_keys(self, content: JsonDict) -> bool:
for key in {"id_server", "medium", "address"}:
if key not in content:
return False
return True
def on_PUT(self, request, room_id, membership_action, txn_id):
def on_PUT(
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@ -786,16 +849,22 @@ class RoomMembershipRestServlet(TransactionRestServlet):
class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server):
def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
async def on_POST(self, request, room_id, event_id, txn_id=None):
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_id: str,
txn_id: Optional[str] = None,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@ -821,7 +890,9 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
def on_PUT(self, request, room_id, event_id, txn_id):
def on_PUT(
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
@ -846,7 +917,9 @@ class RoomTypingRestServlet(RestServlet):
hs.config.worker.writers.typing == hs.get_instance_name()
)
async def on_PUT(self, request, room_id, user_id):
async def on_PUT(
self, request: SynapseRequest, room_id: str, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not self._is_typing_writer:
@ -897,7 +970,9 @@ class RoomAliasListServlet(RestServlet):
self.auth = hs.get_auth()
self.directory_handler = hs.get_directory_handler()
async def on_GET(self, request, room_id):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
alias_list = await self.directory_handler.get_aliases_for_room(
@ -910,12 +985,12 @@ class RoomAliasListServlet(RestServlet):
class SearchRestServlet(RestServlet):
PATTERNS = client_patterns("/search$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.search_handler = hs.get_search_handler()
self.auth = hs.get_auth()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
@ -929,19 +1004,24 @@ class SearchRestServlet(RestServlet):
class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
return 200, {"joined_rooms": list(room_ids)}
def register_txn_path(servlet, regex_string, http_server, with_get=False):
def register_txn_path(
servlet: RestServlet,
regex_string: str,
http_server: HttpServer,
with_get: bool = False,
) -> None:
"""Registers a transaction-based path.
This registers two paths:
@ -949,28 +1029,37 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
POST regex_string
Args:
regex_string (str): The regex string to register. Must NOT have a
trailing $ as this string will be appended to.
http_server : The http_server to register paths with.
regex_string: The regex string to register. Must NOT have a
trailing $ as this string will be appended to.
http_server: The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
"""
on_POST = getattr(servlet, "on_POST", None)
on_PUT = getattr(servlet, "on_PUT", None)
if on_POST is None or on_PUT is None:
raise RuntimeError("on_POST and on_PUT must exist when using register_txn_path")
http_server.register_paths(
"POST",
client_patterns(regex_string + "$", v1=True),
servlet.on_POST,
on_POST,
servlet.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT,
on_PUT,
servlet.__class__.__name__,
)
on_GET = getattr(servlet, "on_GET", None)
if with_get:
if on_GET is None:
raise RuntimeError(
"register_txn_path called with with_get = True, but no on_GET method exists"
)
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET,
on_GET,
servlet.__class__.__name__,
)
@ -1120,7 +1209,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
)
def register_servlets(hs: "HomeServer", http_server, is_worker=False):
def register_servlets(
hs: "HomeServer", http_server: HttpServer, is_worker: bool = False
) -> None:
RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
@ -1148,5 +1239,5 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
RoomForgetRestServlet(hs).register(http_server)
def register_deprecated_servlets(hs, http_server):
def register_deprecated_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RoomInitialSyncRestServlet(hs).register(http_server)