Add a return type to parse_string. (#10438)

And set the required attribute in a few places which will error if
a parameter is not provided.
This commit is contained in:
Patrick Cloke 2021-07-21 09:47:56 -04:00 committed by GitHub
parent 2d89c66b88
commit 5db118626b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 86 additions and 45 deletions

1
changelog.d/10438.misc Normal file
View File

@ -0,0 +1 @@
Improve servlet type hints.

View File

@ -172,6 +172,42 @@ def parse_bytes_from_args(
return default return default
@overload
def parse_string(
request: Request,
name: str,
default: str,
*,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> str:
...
@overload
def parse_string(
request: Request,
name: str,
*,
required: Literal[True],
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> str:
...
@overload
def parse_string(
request: Request,
name: str,
*,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> Optional[str]:
...
def parse_string( def parse_string(
request: Request, request: Request,
name: str, name: str,
@ -179,7 +215,7 @@ def parse_string(
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii", encoding: str = "ascii",
): ) -> Optional[str]:
""" """
Parse a string parameter from the request query string. Parse a string parameter from the request query string.

View File

@ -90,8 +90,8 @@ class UsersRestServletV2(RestServlet):
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
user_id = parse_string(request, "user_id", default=None) user_id = parse_string(request, "user_id")
name = parse_string(request, "name", default=None) name = parse_string(request, "name")
guests = parse_boolean(request, "guests", default=True) guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False) deactivated = parse_boolean(request, "deactivated", default=False)

View File

@ -413,7 +413,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
assert_params_in_dict(body, ["state_events_at_start", "events"]) assert_params_in_dict(body, ["state_events_at_start", "events"])
prev_events_from_query = parse_strings_from_args(request.args, "prev_event") prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
chunk_id_from_query = parse_string(request, "chunk_id", default=None) chunk_id_from_query = parse_string(request, "chunk_id")
if prev_events_from_query is None: if prev_events_from_query is None:
raise SynapseError( raise SynapseError(
@ -735,7 +735,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request): async def on_GET(self, request):
server = parse_string(request, "server", default=None) server = parse_string(request, "server")
try: try:
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
@ -755,7 +755,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
raise e raise e
limit = parse_integer(request, "limit", 0) limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None) since_token = parse_string(request, "since")
if limit == 0: if limit == 0:
# zero is a special value which corresponds to no limit. # zero is a special value which corresponds to no limit.
@ -789,7 +789,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
async def on_POST(self, request): async def on_POST(self, request):
await self.auth.get_user_by_req(request, allow_guest=True) await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server", default=None) server = parse_string(request, "server")
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
limit: Optional[int] = int(content.get("limit", 100)) limit: Optional[int] = int(content.get("limit", 100))

View File

@ -194,7 +194,7 @@ class KeyChangesServlet(RestServlet):
async def on_GET(self, request): async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from") from_token_string = parse_string(request, "from", required=True)
set_tag("from", from_token_string) set_tag("from", from_token_string)
# We want to enforce they do pass us one, but we ignore it and return # We want to enforce they do pass us one, but we ignore it and return

View File

@ -158,19 +158,21 @@ class RelationPaginationServlet(RestServlet):
event = await self.event_handler.get_event(requester.user, room_id, parent_id) event = await self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from") from_token_str = parse_string(request, "from")
to_token = parse_string(request, "to") to_token_str = parse_string(request, "to")
if event.internal_metadata.is_redacted(): if event.internal_metadata.is_redacted():
# If the event is redacted, return an empty list of relations # If the event is redacted, return an empty list of relations
pagination_chunk = PaginationChunk(chunk=[]) pagination_chunk = PaginationChunk(chunk=[])
else: else:
# Return the relations # Return the relations
if from_token: from_token = None
from_token = RelationPaginationToken.from_string(from_token) if from_token_str:
from_token = RelationPaginationToken.from_string(from_token_str)
if to_token: to_token = None
to_token = RelationPaginationToken.from_string(to_token) if to_token_str:
to_token = RelationPaginationToken.from_string(to_token_str)
pagination_chunk = await self.store.get_relations_for_event( pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id, event_id=parent_id,
@ -256,19 +258,21 @@ class RelationAggregationPaginationServlet(RestServlet):
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from") from_token_str = parse_string(request, "from")
to_token = parse_string(request, "to") to_token_str = parse_string(request, "to")
if event.internal_metadata.is_redacted(): if event.internal_metadata.is_redacted():
# If the event is redacted, return an empty list of relations # If the event is redacted, return an empty list of relations
pagination_chunk = PaginationChunk(chunk=[]) pagination_chunk = PaginationChunk(chunk=[])
else: else:
# Return the relations # Return the relations
if from_token: from_token = None
from_token = AggregationPaginationToken.from_string(from_token) if from_token_str:
from_token = AggregationPaginationToken.from_string(from_token_str)
if to_token: to_token = None
to_token = AggregationPaginationToken.from_string(to_token) if to_token_str:
to_token = AggregationPaginationToken.from_string(to_token_str)
pagination_chunk = await self.store.get_aggregation_groups_for_event( pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id, event_id=parent_id,
@ -336,14 +340,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from") from_token_str = parse_string(request, "from")
to_token = parse_string(request, "to") to_token_str = parse_string(request, "to")
if from_token: from_token = None
from_token = RelationPaginationToken.from_string(from_token) if from_token_str:
from_token = RelationPaginationToken.from_string(from_token_str)
if to_token: to_token = None
to_token = RelationPaginationToken.from_string(to_token) if to_token_str:
to_token = RelationPaginationToken.from_string(to_token_str)
result = await self.store.get_relations_for_event( result = await self.store.get_relations_for_event(
event_id=parent_id, event_id=parent_id,

View File

@ -112,7 +112,7 @@ class SyncRestServlet(RestServlet):
default="online", default="online",
allowed_values=self.ALLOWED_PRESENCE, allowed_values=self.ALLOWED_PRESENCE,
) )
filter_id = parse_string(request, "filter", default=None) filter_id = parse_string(request, "filter")
full_state = parse_boolean(request, "full_state", default=False) full_state = parse_boolean(request, "full_state", default=False)
logger.debug( logger.debug(

View File

@ -112,7 +112,7 @@ class ConsentResource(DirectServeHtmlResource):
request (twisted.web.http.Request): request (twisted.web.http.Request):
""" """
version = parse_string(request, "v", default=self._default_consent_version) version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", required=False, default="") username = parse_string(request, "u", default="")
userhmac = None userhmac = None
has_consented = False has_consented = False
public_version = username == "" public_version = username == ""

View File

@ -186,15 +186,11 @@ class PreviewUrlResource(DirectServeJsonResource):
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request: SynapseRequest) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
# This will always be set by the time Twisted calls us.
assert request.args is not None
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
url = parse_string(request, "url") url = parse_string(request, "url", required=True)
if b"ts" in request.args: ts = parse_integer(request, "ts")
ts = parse_integer(request, "ts") if ts is None:
else:
ts = self.clock.time_msec() ts = self.clock.time_msec()
# XXX: we could move this into _do_preview if we wanted. # XXX: we could move this into _do_preview if we wanted.

View File

@ -249,7 +249,7 @@ class DataStore(
name: Optional[str] = None, name: Optional[str] = None,
guests: bool = True, guests: bool = True,
deactivated: bool = False, deactivated: bool = False,
order_by: UserSortOrder = UserSortOrder.USER_ID.value, order_by: str = UserSortOrder.USER_ID.value,
direction: str = "f", direction: str = "f",
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from """Function to retrieve a paginated list of users from

View File

@ -363,7 +363,7 @@ class RoomWorkerStore(SQLBaseStore):
self, self,
start: int, start: int,
limit: int, limit: int,
order_by: RoomSortOrder, order_by: str,
reverse_order: bool, reverse_order: bool,
search_term: Optional[str], search_term: Optional[str],
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[Dict[str, Any]], int]:

View File

@ -647,7 +647,7 @@ class StatsStore(StateDeltasStore):
limit: int, limit: int,
from_ts: Optional[int] = None, from_ts: Optional[int] = None,
until_ts: Optional[int] = None, until_ts: Optional[int] = None,
order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value, order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Optional[str] = "f", direction: Optional[str] = "f",
search_term: Optional[str] = None, search_term: Optional[str] = None,
) -> Tuple[List[JsonDict], Dict[str, int]]: ) -> Tuple[List[JsonDict], Dict[str, int]]:

View File

@ -47,20 +47,22 @@ class PaginationConfig:
) -> "PaginationConfig": ) -> "PaginationConfig":
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
from_tok = parse_string(request, "from") from_tok_str = parse_string(request, "from")
to_tok = parse_string(request, "to") to_tok_str = parse_string(request, "to")
try: try:
if from_tok == "END": from_tok = None
if from_tok_str == "END":
from_tok = None # For backwards compat. from_tok = None # For backwards compat.
elif from_tok: elif from_tok_str:
from_tok = await StreamToken.from_string(store, from_tok) from_tok = await StreamToken.from_string(store, from_tok_str)
except Exception: except Exception:
raise SynapseError(400, "'from' parameter is invalid") raise SynapseError(400, "'from' parameter is invalid")
try: try:
if to_tok: to_tok = None
to_tok = await StreamToken.from_string(store, to_tok) if to_tok_str:
to_tok = await StreamToken.from_string(store, to_tok_str)
except Exception: except Exception:
raise SynapseError(400, "'to' parameter is invalid") raise SynapseError(400, "'to' parameter is invalid")