mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-01-27 01:46:59 -05:00
Add a return type to parse_string. (#10438)
And set the required attribute in a few places which will error if a parameter is not provided.
This commit is contained in:
parent
2d89c66b88
commit
5db118626b
1
changelog.d/10438.misc
Normal file
1
changelog.d/10438.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Improve servlet type hints.
|
@ -172,6 +172,42 @@ def parse_bytes_from_args(
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def parse_string(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
default: str,
|
||||||
|
*,
|
||||||
|
allowed_values: Optional[Iterable[str]] = None,
|
||||||
|
encoding: str = "ascii",
|
||||||
|
) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def parse_string(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
required: Literal[True],
|
||||||
|
allowed_values: Optional[Iterable[str]] = None,
|
||||||
|
encoding: str = "ascii",
|
||||||
|
) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def parse_string(
|
||||||
|
request: Request,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
required: bool = False,
|
||||||
|
allowed_values: Optional[Iterable[str]] = None,
|
||||||
|
encoding: str = "ascii",
|
||||||
|
) -> Optional[str]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def parse_string(
|
def parse_string(
|
||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
@ -179,7 +215,7 @@ def parse_string(
|
|||||||
required: bool = False,
|
required: bool = False,
|
||||||
allowed_values: Optional[Iterable[str]] = None,
|
allowed_values: Optional[Iterable[str]] = None,
|
||||||
encoding: str = "ascii",
|
encoding: str = "ascii",
|
||||||
):
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Parse a string parameter from the request query string.
|
Parse a string parameter from the request query string.
|
||||||
|
|
||||||
|
@ -90,8 +90,8 @@ class UsersRestServletV2(RestServlet):
|
|||||||
errcode=Codes.INVALID_PARAM,
|
errcode=Codes.INVALID_PARAM,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_id = parse_string(request, "user_id", default=None)
|
user_id = parse_string(request, "user_id")
|
||||||
name = parse_string(request, "name", default=None)
|
name = parse_string(request, "name")
|
||||||
guests = parse_boolean(request, "guests", default=True)
|
guests = parse_boolean(request, "guests", default=True)
|
||||||
deactivated = parse_boolean(request, "deactivated", default=False)
|
deactivated = parse_boolean(request, "deactivated", default=False)
|
||||||
|
|
||||||
|
@ -413,7 +413,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
|
|||||||
assert_params_in_dict(body, ["state_events_at_start", "events"])
|
assert_params_in_dict(body, ["state_events_at_start", "events"])
|
||||||
|
|
||||||
prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
|
prev_events_from_query = parse_strings_from_args(request.args, "prev_event")
|
||||||
chunk_id_from_query = parse_string(request, "chunk_id", default=None)
|
chunk_id_from_query = parse_string(request, "chunk_id")
|
||||||
|
|
||||||
if prev_events_from_query is None:
|
if prev_events_from_query is None:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
@ -735,7 +735,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
async def on_GET(self, request):
|
async def on_GET(self, request):
|
||||||
server = parse_string(request, "server", default=None)
|
server = parse_string(request, "server")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
@ -755,7 +755,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
limit = parse_integer(request, "limit", 0)
|
limit = parse_integer(request, "limit", 0)
|
||||||
since_token = parse_string(request, "since", None)
|
since_token = parse_string(request, "since")
|
||||||
|
|
||||||
if limit == 0:
|
if limit == 0:
|
||||||
# zero is a special value which corresponds to no limit.
|
# zero is a special value which corresponds to no limit.
|
||||||
@ -789,7 +789,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||||||
async def on_POST(self, request):
|
async def on_POST(self, request):
|
||||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
server = parse_string(request, "server", default=None)
|
server = parse_string(request, "server")
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
limit: Optional[int] = int(content.get("limit", 100))
|
limit: Optional[int] = int(content.get("limit", 100))
|
||||||
|
@ -194,7 +194,7 @@ class KeyChangesServlet(RestServlet):
|
|||||||
async def on_GET(self, request):
|
async def on_GET(self, request):
|
||||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
from_token_string = parse_string(request, "from")
|
from_token_string = parse_string(request, "from", required=True)
|
||||||
set_tag("from", from_token_string)
|
set_tag("from", from_token_string)
|
||||||
|
|
||||||
# We want to enforce they do pass us one, but we ignore it and return
|
# We want to enforce they do pass us one, but we ignore it and return
|
||||||
|
@ -158,19 +158,21 @@ class RelationPaginationServlet(RestServlet):
|
|||||||
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
|
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
|
||||||
|
|
||||||
limit = parse_integer(request, "limit", default=5)
|
limit = parse_integer(request, "limit", default=5)
|
||||||
from_token = parse_string(request, "from")
|
from_token_str = parse_string(request, "from")
|
||||||
to_token = parse_string(request, "to")
|
to_token_str = parse_string(request, "to")
|
||||||
|
|
||||||
if event.internal_metadata.is_redacted():
|
if event.internal_metadata.is_redacted():
|
||||||
# If the event is redacted, return an empty list of relations
|
# If the event is redacted, return an empty list of relations
|
||||||
pagination_chunk = PaginationChunk(chunk=[])
|
pagination_chunk = PaginationChunk(chunk=[])
|
||||||
else:
|
else:
|
||||||
# Return the relations
|
# Return the relations
|
||||||
if from_token:
|
from_token = None
|
||||||
from_token = RelationPaginationToken.from_string(from_token)
|
if from_token_str:
|
||||||
|
from_token = RelationPaginationToken.from_string(from_token_str)
|
||||||
|
|
||||||
if to_token:
|
to_token = None
|
||||||
to_token = RelationPaginationToken.from_string(to_token)
|
if to_token_str:
|
||||||
|
to_token = RelationPaginationToken.from_string(to_token_str)
|
||||||
|
|
||||||
pagination_chunk = await self.store.get_relations_for_event(
|
pagination_chunk = await self.store.get_relations_for_event(
|
||||||
event_id=parent_id,
|
event_id=parent_id,
|
||||||
@ -256,19 +258,21 @@ class RelationAggregationPaginationServlet(RestServlet):
|
|||||||
raise SynapseError(400, "Relation type must be 'annotation'")
|
raise SynapseError(400, "Relation type must be 'annotation'")
|
||||||
|
|
||||||
limit = parse_integer(request, "limit", default=5)
|
limit = parse_integer(request, "limit", default=5)
|
||||||
from_token = parse_string(request, "from")
|
from_token_str = parse_string(request, "from")
|
||||||
to_token = parse_string(request, "to")
|
to_token_str = parse_string(request, "to")
|
||||||
|
|
||||||
if event.internal_metadata.is_redacted():
|
if event.internal_metadata.is_redacted():
|
||||||
# If the event is redacted, return an empty list of relations
|
# If the event is redacted, return an empty list of relations
|
||||||
pagination_chunk = PaginationChunk(chunk=[])
|
pagination_chunk = PaginationChunk(chunk=[])
|
||||||
else:
|
else:
|
||||||
# Return the relations
|
# Return the relations
|
||||||
if from_token:
|
from_token = None
|
||||||
from_token = AggregationPaginationToken.from_string(from_token)
|
if from_token_str:
|
||||||
|
from_token = AggregationPaginationToken.from_string(from_token_str)
|
||||||
|
|
||||||
if to_token:
|
to_token = None
|
||||||
to_token = AggregationPaginationToken.from_string(to_token)
|
if to_token_str:
|
||||||
|
to_token = AggregationPaginationToken.from_string(to_token_str)
|
||||||
|
|
||||||
pagination_chunk = await self.store.get_aggregation_groups_for_event(
|
pagination_chunk = await self.store.get_aggregation_groups_for_event(
|
||||||
event_id=parent_id,
|
event_id=parent_id,
|
||||||
@ -336,14 +340,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
|||||||
raise SynapseError(400, "Relation type must be 'annotation'")
|
raise SynapseError(400, "Relation type must be 'annotation'")
|
||||||
|
|
||||||
limit = parse_integer(request, "limit", default=5)
|
limit = parse_integer(request, "limit", default=5)
|
||||||
from_token = parse_string(request, "from")
|
from_token_str = parse_string(request, "from")
|
||||||
to_token = parse_string(request, "to")
|
to_token_str = parse_string(request, "to")
|
||||||
|
|
||||||
if from_token:
|
from_token = None
|
||||||
from_token = RelationPaginationToken.from_string(from_token)
|
if from_token_str:
|
||||||
|
from_token = RelationPaginationToken.from_string(from_token_str)
|
||||||
|
|
||||||
if to_token:
|
to_token = None
|
||||||
to_token = RelationPaginationToken.from_string(to_token)
|
if to_token_str:
|
||||||
|
to_token = RelationPaginationToken.from_string(to_token_str)
|
||||||
|
|
||||||
result = await self.store.get_relations_for_event(
|
result = await self.store.get_relations_for_event(
|
||||||
event_id=parent_id,
|
event_id=parent_id,
|
||||||
|
@ -112,7 +112,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
default="online",
|
default="online",
|
||||||
allowed_values=self.ALLOWED_PRESENCE,
|
allowed_values=self.ALLOWED_PRESENCE,
|
||||||
)
|
)
|
||||||
filter_id = parse_string(request, "filter", default=None)
|
filter_id = parse_string(request, "filter")
|
||||||
full_state = parse_boolean(request, "full_state", default=False)
|
full_state = parse_boolean(request, "full_state", default=False)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -112,7 +112,7 @@ class ConsentResource(DirectServeHtmlResource):
|
|||||||
request (twisted.web.http.Request):
|
request (twisted.web.http.Request):
|
||||||
"""
|
"""
|
||||||
version = parse_string(request, "v", default=self._default_consent_version)
|
version = parse_string(request, "v", default=self._default_consent_version)
|
||||||
username = parse_string(request, "u", required=False, default="")
|
username = parse_string(request, "u", default="")
|
||||||
userhmac = None
|
userhmac = None
|
||||||
has_consented = False
|
has_consented = False
|
||||||
public_version = username == ""
|
public_version = username == ""
|
||||||
|
@ -186,15 +186,11 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
# This will always be set by the time Twisted calls us.
|
|
||||||
assert request.args is not None
|
|
||||||
|
|
||||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
url = parse_string(request, "url")
|
url = parse_string(request, "url", required=True)
|
||||||
if b"ts" in request.args:
|
ts = parse_integer(request, "ts")
|
||||||
ts = parse_integer(request, "ts")
|
if ts is None:
|
||||||
else:
|
|
||||||
ts = self.clock.time_msec()
|
ts = self.clock.time_msec()
|
||||||
|
|
||||||
# XXX: we could move this into _do_preview if we wanted.
|
# XXX: we could move this into _do_preview if we wanted.
|
||||||
|
@ -249,7 +249,7 @@ class DataStore(
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
guests: bool = True,
|
guests: bool = True,
|
||||||
deactivated: bool = False,
|
deactivated: bool = False,
|
||||||
order_by: UserSortOrder = UserSortOrder.USER_ID.value,
|
order_by: str = UserSortOrder.USER_ID.value,
|
||||||
direction: str = "f",
|
direction: str = "f",
|
||||||
) -> Tuple[List[JsonDict], int]:
|
) -> Tuple[List[JsonDict], int]:
|
||||||
"""Function to retrieve a paginated list of users from
|
"""Function to retrieve a paginated list of users from
|
||||||
|
@ -363,7 +363,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||||||
self,
|
self,
|
||||||
start: int,
|
start: int,
|
||||||
limit: int,
|
limit: int,
|
||||||
order_by: RoomSortOrder,
|
order_by: str,
|
||||||
reverse_order: bool,
|
reverse_order: bool,
|
||||||
search_term: Optional[str],
|
search_term: Optional[str],
|
||||||
) -> Tuple[List[Dict[str, Any]], int]:
|
) -> Tuple[List[Dict[str, Any]], int]:
|
||||||
|
@ -647,7 +647,7 @@ class StatsStore(StateDeltasStore):
|
|||||||
limit: int,
|
limit: int,
|
||||||
from_ts: Optional[int] = None,
|
from_ts: Optional[int] = None,
|
||||||
until_ts: Optional[int] = None,
|
until_ts: Optional[int] = None,
|
||||||
order_by: Optional[UserSortOrder] = UserSortOrder.USER_ID.value,
|
order_by: Optional[str] = UserSortOrder.USER_ID.value,
|
||||||
direction: Optional[str] = "f",
|
direction: Optional[str] = "f",
|
||||||
search_term: Optional[str] = None,
|
search_term: Optional[str] = None,
|
||||||
) -> Tuple[List[JsonDict], Dict[str, int]]:
|
) -> Tuple[List[JsonDict], Dict[str, int]]:
|
||||||
|
@ -47,20 +47,22 @@ class PaginationConfig:
|
|||||||
) -> "PaginationConfig":
|
) -> "PaginationConfig":
|
||||||
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
|
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
|
||||||
|
|
||||||
from_tok = parse_string(request, "from")
|
from_tok_str = parse_string(request, "from")
|
||||||
to_tok = parse_string(request, "to")
|
to_tok_str = parse_string(request, "to")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if from_tok == "END":
|
from_tok = None
|
||||||
|
if from_tok_str == "END":
|
||||||
from_tok = None # For backwards compat.
|
from_tok = None # For backwards compat.
|
||||||
elif from_tok:
|
elif from_tok_str:
|
||||||
from_tok = await StreamToken.from_string(store, from_tok)
|
from_tok = await StreamToken.from_string(store, from_tok_str)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise SynapseError(400, "'from' parameter is invalid")
|
raise SynapseError(400, "'from' parameter is invalid")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if to_tok:
|
to_tok = None
|
||||||
to_tok = await StreamToken.from_string(store, to_tok)
|
if to_tok_str:
|
||||||
|
to_tok = await StreamToken.from_string(store, to_tok_str)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise SynapseError(400, "'to' parameter is invalid")
|
raise SynapseError(400, "'to' parameter is invalid")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user