Use an enum for direction. (#14927)

For better type safety we  use an enum instead of strings to
configure direction (backwards or forwards).
This commit is contained in:
Patrick Cloke 2023-01-27 07:27:55 -05:00 committed by GitHub
parent fc35e0673f
commit 265735db9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 76 additions and 44 deletions

1
changelog.d/14927.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints.

View File

@ -17,6 +17,8 @@
"""Contains constants from the specification."""
import enum
from typing_extensions import Final
# the max size of a (canonical-json-encoded) event
@ -290,3 +292,8 @@ class ApprovalNoticeMedium:
NONE = "org.matrix.msc3866.none"
EMAIL = "org.matrix.msc3866.email"
class Direction(enum.Enum):
BACKWARDS = "b"
FORWARDS = "f"

View File

@ -16,7 +16,7 @@ import abc
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from synapse.api.constants import Membership
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
@ -197,7 +197,7 @@ class AdminHandler:
# efficient method perhaps but it does guarantee we get everything.
while True:
events, _ = await self.store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction="f"
room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
)
if not events:
break

View File

@ -15,7 +15,13 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
from synapse.api.constants import (
AccountDataTypes,
Direction,
EduTypes,
EventTypes,
Membership,
)
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
@ -57,7 +63,13 @@ class InitialSyncHandler:
self.validator = EventValidator()
self.snapshot_cache: ResponseCache[
Tuple[
str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
str,
Optional[StreamToken],
Optional[StreamToken],
Direction,
int,
bool,
bool,
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()

View File

@ -19,7 +19,7 @@ import attr
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
@ -448,7 +448,7 @@ class PaginationHandler:
if pagin_config.from_token:
from_token = pagin_config.from_token
elif pagin_config.direction == "f":
elif pagin_config.direction == Direction.FORWARDS:
from_token = (
await self.hs.get_event_sources().get_start_token_for_pagination(
room_id
@ -476,7 +476,7 @@ class PaginationHandler:
room_id, requester, allow_departed_users=True
)
if pagin_config.direction == "b":
if pagin_config.direction == Direction.BACKWARDS:
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, O
import attr
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -413,7 +413,11 @@ class RelationsHandler:
# Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.THREAD, direction="f"
event_id,
event,
room_id,
RelationTypes.THREAD,
direction=Direction.FORWARDS,
)
# Filter out ignored users.

View File

@ -30,7 +30,7 @@ from typing import (
import attr
from synapse.api.constants import MAIN_TIMELINE, RelationTypes
from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
@ -168,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore):
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
@ -181,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore):
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
direction: Whether to fetch the most recent first (backwards) or the
oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.

View File

@ -55,6 +55,7 @@ from typing_extensions import Literal
from twisted.internet import defer
from synapse.api.constants import Direction
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000
_STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
@ -104,7 +104,7 @@ class _EventsAround:
def generate_pagination_where_clause(
direction: str,
direction: Direction,
column_names: Tuple[str, str],
from_token: Optional[Tuple[Optional[int], int]],
to_token: Optional[Tuple[Optional[int], int]],
@ -130,27 +130,26 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
direction: Whether we're paginating backwards("b") or forwards ("f").
direction: Whether we're paginating backwards or forwards.
column_names: The column names to bound. Must *not* be user defined as
these get inserted directly into the SQL statement without escapes.
from_token: The start point for the pagination. This is an exclusive
minimum bound if direction is "f", and an inclusive maximum bound if
direction is "b".
minimum bound if direction is forwards, and an inclusive maximum bound if
direction is backwards.
to_token: The endpoint point for the pagination. This is an inclusive
maximum bound if direction is "f", and an exclusive minimum bound if
direction is "b".
maximum bound if direction is forwards, and an exclusive minimum bound if
direction is backwards.
engine: The database engine to generate the clauses for
Returns:
The sql expression
"""
assert direction in ("b", "f")
where_clause = []
if from_token:
where_clause.append(
_make_generic_sql_bound(
bound=">=" if direction == "b" else "<",
bound=">=" if direction == Direction.BACKWARDS else "<",
column_names=column_names,
values=from_token,
engine=engine,
@ -160,7 +159,7 @@ def generate_pagination_where_clause(
if to_token:
where_clause.append(
_make_generic_sql_bound(
bound="<" if direction == "b" else ">=",
bound="<" if direction == Direction.BACKWARDS else ">=",
column_names=column_names,
values=to_token,
engine=engine,
@ -171,7 +170,7 @@ def generate_pagination_where_clause(
def generate_pagination_bounds(
direction: str,
direction: Direction,
from_token: Optional[RoomStreamToken],
to_token: Optional[RoomStreamToken],
) -> Tuple[
@ -181,7 +180,7 @@ def generate_pagination_bounds(
Generate a start and end point for this page of events.
Args:
direction: Whether pagination is going forwards or backwards. One of "f" or "b".
direction: Whether pagination is going forwards or backwards.
from_token: The token to start pagination at, or None to start at the first value.
to_token: The token to end pagination at, or None to not limit the end point.
@ -201,7 +200,7 @@ def generate_pagination_bounds(
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
if direction == "b":
if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
@ -215,7 +214,7 @@ def generate_pagination_bounds(
if from_token:
if from_token.topological is not None:
from_bound = from_token.as_historical_tuple()
elif direction == "b":
elif direction == Direction.BACKWARDS:
from_bound = (
None,
from_token.get_max_stream_pos(),
@ -230,7 +229,7 @@ def generate_pagination_bounds(
if to_token:
if to_token.topological is not None:
to_bound = to_token.as_historical_tuple()
elif direction == "b":
elif direction == Direction.BACKWARDS:
to_bound = (
None,
to_token.stream,
@ -245,20 +244,20 @@ def generate_pagination_bounds(
def generate_next_token(
direction: str, last_topo_ordering: int, last_stream_ordering: int
direction: Direction, last_topo_ordering: int, last_stream_ordering: int
) -> RoomStreamToken:
"""
Generate the next room stream token based on the currently returned data.
Args:
direction: Whether pagination is going forwards or backwards. One of "f" or "b".
direction: Whether pagination is going forwards or backwards.
last_topo_ordering: The last topological ordering being returned.
last_stream_ordering: The last stream ordering being returned.
Returns:
A new RoomStreamToken to return to the client.
"""
if direction == "b":
if direction == Direction.BACKWARDS:
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk
@ -1201,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
before_token,
direction="b",
direction=Direction.BACKWARDS,
limit=before_limit,
event_filter=event_filter,
)
@ -1211,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
after_token,
direction="f",
direction=Direction.FORWARDS,
limit=after_limit,
event_filter=event_filter,
)
@ -1374,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
direction: str = "b",
direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
@ -1385,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id
from_token: The token used to stream from
to_token: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating
forwards or backwards from `from_key`.
direction: Indicates whether we are paginating forwards or backwards
from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to
those that match the filter.
@ -1489,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
_EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results(
lower_token=to_token if direction == "b" else from_token,
upper_token=from_token if direction == "b" else to_token,
lower_token=to_token
if direction == Direction.BACKWARDS
else from_token,
upper_token=from_token
if direction == Direction.BACKWARDS
else to_token,
instance_name=instance_name,
topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
@ -1514,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id: str,
from_key: RoomStreamToken,
to_key: Optional[RoomStreamToken] = None,
direction: str = "b",
direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
@ -1524,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id
from_key: The token used to stream from
to_key: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating
forwards or backwards from `from_key`.
direction: Indicates whether we are paginating forwards or backwards
from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to those that match the filter.

View File

@ -16,6 +16,7 @@ from typing import Optional
import attr
from synapse.api.constants import Direction
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
@ -34,7 +35,7 @@ class PaginationConfig:
from_token: Optional[StreamToken]
to_token: Optional[StreamToken]
direction: str
direction: Direction
limit: int
@classmethod
@ -45,9 +46,13 @@ class PaginationConfig:
default_limit: int,
default_dir: str = "f",
) -> "PaginationConfig":
direction = parse_string(
request, "dir", default=default_dir, allowed_values=["f", "b"]
direction_str = parse_string(
request,
"dir",
default=default_dir,
allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
)
direction = Direction(direction_str)
from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")