Add helper to parse an enum from query args & use it. (#14956)

The `parse_enum` helper pulls an enum value from the query string
(by delegating down to the parse_string helper with values generated
from the enum).

This is used to pull out "f" and "b" in most places and then we thread
the resulting Direction enum throughout more code.
This commit is contained in:
Patrick Cloke 2023-02-01 16:35:24 -05:00 committed by GitHub
parent 230a831c73
commit 1182ae5063
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 176 additions and 96 deletions

1
changelog.d/14956.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints.

View File

@ -37,7 +37,7 @@ from typing import (
import attr
from prometheus_client import Counter
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.errors import (
CodeMessageException,
Codes,
@ -1680,7 +1680,12 @@ class FederationClient(FederationBase):
return result
async def timestamp_to_event(
self, *, destinations: List[str], room_id: str, timestamp: int, direction: str
self,
*,
destinations: List[str],
room_id: str,
timestamp: int,
direction: Direction,
) -> Optional["TimestampToEventResponse"]:
"""
Calls each remote federating server from `destinations` asking for their closest
@ -1693,7 +1698,7 @@ class FederationClient(FederationBase):
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
@ -1738,7 +1743,7 @@ class FederationClient(FederationBase):
return None
async def _timestamp_to_event_from_destination(
self, destination: str, room_id: str, timestamp: int, direction: str
self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> "TimestampToEventResponse":
"""
Calls a remote federating server at `destination` asking for their
@ -1751,7 +1756,7 @@ class FederationClient(FederationBase):
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:

View File

@ -34,7 +34,13 @@ from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership
from synapse.api.constants import (
Direction,
EduTypes,
EventContentFields,
EventTypes,
Membership,
)
from synapse.api.errors import (
AuthError,
Codes,
@ -218,7 +224,7 @@ class FederationServer(FederationBase):
return 200, res
async def on_timestamp_to_event_request(
self, origin: str, room_id: str, timestamp: int, direction: str
self, origin: str, room_id: str, timestamp: int, direction: Direction
) -> Tuple[int, Dict[str, Any]]:
"""When we receive a federated `/timestamp_to_event` request,
handle all of the logic for validating and fetching the event.
@ -228,7 +234,7 @@ class FederationServer(FederationBase):
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:

View File

@ -32,7 +32,7 @@ from typing import (
import attr
import ijson
from synapse.api.constants import Membership
from synapse.api.constants import Direction, Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.api.urls import (
@ -169,7 +169,7 @@ class TransportLayerClient:
)
async def timestamp_to_event(
self, destination: str, room_id: str, timestamp: int, direction: str
self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> Union[JsonDict, List]:
"""
Calls a remote federating server at `destination` asking for their
@ -180,7 +180,7 @@ class TransportLayerClient:
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
@ -194,7 +194,7 @@ class TransportLayerClient:
room_id,
)
args = {"ts": [str(timestamp)], "dir": [direction]}
args = {"ts": [str(timestamp)], "dir": [direction.value]}
remote_response = await self.client.get_json(
destination, path=path, args=args, try_trailing_slash_on_400=True

View File

@ -26,7 +26,7 @@ from typing import (
from typing_extensions import Literal
from synapse.api.constants import EduTypes
from synapse.api.constants import Direction, EduTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX
@ -234,9 +234,10 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
room_id: str,
) -> Tuple[int, JsonDict]:
timestamp = parse_integer_from_args(query, "ts", required=True)
direction = parse_string_from_args(
query, "dir", default="f", allowed_values=["f", "b"], required=True
direction_str = parse_string_from_args(
query, "dir", allowed_values=["f", "b"], required=True
)
direction = Direction(direction_str)
return await self.handler.on_timestamp_to_event_request(
origin, room_id, timestamp, direction

View File

@ -314,7 +314,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
def get_current_key(self, direction: str = "f") -> int:
def get_current_key(self) -> int:
return self.store.get_max_account_data_stream_id()
async def get_new_events(

View File

@ -315,5 +315,5 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
return events, to_key
def get_current_key(self, direction: str = "f") -> int:
def get_current_key(self) -> int:
return self.store.get_max_receipt_stream_id()

View File

@ -27,6 +27,7 @@ from typing_extensions import TypedDict
import synapse.events.snapshot
from synapse.api.constants import (
Direction,
EventContentFields,
EventTypes,
GuestAccess,
@ -1487,7 +1488,7 @@ class TimestampLookupHandler:
requester: Requester,
room_id: str,
timestamp: int,
direction: str,
direction: Direction,
) -> Tuple[str, int]:
"""Find the closest event to the given timestamp in the given direction.
If we can't find an event locally or the event we have locally is next to a gap,
@ -1498,7 +1499,7 @@ class TimestampLookupHandler:
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
@ -1533,13 +1534,13 @@ class TimestampLookupHandler:
local_event_id, allow_none=False, allow_rejected=False
)
if direction == "f":
if direction == Direction.FORWARDS:
# We only need to check for a backward gap if we're looking forwards
# to ensure there is nothing in between.
is_event_next_to_backward_gap = (
await self.store.is_event_next_to_backward_gap(local_event)
)
elif direction == "b":
elif direction == Direction.BACKWARDS:
# We only need to check for a forward gap if we're looking backwards
# to ensure there is nothing in between
is_event_next_to_forward_gap = (

View File

@ -13,6 +13,7 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
import enum
import logging
from http import HTTPStatus
from typing import (
@ -362,6 +363,7 @@ def parse_string(
request: Request,
name: str,
*,
default: Optional[str] = None,
required: bool = False,
allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii",
@ -413,6 +415,74 @@ def parse_string(
)
EnumT = TypeVar("EnumT", bound=enum.Enum)
@overload
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
default: EnumT,
) -> EnumT:
...
@overload
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
*,
required: Literal[True],
) -> EnumT:
...
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
default: Optional[EnumT] = None,
required: bool = False,
) -> Optional[EnumT]:
"""
Parse an enum parameter from the request query string.
Note that the enum *must only have string values*.
Args:
request: the twisted HTTP request.
name: the name of the query parameter.
E: the enum which represents valid values
default: enum value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
Returns:
An enum value.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
# Assert the enum values are strings.
assert all(
isinstance(e.value, str) for e in E
), "parse_enum only works with string values"
str_value = parse_string(
request,
name,
default=default.value if default is not None else None,
required=required,
allowed_values=[e.value for e in E],
)
if str_value is None:
return None
return E(str_value)
def _parse_string_value(
value: bytes,
allowed_values: Optional[Iterable[str]],

View File

@ -16,8 +16,9 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
@ -60,7 +61,7 @@ class EventReportsRestServlet(RestServlet):
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
direction = parse_string(request, "dir", default="b")
direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS)
user_id = parse_string(request, "user_id")
room_id = parse_string(request, "room_id")
@ -78,13 +79,6 @@ class EventReportsRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
if direction not in ("f", "b"):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown direction: %s" % (direction,),
errcode=Codes.INVALID_PARAM,
)
event_reports, total = await self.store.get_event_reports_paginate(
start, limit, direction, user_id, room_id
)

View File

@ -15,9 +15,10 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.federation.transport.server import Authenticator
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.storage.databases.main.transactions import DestinationSortOrder
@ -79,7 +80,7 @@ class ListDestinationsRestServlet(RestServlet):
allowed_values=[dest.value for dest in DestinationSortOrder],
)
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
@ -192,7 +193,7 @@ class DestinationMembershipRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
rooms, total = await self._store.get_destination_rooms_paginate(
destination, start, limit, direction

View File

@ -17,9 +17,16 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.servlet import (
RestServlet,
parse_boolean,
parse_enum,
parse_integer,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
@ -389,7 +396,7 @@ class UserMediaRestServlet(RestServlet):
# to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value
direction = "b"
direction = Direction.BACKWARDS
else:
order_by = parse_string(
request,
@ -397,8 +404,8 @@ class UserMediaRestServlet(RestServlet):
default=MediaSortOrder.CREATED_TS.value,
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
direction = parse_enum(
request, "dir", Direction, default=Direction.FORWARDS
)
media, total = await self.store.get_local_media_by_user_paginate(
@ -447,7 +454,7 @@ class UserMediaRestServlet(RestServlet):
# to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value
direction = "b"
direction = Direction.BACKWARDS
else:
order_by = parse_string(
request,
@ -455,8 +462,8 @@ class UserMediaRestServlet(RestServlet):
default=MediaSortOrder.CREATED_TS.value,
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
)
direction = parse_string(
request, "dir", default="f", allowed_values=("f", "b")
direction = parse_enum(
request, "dir", Direction, default=Direction.FORWARDS
)
media, _ = await self.store.get_local_media_by_user_paginate(

View File

@ -16,13 +16,14 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
assert_params_in_dict,
parse_enum,
parse_integer,
parse_json_object_from_request,
parse_string,
@ -224,15 +225,8 @@ class ListRoomRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
direction = parse_string(request, "dir", default="f")
if direction not in ("f", "b"):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown direction: %s" % (direction,),
errcode=Codes.INVALID_PARAM,
)
reverse_order = True if direction == "b" else False
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
reverse_order = True if direction == Direction.BACKWARDS else False
# Return list of rooms according to parameters
rooms, total_rooms = await self.store.get_rooms_paginate(
@ -949,7 +943,7 @@ class RoomTimestampToEventRestServlet(RestServlet):
await assert_user_is_admin(self._auth, requester)
timestamp = parse_integer(request, "ts", required=True)
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
(
event_id,

View File

@ -16,8 +16,9 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import Direction
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.storage.databases.main.stats import UserSortOrder
@ -102,13 +103,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
direction = parse_string(request, "dir", default="f")
if direction not in ("f", "b"):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown direction: %s" % (direction,),
errcode=Codes.INVALID_PARAM,
)
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term

View File

@ -18,12 +18,13 @@ import secrets
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.constants import Direction, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_boolean,
parse_enum,
parse_integer,
parse_json_object_from_request,
parse_string,
@ -120,7 +121,7 @@ class UsersRestServletV2(RestServlet):
),
)
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
users, total = await self.store.get_users_paginate(
start,

View File

@ -16,6 +16,7 @@ import logging
import re
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import Direction
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
@ -59,7 +60,7 @@ class RelationPaginationServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(
self._store, request, default_limit=5, default_dir="b"
self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
)
# The unstable version of this API returns an extra field for client

View File

@ -26,7 +26,7 @@ from prometheus_client.core import Histogram
from twisted.web.server import Request
from synapse import event_auth
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@ -44,6 +44,7 @@ from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_boolean,
parse_enum,
parse_integer,
parse_json_object_from_request,
parse_string,
@ -1297,7 +1298,7 @@ class TimestampLookupRestServlet(RestServlet):
await self._auth.check_user_in_room_or_world_readable(room_id, requester)
timestamp = parse_integer(request, "ts", required=True)
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
(
event_id,

View File

@ -17,6 +17,7 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import (
DatabasePool,
@ -167,7 +168,7 @@ class DataStore(
guests: bool = True,
deactivated: bool = False,
order_by: str = UserSortOrder.NAME.value,
direction: str = "f",
direction: Direction = Direction.FORWARDS,
approved: bool = True,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
@ -197,7 +198,7 @@ class DataStore(
# Set ordering
order_by_column = UserSortOrder(order_by).value
if direction == "b":
if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"

View File

@ -38,7 +38,7 @@ from typing_extensions import Literal
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.constants import Direction, EventTypes
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
@ -2240,7 +2240,7 @@ class EventsWorkerStore(SQLBaseStore):
)
async def get_event_id_for_timestamp(
self, room_id: str, timestamp: int, direction: str
self, room_id: str, timestamp: int, direction: Direction
) -> Optional[str]:
"""Find the closest event to the given timestamp in the given direction.
@ -2248,14 +2248,14 @@ class EventsWorkerStore(SQLBaseStore):
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
The closest event_id otherwise None if we can't find any event in
the given direction.
"""
if direction == "b":
if direction == Direction.BACKWARDS:
# Find closest event *before* a given timestamp. We use descending
# (which gives values largest to smallest) because we want the
# largest possible timestamp *before* the given timestamp.
@ -2307,9 +2307,6 @@ class EventsWorkerStore(SQLBaseStore):
return None
if direction not in ("f", "b"):
raise ValueError("Unknown direction: %s" % (direction,))
return await self.db_pool.runInteraction(
"get_event_id_for_timestamp_txn",
get_event_id_for_timestamp_txn,

View File

@ -26,6 +26,7 @@ from typing import (
cast,
)
from synapse.api.constants import Direction
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -176,7 +177,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
limit: int,
user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value,
direction: str = "f",
direction: Direction = Direction.FORWARDS,
) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@ -199,7 +200,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
# Set ordering
order_by_column = MediaSortOrder(order_by).value
if direction == "b":
if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"

View File

@ -35,6 +35,7 @@ from typing import (
import attr
from synapse.api.constants import (
Direction,
EventContentFields,
EventTypes,
JoinRules,
@ -2204,7 +2205,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self,
start: int,
limit: int,
direction: str = "b",
direction: Direction = Direction.BACKWARDS,
user_id: Optional[str] = None,
room_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
@ -2213,8 +2214,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args:
start: event offset to begin the query from
limit: number of rows to retrieve
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`)
direction: Whether to fetch the most recent first (backwards) or the
oldest first (forwards)
user_id: search for user_id. Ignored if user_id is None
room_id: search for room_id. Ignored if room_id is None
Returns:
@ -2236,7 +2237,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
filters.append("er.room_id LIKE ?")
args.extend(["%" + room_id + "%"])
if direction == "b":
if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"

View File

@ -22,7 +22,7 @@ from typing_extensions import Counter
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.storage.database import (
DatabasePool,
@ -663,7 +663,7 @@ class StatsStore(StateDeltasStore):
from_ts: Optional[int] = None,
until_ts: Optional[int] = None,
order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Optional[str] = "f",
direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users and their uploaded local media
@ -714,7 +714,7 @@ class StatsStore(StateDeltasStore):
500, "Incorrect value for order_by provided: %s" % order_by
)
if direction == "b":
if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"

View File

@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
import attr
from canonicaljson import encode_canonical_json
from synapse.api.constants import Direction
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import db_to_json
from synapse.storage.database import (
@ -496,7 +497,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit: int,
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
direction: str = "f",
direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
@ -518,7 +519,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) -> Tuple[List[JsonDict], int]:
order_by_column = DestinationSortOrder(order_by).value
if direction == "b":
if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
@ -550,7 +551,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
)
async def get_destination_rooms_paginate(
self, destination: str, start: int, limit: int, direction: str = "f"
self,
destination: str,
start: int,
limit: int,
direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destination's rooms.
This will return a json list of rooms and the
@ -569,7 +574,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
if direction == "b":
if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"

View File

@ -18,7 +18,7 @@ import attr
from synapse.api.constants import Direction
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.servlet import parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.storage.databases.main import DataStore
from synapse.types import StreamToken
@ -44,15 +44,9 @@ class PaginationConfig:
store: "DataStore",
request: SynapseRequest,
default_limit: int,
default_dir: str = "f",
default_dir: Direction = Direction.FORWARDS,
) -> "PaginationConfig":
direction_str = parse_string(
request,
"dir",
default=default_dir,
allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
)
direction = Direction(direction_str)
direction = parse_enum(request, "dir", Direction, default=default_dir)
from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")

View File

@ -280,7 +280,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual("Unknown direction: bar", channel.json_body["error"])
self.assertEqual(
"Query parameter 'dir' must be one of ['b', 'f']",
channel.json_body["error"],
)
def test_limit_is_negative(self) -> None:
"""