Add parse_strings_from_args to get prev_events array (#10048)

Split out from https://github.com/matrix-org/synapse/pull/9247

Strings:

 - `parse_string`
 - `parse_string_from_args`
 - `parse_strings_from_args`

For comparison with ints:

 - `parse_integer`
 - `parse_integer_from_args`

Previous discussions:

 - https://github.com/matrix-org/synapse/pull/9247#discussion_r573195687
 - https://github.com/matrix-org/synapse/pull/9247#discussion_r574214156
 - https://github.com/matrix-org/synapse/pull/9247#discussion_r573264791

Signed-off-by: Eric Eastwood <erice@element.io>
This commit is contained in:
Eric Eastwood 2021-05-28 08:19:06 -05:00 committed by GitHub
parent 5eed6348ce
commit ac3e02d089
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 153 additions and 42 deletions

1
changelog.d/10048.misc Normal file
View File

@ -0,0 +1 @@
Add `parse_strings_from_args` for parsing an array from query parameters.

View File

@ -15,6 +15,9 @@
""" This module contains base REST classes for constructing REST servlets. """ """ This module contains base REST classes for constructing REST servlets. """
import logging import logging
from typing import Iterable, List, Optional, Union, overload
from typing_extensions import Literal
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.util import json_decoder from synapse.util import json_decoder
@ -107,12 +110,11 @@ def parse_boolean_from_args(args, name, default=None, required=False):
def parse_string( def parse_string(
request, request,
name, name: Union[bytes, str],
default=None, default: Optional[str] = None,
required=False, required: bool = False,
allowed_values=None, allowed_values: Optional[Iterable[str]] = None,
param_type="string", encoding: Optional[str] = "ascii",
encoding="ascii",
): ):
""" """
Parse a string parameter from the request query string. Parse a string parameter from the request query string.
@ -122,18 +124,17 @@ def parse_string(
Args: Args:
request: the twisted HTTP request. request: the twisted HTTP request.
name (bytes|unicode): the name of the query parameter. name: the name of the query parameter.
default (bytes|unicode|None): value to use if the parameter is absent, default: value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None. defaults to None. Must be bytes if encoding is None.
required (bool): whether to raise a 400 SynapseError if the required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False. parameter is absent, defaults to False.
allowed_values (list[bytes|unicode]): List of allowed values for the allowed_values: List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given. the same type as name, if given.
encoding (str|None): The encoding to decode the string content with. encoding : The encoding to decode the string content with.
Returns: Returns:
bytes/unicode|None: A string value or the default. Unicode if encoding A string value or the default. Unicode if encoding
was given, bytes otherwise. was given, bytes otherwise.
Raises: Raises:
@ -142,33 +143,21 @@ def parse_string(
is not one of those allowed values. is not one of those allowed values.
""" """
return parse_string_from_args( return parse_string_from_args(
request.args, name, default, required, allowed_values, param_type, encoding request.args, name, default, required, allowed_values, encoding
) )
def parse_string_from_args( def _parse_string_value(
args, value: Union[str, bytes],
name, allowed_values: Optional[Iterable[str]],
default=None, name: str,
required=False, encoding: Optional[str],
allowed_values=None, ) -> Union[str, bytes]:
param_type="string",
encoding="ascii",
):
if not isinstance(name, bytes):
name = name.encode("ascii")
if name in args:
value = args[name][0]
if encoding: if encoding:
try: try:
value = value.decode(encoding) value = value.decode(encoding)
except ValueError: except ValueError:
raise SynapseError( raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
400, "Query parameter %r must be %s" % (name, encoding)
)
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
@ -178,9 +167,81 @@ def parse_string_from_args(
raise SynapseError(400, message) raise SynapseError(400, message)
else: else:
return value return value
@overload
def parse_strings_from_args(
args: List[str],
name: Union[bytes, str],
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: Literal[None] = None,
) -> Optional[List[bytes]]:
...
@overload
def parse_strings_from_args(
args: List[str],
name: Union[bytes, str],
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
) -> Optional[List[str]]:
...
def parse_strings_from_args(
args: List[str],
name: Union[bytes, str],
default: Optional[List[str]] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: Optional[str] = "ascii",
) -> Optional[List[Union[bytes, str]]]:
"""
Parse a string parameter from the request query string list.
If encoding is not None, the content of the query param will be
decoded to Unicode using the encoding, otherwise it will be encoded
Args:
args: the twisted HTTP request.args list.
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
required : whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values (list[bytes|unicode]): List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding: The encoding to decode the string content with.
Returns:
A string value or the default. Unicode if encoding
was given, bytes otherwise.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
if not isinstance(name, bytes):
name = name.encode("ascii")
if name in args:
values = args[name]
return [
_parse_string_value(value, allowed_values, name=name, encoding=encoding)
for value in values
]
else: else:
if required: if required:
message = "Missing %s query parameter %r" % (param_type, name) message = "Missing string query parameter %r" % (name)
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM) raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
else: else:
@ -190,6 +251,55 @@ def parse_string_from_args(
return default return default
def parse_string_from_args(
args: List[str],
name: Union[bytes, str],
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: Optional[str] = "ascii",
) -> Optional[Union[bytes, str]]:
"""
Parse the string parameter from the request query string list
and return the first result.
If encoding is not None, the content of the query param will be
decoded to Unicode using the encoding, otherwise it will be encoded
Args:
args: the twisted HTTP request.args list.
name: the name of the query parameter.
default: value to use if the parameter is absent,
defaults to None. Must be bytes if encoding is None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values: List of allowed values for the
string, or None if any value is allowed, defaults to None. Must be
the same type as name, if given.
encoding: The encoding to decode the string content with.
Returns:
A string value or the default. Unicode if encoding
was given, bytes otherwise.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
strings = parse_strings_from_args(
args,
name,
default=[default],
required=required,
allowed_values=allowed_values,
encoding=encoding,
)
return strings[0]
def parse_json_value_from_request(request, allow_empty_body=False): def parse_json_value_from_request(request, allow_empty_body=False):
"""Parse a JSON value from the body of a twisted HTTP request. """Parse a JSON value from the body of a twisted HTTP request.
@ -215,7 +325,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
try: try:
content = json_decoder.decode(content_bytes.decode("utf-8")) content = json_decoder.decode(content_bytes.decode("utf-8"))
except Exception as e: except Exception as e:
logger.warning("Unable to parse JSON: %s", e) logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content return content