mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add helper to parse an enum from query args & use it. (#14956)
The `parse_enum` helper pulls an enum value from the query string (by delegating down to the parse_string helper with values generated from the enum). This is used to pull out "f" and "b" in most places and then we thread the resulting Direction enum throughout more code.
This commit is contained in:
parent
230a831c73
commit
1182ae5063
1
changelog.d/14956.misc
Normal file
1
changelog.d/14956.misc
Normal file
@ -0,0 +1 @@
|
||||
Add missing type hints.
|
@ -37,7 +37,7 @@ from typing import (
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
@ -1680,7 +1680,12 @@ class FederationClient(FederationBase):
|
||||
return result
|
||||
|
||||
async def timestamp_to_event(
|
||||
self, *, destinations: List[str], room_id: str, timestamp: int, direction: str
|
||||
self,
|
||||
*,
|
||||
destinations: List[str],
|
||||
room_id: str,
|
||||
timestamp: int,
|
||||
direction: Direction,
|
||||
) -> Optional["TimestampToEventResponse"]:
|
||||
"""
|
||||
Calls each remote federating server from `destinations` asking for their closest
|
||||
@ -1693,7 +1698,7 @@ class FederationClient(FederationBase):
|
||||
room_id: Room to fetch the event from
|
||||
timestamp: The point in time (inclusive) we should navigate from in
|
||||
the given direction to find the closest event.
|
||||
direction: ["f"|"b"] to indicate whether we should navigate forward
|
||||
direction: indicates whether we should navigate forward
|
||||
or backward from the given timestamp to find the closest event.
|
||||
|
||||
Returns:
|
||||
@ -1738,7 +1743,7 @@ class FederationClient(FederationBase):
|
||||
return None
|
||||
|
||||
async def _timestamp_to_event_from_destination(
|
||||
self, destination: str, room_id: str, timestamp: int, direction: str
|
||||
self, destination: str, room_id: str, timestamp: int, direction: Direction
|
||||
) -> "TimestampToEventResponse":
|
||||
"""
|
||||
Calls a remote federating server at `destination` asking for their
|
||||
@ -1751,7 +1756,7 @@ class FederationClient(FederationBase):
|
||||
room_id: Room to fetch the event from
|
||||
timestamp: The point in time (inclusive) we should navigate from in
|
||||
the given direction to find the closest event.
|
||||
direction: ["f"|"b"] to indicate whether we should navigate forward
|
||||
direction: indicates whether we should navigate forward
|
||||
or backward from the given timestamp to find the closest event.
|
||||
|
||||
Returns:
|
||||
|
@ -34,7 +34,13 @@ from prometheus_client import Counter, Gauge, Histogram
|
||||
from twisted.internet.abstract import isIPAddress
|
||||
from twisted.python import failure
|
||||
|
||||
from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership
|
||||
from synapse.api.constants import (
|
||||
Direction,
|
||||
EduTypes,
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
Membership,
|
||||
)
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
@ -218,7 +224,7 @@ class FederationServer(FederationBase):
|
||||
return 200, res
|
||||
|
||||
async def on_timestamp_to_event_request(
|
||||
self, origin: str, room_id: str, timestamp: int, direction: str
|
||||
self, origin: str, room_id: str, timestamp: int, direction: Direction
|
||||
) -> Tuple[int, Dict[str, Any]]:
|
||||
"""When we receive a federated `/timestamp_to_event` request,
|
||||
handle all of the logic for validating and fetching the event.
|
||||
@ -228,7 +234,7 @@ class FederationServer(FederationBase):
|
||||
room_id: Room to fetch the event from
|
||||
timestamp: The point in time (inclusive) we should navigate from in
|
||||
the given direction to find the closest event.
|
||||
direction: ["f"|"b"] to indicate whether we should navigate forward
|
||||
direction: indicates whether we should navigate forward
|
||||
or backward from the given timestamp to find the closest event.
|
||||
|
||||
Returns:
|
||||
|
@ -32,7 +32,7 @@ from typing import (
|
||||
import attr
|
||||
import ijson
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import Direction, Membership
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.api.urls import (
|
||||
@ -169,7 +169,7 @@ class TransportLayerClient:
|
||||
)
|
||||
|
||||
async def timestamp_to_event(
|
||||
self, destination: str, room_id: str, timestamp: int, direction: str
|
||||
self, destination: str, room_id: str, timestamp: int, direction: Direction
|
||||
) -> Union[JsonDict, List]:
|
||||
"""
|
||||
Calls a remote federating server at `destination` asking for their
|
||||
@ -180,7 +180,7 @@ class TransportLayerClient:
|
||||
room_id: Room to fetch the event from
|
||||
timestamp: The point in time (inclusive) we should navigate from in
|
||||
the given direction to find the closest event.
|
||||
direction: ["f"|"b"] to indicate whether we should navigate forward
|
||||
direction: indicates whether we should navigate forward
|
||||
or backward from the given timestamp to find the closest event.
|
||||
|
||||
Returns:
|
||||
@ -194,7 +194,7 @@ class TransportLayerClient:
|
||||
room_id,
|
||||
)
|
||||
|
||||
args = {"ts": [str(timestamp)], "dir": [direction]}
|
||||
args = {"ts": [str(timestamp)], "dir": [direction.value]}
|
||||
|
||||
remote_response = await self.client.get_json(
|
||||
destination, path=path, args=args, try_trailing_slash_on_400=True
|
||||
|
@ -26,7 +26,7 @@ from typing import (
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.api.constants import EduTypes
|
||||
from synapse.api.constants import Direction, EduTypes
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX
|
||||
@ -234,9 +234,10 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
|
||||
room_id: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
timestamp = parse_integer_from_args(query, "ts", required=True)
|
||||
direction = parse_string_from_args(
|
||||
query, "dir", default="f", allowed_values=["f", "b"], required=True
|
||||
direction_str = parse_string_from_args(
|
||||
query, "dir", allowed_values=["f", "b"], required=True
|
||||
)
|
||||
direction = Direction(direction_str)
|
||||
|
||||
return await self.handler.on_timestamp_to_event_request(
|
||||
origin, room_id, timestamp, direction
|
||||
|
@ -314,7 +314,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
def get_current_key(self, direction: str = "f") -> int:
|
||||
def get_current_key(self) -> int:
|
||||
return self.store.get_max_account_data_stream_id()
|
||||
|
||||
async def get_new_events(
|
||||
|
@ -315,5 +315,5 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
|
||||
|
||||
return events, to_key
|
||||
|
||||
def get_current_key(self, direction: str = "f") -> int:
|
||||
def get_current_key(self) -> int:
|
||||
return self.store.get_max_receipt_stream_id()
|
||||
|
@ -27,6 +27,7 @@ from typing_extensions import TypedDict
|
||||
|
||||
import synapse.events.snapshot
|
||||
from synapse.api.constants import (
|
||||
Direction,
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
GuestAccess,
|
||||
@ -1487,7 +1488,7 @@ class TimestampLookupHandler:
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
timestamp: int,
|
||||
direction: str,
|
||||
direction: Direction,
|
||||
) -> Tuple[str, int]:
|
||||
"""Find the closest event to the given timestamp in the given direction.
|
||||
If we can't find an event locally or the event we have locally is next to a gap,
|
||||
@ -1498,7 +1499,7 @@ class TimestampLookupHandler:
|
||||
room_id: Room to fetch the event from
|
||||
timestamp: The point in time (inclusive) we should navigate from in
|
||||
the given direction to find the closest event.
|
||||
direction: ["f"|"b"] to indicate whether we should navigate forward
|
||||
direction: indicates whether we should navigate forward
|
||||
or backward from the given timestamp to find the closest event.
|
||||
|
||||
Returns:
|
||||
@ -1533,13 +1534,13 @@ class TimestampLookupHandler:
|
||||
local_event_id, allow_none=False, allow_rejected=False
|
||||
)
|
||||
|
||||
if direction == "f":
|
||||
if direction == Direction.FORWARDS:
|
||||
# We only need to check for a backward gap if we're looking forwards
|
||||
# to ensure there is nothing in between.
|
||||
is_event_next_to_backward_gap = (
|
||||
await self.store.is_event_next_to_backward_gap(local_event)
|
||||
)
|
||||
elif direction == "b":
|
||||
elif direction == Direction.BACKWARDS:
|
||||
# We only need to check for a forward gap if we're looking backwards
|
||||
# to ensure there is nothing in between
|
||||
is_event_next_to_forward_gap = (
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
""" This module contains base REST classes for constructing REST servlets. """
|
||||
import enum
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
@ -362,6 +363,7 @@ def parse_string(
|
||||
request: Request,
|
||||
name: str,
|
||||
*,
|
||||
default: Optional[str] = None,
|
||||
required: bool = False,
|
||||
allowed_values: Optional[Iterable[str]] = None,
|
||||
encoding: str = "ascii",
|
||||
@ -413,6 +415,74 @@ def parse_string(
|
||||
)
|
||||
|
||||
|
||||
EnumT = TypeVar("EnumT", bound=enum.Enum)
|
||||
|
||||
|
||||
@overload
|
||||
def parse_enum(
|
||||
request: Request,
|
||||
name: str,
|
||||
E: Type[EnumT],
|
||||
default: EnumT,
|
||||
) -> EnumT:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_enum(
|
||||
request: Request,
|
||||
name: str,
|
||||
E: Type[EnumT],
|
||||
*,
|
||||
required: Literal[True],
|
||||
) -> EnumT:
|
||||
...
|
||||
|
||||
|
||||
def parse_enum(
|
||||
request: Request,
|
||||
name: str,
|
||||
E: Type[EnumT],
|
||||
default: Optional[EnumT] = None,
|
||||
required: bool = False,
|
||||
) -> Optional[EnumT]:
|
||||
"""
|
||||
Parse an enum parameter from the request query string.
|
||||
|
||||
Note that the enum *must only have string values*.
|
||||
|
||||
Args:
|
||||
request: the twisted HTTP request.
|
||||
name: the name of the query parameter.
|
||||
E: the enum which represents valid values
|
||||
default: enum value to use if the parameter is absent, defaults to None.
|
||||
required: whether to raise a 400 SynapseError if the
|
||||
parameter is absent, defaults to False.
|
||||
|
||||
Returns:
|
||||
An enum value.
|
||||
|
||||
Raises:
|
||||
SynapseError if the parameter is absent and required, or if the
|
||||
parameter is present, must be one of a list of allowed values and
|
||||
is not one of those allowed values.
|
||||
"""
|
||||
# Assert the enum values are strings.
|
||||
assert all(
|
||||
isinstance(e.value, str) for e in E
|
||||
), "parse_enum only works with string values"
|
||||
str_value = parse_string(
|
||||
request,
|
||||
name,
|
||||
default=default.value if default is not None else None,
|
||||
required=required,
|
||||
allowed_values=[e.value for e in E],
|
||||
)
|
||||
if str_value is None:
|
||||
return None
|
||||
return E(str_value)
|
||||
|
||||
|
||||
def _parse_string_value(
|
||||
value: bytes,
|
||||
allowed_values: Optional[Iterable[str]],
|
||||
|
@ -16,8 +16,9 @@ import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||
from synapse.types import JsonDict
|
||||
@ -60,7 +61,7 @@ class EventReportsRestServlet(RestServlet):
|
||||
|
||||
start = parse_integer(request, "from", default=0)
|
||||
limit = parse_integer(request, "limit", default=100)
|
||||
direction = parse_string(request, "dir", default="b")
|
||||
direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS)
|
||||
user_id = parse_string(request, "user_id")
|
||||
room_id = parse_string(request, "room_id")
|
||||
|
||||
@ -78,13 +79,6 @@ class EventReportsRestServlet(RestServlet):
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
if direction not in ("f", "b"):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Unknown direction: %s" % (direction,),
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
event_reports, total = await self.store.get_event_reports_paginate(
|
||||
start, limit, direction, user_id, room_id
|
||||
)
|
||||
|
@ -15,9 +15,10 @@ import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.federation.transport.server import Authenticator
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||
from synapse.storage.databases.main.transactions import DestinationSortOrder
|
||||
@ -79,7 +80,7 @@ class ListDestinationsRestServlet(RestServlet):
|
||||
allowed_values=[dest.value for dest in DestinationSortOrder],
|
||||
)
|
||||
|
||||
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
|
||||
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
|
||||
|
||||
destinations, total = await self._store.get_destinations_paginate(
|
||||
start, limit, destination, order_by, direction
|
||||
@ -192,7 +193,7 @@ class DestinationMembershipRestServlet(RestServlet):
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
|
||||
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
|
||||
|
||||
rooms, total = await self._store.get_destination_rooms_paginate(
|
||||
destination, start, limit, direction
|
||||
|
@ -17,9 +17,16 @@ import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_boolean,
|
||||
parse_enum,
|
||||
parse_integer,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import (
|
||||
admin_patterns,
|
||||
@ -389,7 +396,7 @@ class UserMediaRestServlet(RestServlet):
|
||||
# to newest media is on top for backward compatibility.
|
||||
if b"order_by" not in request.args and b"dir" not in request.args:
|
||||
order_by = MediaSortOrder.CREATED_TS.value
|
||||
direction = "b"
|
||||
direction = Direction.BACKWARDS
|
||||
else:
|
||||
order_by = parse_string(
|
||||
request,
|
||||
@ -397,8 +404,8 @@ class UserMediaRestServlet(RestServlet):
|
||||
default=MediaSortOrder.CREATED_TS.value,
|
||||
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
|
||||
)
|
||||
direction = parse_string(
|
||||
request, "dir", default="f", allowed_values=("f", "b")
|
||||
direction = parse_enum(
|
||||
request, "dir", Direction, default=Direction.FORWARDS
|
||||
)
|
||||
|
||||
media, total = await self.store.get_local_media_by_user_paginate(
|
||||
@ -447,7 +454,7 @@ class UserMediaRestServlet(RestServlet):
|
||||
# to newest media is on top for backward compatibility.
|
||||
if b"order_by" not in request.args and b"dir" not in request.args:
|
||||
order_by = MediaSortOrder.CREATED_TS.value
|
||||
direction = "b"
|
||||
direction = Direction.BACKWARDS
|
||||
else:
|
||||
order_by = parse_string(
|
||||
request,
|
||||
@ -455,8 +462,8 @@ class UserMediaRestServlet(RestServlet):
|
||||
default=MediaSortOrder.CREATED_TS.value,
|
||||
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
|
||||
)
|
||||
direction = parse_string(
|
||||
request, "dir", default="f", allowed_values=("f", "b")
|
||||
direction = parse_enum(
|
||||
request, "dir", Direction, default=Direction.FORWARDS
|
||||
)
|
||||
|
||||
media, _ = await self.store.get_local_media_by_user_paginate(
|
||||
|
@ -16,13 +16,14 @@ from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.http.servlet import (
|
||||
ResolveRoomIdMixin,
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_enum,
|
||||
parse_integer,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
@ -224,15 +225,8 @@ class ListRoomRestServlet(RestServlet):
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
direction = parse_string(request, "dir", default="f")
|
||||
if direction not in ("f", "b"):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Unknown direction: %s" % (direction,),
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
reverse_order = True if direction == "b" else False
|
||||
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
|
||||
reverse_order = True if direction == Direction.BACKWARDS else False
|
||||
|
||||
# Return list of rooms according to parameters
|
||||
rooms, total_rooms = await self.store.get_rooms_paginate(
|
||||
@ -949,7 +943,7 @@ class RoomTimestampToEventRestServlet(RestServlet):
|
||||
await assert_user_is_admin(self._auth, requester)
|
||||
|
||||
timestamp = parse_integer(request, "ts", required=True)
|
||||
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
|
||||
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
|
||||
|
||||
(
|
||||
event_id,
|
||||
|
@ -16,8 +16,9 @@ import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||
from synapse.storage.databases.main.stats import UserSortOrder
|
||||
@ -102,13 +103,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
direction = parse_string(request, "dir", default="f")
|
||||
if direction not in ("f", "b"):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Unknown direction: %s" % (direction,),
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
|
||||
|
||||
users_media, total = await self.store.get_users_media_usage_paginate(
|
||||
start, limit, from_ts, until_ts, order_by, direction, search_term
|
||||
|
@ -18,12 +18,13 @@ import secrets
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.constants import Direction, UserTypes
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_boolean,
|
||||
parse_enum,
|
||||
parse_integer,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
@ -120,7 +121,7 @@ class UsersRestServletV2(RestServlet):
|
||||
),
|
||||
)
|
||||
|
||||
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
|
||||
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
|
||||
|
||||
users, total = await self.store.get_users_paginate(
|
||||
start,
|
||||
|
@ -16,6 +16,7 @@ import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.handlers.relations import ThreadsListInclude
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
@ -59,7 +60,7 @@ class RelationPaginationServlet(RestServlet):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
pagination_config = await PaginationConfig.from_request(
|
||||
self._store, request, default_limit=5, default_dir="b"
|
||||
self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
|
||||
)
|
||||
|
||||
# The unstable version of this API returns an extra field for client
|
||||
|
@ -26,7 +26,7 @@ from prometheus_client.core import Histogram
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.constants import Direction, EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
@ -44,6 +44,7 @@ from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_boolean,
|
||||
parse_enum,
|
||||
parse_integer,
|
||||
parse_json_object_from_request,
|
||||
parse_string,
|
||||
@ -1297,7 +1298,7 @@ class TimestampLookupRestServlet(RestServlet):
|
||||
await self._auth.check_user_in_room_or_world_readable(room_id, requester)
|
||||
|
||||
timestamp = parse_integer(request, "ts", required=True)
|
||||
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
|
||||
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
|
||||
|
||||
(
|
||||
event_id,
|
||||
|
@ -17,6 +17,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
@ -167,7 +168,7 @@ class DataStore(
|
||||
guests: bool = True,
|
||||
deactivated: bool = False,
|
||||
order_by: str = UserSortOrder.NAME.value,
|
||||
direction: str = "f",
|
||||
direction: Direction = Direction.FORWARDS,
|
||||
approved: bool = True,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
"""Function to retrieve a paginated list of users from
|
||||
@ -197,7 +198,7 @@ class DataStore(
|
||||
# Set ordering
|
||||
order_by_column = UserSortOrder(order_by).value
|
||||
|
||||
if direction == "b":
|
||||
if direction == Direction.BACKWARDS:
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
|
@ -38,7 +38,7 @@ from typing_extensions import Literal
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import Direction, EventTypes
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.api.room_versions import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
@ -2240,7 +2240,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
async def get_event_id_for_timestamp(
|
||||
self, room_id: str, timestamp: int, direction: str
|
||||
self, room_id: str, timestamp: int, direction: Direction
|
||||
) -> Optional[str]:
|
||||
"""Find the closest event to the given timestamp in the given direction.
|
||||
|
||||
@ -2248,14 +2248,14 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
room_id: Room to fetch the event from
|
||||
timestamp: The point in time (inclusive) we should navigate from in
|
||||
the given direction to find the closest event.
|
||||
direction: ["f"|"b"] to indicate whether we should navigate forward
|
||||
direction: indicates whether we should navigate forward
|
||||
or backward from the given timestamp to find the closest event.
|
||||
|
||||
Returns:
|
||||
The closest event_id otherwise None if we can't find any event in
|
||||
the given direction.
|
||||
"""
|
||||
if direction == "b":
|
||||
if direction == Direction.BACKWARDS:
|
||||
# Find closest event *before* a given timestamp. We use descending
|
||||
# (which gives values largest to smallest) because we want the
|
||||
# largest possible timestamp *before* the given timestamp.
|
||||
@ -2307,9 +2307,6 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return None
|
||||
|
||||
if direction not in ("f", "b"):
|
||||
raise ValueError("Unknown direction: %s" % (direction,))
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_event_id_for_timestamp_txn",
|
||||
get_event_id_for_timestamp_txn,
|
||||
|
@ -26,6 +26,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
@ -176,7 +177,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
limit: int,
|
||||
user_id: str,
|
||||
order_by: str = MediaSortOrder.CREATED_TS.value,
|
||||
direction: str = "f",
|
||||
direction: Direction = Direction.FORWARDS,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""Get a paginated list of metadata for a local piece of media
|
||||
which an user_id has uploaded
|
||||
@ -199,7 +200,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
# Set ordering
|
||||
order_by_column = MediaSortOrder(order_by).value
|
||||
|
||||
if direction == "b":
|
||||
if direction == Direction.BACKWARDS:
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
|
@ -35,6 +35,7 @@ from typing import (
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import (
|
||||
Direction,
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
JoinRules,
|
||||
@ -2204,7 +2205,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
|
||||
self,
|
||||
start: int,
|
||||
limit: int,
|
||||
direction: str = "b",
|
||||
direction: Direction = Direction.BACKWARDS,
|
||||
user_id: Optional[str] = None,
|
||||
room_id: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
@ -2213,8 +2214,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
|
||||
Args:
|
||||
start: event offset to begin the query from
|
||||
limit: number of rows to retrieve
|
||||
direction: Whether to fetch the most recent first (`"b"`) or the
|
||||
oldest first (`"f"`)
|
||||
direction: Whether to fetch the most recent first (backwards) or the
|
||||
oldest first (forwards)
|
||||
user_id: search for user_id. Ignored if user_id is None
|
||||
room_id: search for room_id. Ignored if room_id is None
|
||||
Returns:
|
||||
@ -2236,7 +2237,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
|
||||
filters.append("er.room_id LIKE ?")
|
||||
args.extend(["%" + room_id + "%"])
|
||||
|
||||
if direction == "b":
|
||||
if direction == Direction.BACKWARDS:
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
|
@ -22,7 +22,7 @@ from typing_extensions import Counter
|
||||
|
||||
from twisted.internet.defer import DeferredLock
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
@ -663,7 +663,7 @@ class StatsStore(StateDeltasStore):
|
||||
from_ts: Optional[int] = None,
|
||||
until_ts: Optional[int] = None,
|
||||
order_by: Optional[str] = UserSortOrder.USER_ID.value,
|
||||
direction: Optional[str] = "f",
|
||||
direction: Direction = Direction.FORWARDS,
|
||||
search_term: Optional[str] = None,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
"""Function to retrieve a paginated list of users and their uploaded local media
|
||||
@ -714,7 +714,7 @@ class StatsStore(StateDeltasStore):
|
||||
500, "Incorrect value for order_by provided: %s" % order_by
|
||||
)
|
||||
|
||||
if direction == "b":
|
||||
if direction == Direction.BACKWARDS:
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
|
@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import db_to_json
|
||||
from synapse.storage.database import (
|
||||
@ -496,7 +497,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
limit: int,
|
||||
destination: Optional[str] = None,
|
||||
order_by: str = DestinationSortOrder.DESTINATION.value,
|
||||
direction: str = "f",
|
||||
direction: Direction = Direction.FORWARDS,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
"""Function to retrieve a paginated list of destinations.
|
||||
This will return a json list of destinations and the
|
||||
@ -518,7 +519,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
order_by_column = DestinationSortOrder(order_by).value
|
||||
|
||||
if direction == "b":
|
||||
if direction == Direction.BACKWARDS:
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
@ -550,7 +551,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
)
|
||||
|
||||
async def get_destination_rooms_paginate(
|
||||
self, destination: str, start: int, limit: int, direction: str = "f"
|
||||
self,
|
||||
destination: str,
|
||||
start: int,
|
||||
limit: int,
|
||||
direction: Direction = Direction.FORWARDS,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
"""Function to retrieve a paginated list of destination's rooms.
|
||||
This will return a json list of rooms and the
|
||||
@ -569,7 +574,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
|
||||
if direction == "b":
|
||||
if direction == Direction.BACKWARDS:
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
|
@ -18,7 +18,7 @@ import attr
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.http.servlet import parse_enum, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import StreamToken
|
||||
@ -44,15 +44,9 @@ class PaginationConfig:
|
||||
store: "DataStore",
|
||||
request: SynapseRequest,
|
||||
default_limit: int,
|
||||
default_dir: str = "f",
|
||||
default_dir: Direction = Direction.FORWARDS,
|
||||
) -> "PaginationConfig":
|
||||
direction_str = parse_string(
|
||||
request,
|
||||
"dir",
|
||||
default=default_dir,
|
||||
allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
|
||||
)
|
||||
direction = Direction(direction_str)
|
||||
direction = parse_enum(request, "dir", Direction, default=default_dir)
|
||||
|
||||
from_tok_str = parse_string(request, "from")
|
||||
to_tok_str = parse_string(request, "to")
|
||||
|
@ -280,7 +280,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(400, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
self.assertEqual("Unknown direction: bar", channel.json_body["error"])
|
||||
self.assertEqual(
|
||||
"Query parameter 'dir' must be one of ['b', 'f']",
|
||||
channel.json_body["error"],
|
||||
)
|
||||
|
||||
def test_limit_is_negative(self) -> None:
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user