Make StreamToken.room_key be a RoomStreamToken instance. (#8281)

This commit is contained in:
Erik Johnston 2020-09-11 12:22:55 +01:00 committed by GitHub
parent c312ee3cde
commit fe8ed1b46f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 114 additions and 99 deletions

1
changelog.d/8281.misc Normal file
View File

@ -0,0 +1 @@
Change `StreamToken.room_key` to be a `RoomStreamToken` instance.

View File

@ -46,10 +46,12 @@ files =
synapse/server_notices, synapse/server_notices,
synapse/spam_checker_api, synapse/spam_checker_api,
synapse/state, synapse/state,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py, synapse/storage/database.py,
synapse/storage/engines, synapse/storage/engines,
synapse/storage/persist_events.py,
synapse/storage/state.py, synapse/storage/state.py,
synapse/storage/util, synapse/storage/util,
synapse/streams, synapse/streams,

View File

@ -125,8 +125,8 @@ class AdminHandler(BaseHandler):
else: else:
stream_ordering = room.stream_ordering stream_ordering = room.stream_ordering
from_key = str(RoomStreamToken(0, 0)) from_key = RoomStreamToken(0, 0)
to_key = str(RoomStreamToken(None, stream_ordering)) to_key = RoomStreamToken(None, stream_ordering)
written_events = set() # Events that we've processed in this room written_events = set() # Events that we've processed in this room
@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
if not events: if not events:
break break
from_key = events[-1].internal_metadata.after from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
events = await filter_events_for_client(self.storage, user_id, events) events = await filter_events_for_client(self.storage, user_id, events)

View File

@ -29,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import ( from synapse.types import (
RoomStreamToken, RoomStreamToken,
StreamToken,
get_domain_from_id, get_domain_from_id,
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
) )
@ -104,18 +105,15 @@ class DeviceWorkerHandler(BaseHandler):
@trace @trace
@measure_func("device.get_user_ids_changed") @measure_func("device.get_user_ids_changed")
async def get_user_ids_changed(self, user_id, from_token): async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
"""Get list of users that have had the devices updated, or have newly """Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in. joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
""" """
set_tag("user_id", user_id) set_tag("user_id", user_id)
set_tag("from_token", from_token) set_tag("from_token", from_token)
now_room_key = await self.store.get_room_events_max_id() now_room_id = self.store.get_room_max_stream_ordering()
now_room_key = RoomStreamToken(None, now_room_id)
room_ids = await self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
@ -142,7 +140,7 @@ class DeviceWorkerHandler(BaseHandler):
) )
rooms_changed.update(event.room_id for event in member_events) rooms_changed.update(event.room_id for event in member_events)
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream stream_ordering = from_token.room_key.stream
possibly_changed = set(changed) possibly_changed = set(changed)
possibly_left = set() possibly_left = set()

View File

@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -167,7 +167,7 @@ class InitialSyncHandler(BaseHandler):
self.state_handler.get_current_state, event.room_id self.state_handler.get_current_state, event.room_id
) )
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,) room_end_token = RoomStreamToken(None, event.stream_ordering,)
deferred_room_state = run_in_background( deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id] self.state_store.get_state_for_events, [event.event_id]
) )

View File

@ -973,6 +973,7 @@ class EventCreationHandler:
This should only be run on the instance in charge of persisting events. This should only be run on the instance in charge of persisting events.
""" """
assert self._is_event_writer assert self._is_event_writer
assert self.storage.persistence is not None
if ratelimit: if ratelimit:
# We check if this is a room admin redacting an event so that we # We check if this is a room admin redacting an event so that we

View File

@ -344,7 +344,7 @@ class PaginationHandler:
# gets called. # gets called.
raise Exception("limit not set") raise Exception("limit not set")
room_token = RoomStreamToken.parse(from_token.room_key) room_token = from_token.room_key
with await self.pagination_lock.read(room_id): with await self.pagination_lock.read(room_id):
( (
@ -381,7 +381,7 @@ class PaginationHandler:
if leave_token.topological < max_topo: if leave_token.topological < max_topo:
from_token = from_token.copy_and_replace( from_token = from_token.copy_and_replace(
"room_key", leave_token_str "room_key", leave_token
) )
await self.hs.get_handlers().federation_handler.maybe_backfill( await self.hs.get_handlers().federation_handler.maybe_backfill(

View File

@ -1091,20 +1091,19 @@ class RoomEventSource:
async def get_new_events( async def get_new_events(
self, self,
user: UserID, user: UserID,
from_key: str, from_key: RoomStreamToken,
limit: int, limit: int,
room_ids: List[str], room_ids: List[str],
is_guest: bool, is_guest: bool,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], str]: ) -> Tuple[List[EventBase], RoomStreamToken]:
# We just ignore the key for now. # We just ignore the key for now.
to_key = self.get_current_key() to_key = self.get_current_key()
from_token = RoomStreamToken.parse(from_key) if from_key.topological:
if from_token.topological:
logger.warning("Stream has topological part!!!! %r", from_key) logger.warning("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,) from_key = RoomStreamToken(None, from_key.stream)
app_service = self.store.get_app_service_by_user_id(user.to_string()) app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service: if app_service:
@ -1133,14 +1132,14 @@ class RoomEventSource:
events[:] = events[:limit] events[:] = events[:limit]
if events: if events:
end_key = events[-1].internal_metadata.after end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
else: else:
end_key = to_key end_key = to_key
return (events, end_key) return (events, end_key)
def get_current_key(self) -> str: def get_current_key(self) -> RoomStreamToken:
return "s%d" % (self.store.get_room_max_stream_ordering(),) return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id) return self.store.get_room_events_max_id(room_id)

View File

@ -378,7 +378,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config sync_config = sync_result_builder.sync_config
with Measure(self.clock, "ephemeral_by_room"): with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else 0
room_ids = sync_result_builder.joined_room_ids room_ids = sync_result_builder.joined_room_ids
@ -402,7 +402,7 @@ class SyncHandler:
event_copy = {k: v for (k, v) in event.items() if k != "room_id"} event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy) ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0" receipt_key = since_token.receipt_key if since_token else 0
receipt_source = self.event_sources.sources["receipt"] receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = await receipt_source.get_new_events( receipts, receipt_key = await receipt_source.get_new_events(
@ -533,7 +533,7 @@ class SyncHandler:
if len(recents) > timeline_limit: if len(recents) > timeline_limit:
limited = True limited = True
recents = recents[-timeline_limit:] recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
prev_batch_token = now_token.copy_and_replace("room_key", room_key) prev_batch_token = now_token.copy_and_replace("room_key", room_key)
@ -1322,6 +1322,7 @@ class SyncHandler:
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
include_offline=include_offline, include_offline=include_offline,
) )
assert presence_key
sync_result_builder.now_token = now_token.copy_and_replace( sync_result_builder.now_token = now_token.copy_and_replace(
"presence_key", presence_key "presence_key", presence_key
) )
@ -1484,7 +1485,7 @@ class SyncHandler:
if rooms_changed: if rooms_changed:
return True return True
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream stream_id = since_token.room_key.stream
for room_id in sync_result_builder.joined_room_ids: for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id): if self.store.has_room_changed_since(room_id, stream_id):
return True return True
@ -1750,7 +1751,7 @@ class SyncHandler:
continue continue
leave_token = now_token.copy_and_replace( leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,) "room_key", RoomStreamToken(None, event.stream_ordering)
) )
room_entries.append( room_entries.append(
RoomSyncResultBuilder( RoomSyncResultBuilder(

View File

@ -25,6 +25,7 @@ from typing import (
Set, Set,
Tuple, Tuple,
TypeVar, TypeVar,
Union,
) )
from prometheus_client import Counter from prometheus_client import Counter
@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import Collection, StreamToken, UserID from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -111,7 +112,9 @@ class _NotifierUserStream:
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key: str, stream_id: int, time_now_ms: int): def notify(
self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
):
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
event source. event source.
Args: Args:
@ -294,7 +297,12 @@ class Notifier:
rooms.add(event.room_id) rooms.add(event.room_id)
if users or rooms: if users or rooms:
self.on_new_event("room_key", max_room_stream_id, users=users, rooms=rooms) self.on_new_event(
"room_key",
RoomStreamToken(None, max_room_stream_id),
users=users,
rooms=rooms,
)
self._on_updated_room_token(max_room_stream_id) self._on_updated_room_token(max_room_stream_id)
def _on_updated_room_token(self, max_room_stream_id: int): def _on_updated_room_token(self, max_room_stream_id: int):
@ -329,7 +337,7 @@ class Notifier:
def on_new_event( def on_new_event(
self, self,
stream_key: str, stream_key: str,
new_token: int, new_token: Union[int, RoomStreamToken],
users: Collection[UserID] = [], users: Collection[UserID] = [],
rooms: Collection[str] = [], rooms: Collection[str] = [],
): ):

View File

@ -47,6 +47,9 @@ class Storage:
# interfaces. # interfaces.
self.main = stores.main self.main = stores.main
self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores) self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores) self.state = StateGroupStorage(hs, stores)
self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorage(hs, stores)

View File

@ -213,7 +213,7 @@ class PersistEventsStore:
Returns: Returns:
Filtered event ids Filtered event ids
""" """
results = [] results = [] # type: List[str]
def _get_events_which_are_prevs_txn(txn, batch): def _get_events_which_are_prevs_txn(txn, batch):
sql = """ sql = """
@ -631,7 +631,9 @@ class PersistEventsStore:
) )
@classmethod @classmethod
def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): def _filter_events_and_contexts_for_duplicates(
cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
) -> List[Tuple[EventBase, EventContext]]:
"""Ensure that we don't have the same event twice. """Ensure that we don't have the same event twice.
Pick the earliest non-outlier if there is one, else the earliest one. Pick the earliest non-outlier if there is one, else the earliest one.
@ -641,7 +643,9 @@ class PersistEventsStore:
Returns: Returns:
list[(EventBase, EventContext)]: filtered list list[(EventBase, EventContext)]: filtered list
""" """
new_events_and_contexts = OrderedDict() new_events_and_contexts = (
OrderedDict()
) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
for event, context in events_and_contexts: for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id) prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context: if prev_event_context:
@ -655,7 +659,12 @@ class PersistEventsStore:
new_events_and_contexts[event.event_id] = (event, context) new_events_and_contexts[event.event_id] = (event, context)
return list(new_events_and_contexts.values()) return list(new_events_and_contexts.values())
def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): def _update_room_depths_txn(
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
):
"""Update min_depth for each room """Update min_depth for each room
Args: Args:
@ -664,7 +673,7 @@ class PersistEventsStore:
we are persisting we are persisting
backfilled (bool): True if the events were backfilled backfilled (bool): True if the events were backfilled
""" """
depth_updates = {} depth_updates = {} # type: Dict[str, int]
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids # Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id) txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@ -1436,7 +1445,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events. Forward extremities are handled when we first start persisting the events.
""" """
events_by_room = {} events_by_room = {} # type: Dict[str, List[EventBase]]
for ev in events: for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev) events_by_room.setdefault(ev.room_id, []).append(ev)

View File

@ -310,11 +310,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_rooms( async def get_room_events_stream_for_rooms(
self, self,
room_ids: Collection[str], room_ids: Collection[str],
from_key: str, from_key: RoomStreamToken,
to_key: str, to_key: RoomStreamToken,
limit: int = 0, limit: int = 0,
order: str = "DESC", order: str = "DESC",
) -> Dict[str, Tuple[List[EventBase], str]]: ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
"""Get new room events in stream ordering since `from_key`. """Get new room events in stream ordering since `from_key`.
Args: Args:
@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
- list of recent events in the room - list of recent events in the room
- stream ordering key for the start of the chunk of events returned. - stream ordering key for the start of the chunk of events returned.
""" """
from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = self._events_stream_cache.get_entities_changed(
room_ids, from_key.stream
room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) )
if not room_ids: if not room_ids:
return {} return {}
@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results return results
def get_rooms_that_changed( def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str self, room_ids: Collection[str], from_key: RoomStreamToken
) -> Set[str]: ) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have """Given a list of rooms and a token, return rooms where there may have
been changes. been changes.
Args:
room_ids
from_key: The room_key portion of a StreamToken
""" """
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = from_key.stream
return { return {
room_id room_id
for room_id in room_ids for room_id in room_ids
@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_room( async def get_room_events_stream_for_room(
self, self,
room_id: str, room_id: str,
from_key: str, from_key: RoomStreamToken,
to_key: str, to_key: RoomStreamToken,
limit: int = 0, limit: int = 0,
order: str = "DESC", order: str = "DESC",
) -> Tuple[List[EventBase], str]: ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get new room events in stream ordering since `from_key`. """Get new room events in stream ordering since `from_key`.
Args: Args:
@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if from_key == to_key: if from_key == to_key:
return [], from_key return [], from_key
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = from_key.stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream to_id = to_key.stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse() ret.reverse()
if rows: if rows:
key = "s%d" % min(r.stream_ordering for r in rows) key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
else: else:
# Assume we didn't get anything because there was nothing to # Assume we didn't get anything because there was nothing to
# get. # get.
@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key return ret, key
async def get_membership_changes_for_user( async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]: ) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = from_key.stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream to_id = to_key.stream
if from_key == to_key: if from_key == to_key:
return [] return []
@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret return ret
async def get_recent_events_for_room( async def get_recent_events_for_room(
self, room_id: str, limit: int, end_token: str self, room_id: str, limit: int, end_token: RoomStreamToken
) -> Tuple[List[EventBase], str]: ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering. """Get the most recent events in the room in topological ordering.
Args: Args:
@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token) return (events, token)
async def get_recent_event_ids_for_room( async def get_recent_event_ids_for_room(
self, room_id: str, limit: int, end_token: str self, room_id: str, limit: int, end_token: RoomStreamToken
) -> Tuple[List[_EventDictReturn], str]: ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering. """Get the most recent events in the room in topological ordering.
Args: Args:
@ -535,13 +531,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0: if limit == 0:
return [], end_token return [], end_token
parsed_end_token = RoomStreamToken.parse(end_token)
rows, token = await self.db_pool.runInteraction( rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room", "get_recent_event_ids_for_room",
self._paginate_room_events_txn, self._paginate_room_events_txn,
room_id, room_id,
from_token=parsed_end_token, from_token=end_token,
limit=limit, limit=limit,
) )
@ -619,17 +613,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
allow_none=allow_none, allow_none=allow_none,
) )
async def get_stream_token_for_event(self, event_id: str) -> str: async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event """The stream token for an event
Args: Args:
event_id: The id of the event to look up a stream token for. event_id: The id of the event to look up a stream token for.
Raises: Raises:
StoreError if the event wasn't in the database. StoreError if the event wasn't in the database.
Returns: Returns:
A "s%d" stream token. A stream token.
""" """
stream_id = await self.get_stream_id_for_event(event_id) stream_id = await self.get_stream_id_for_event(event_id)
return "s%d" % (stream_id,) return RoomStreamToken(None, stream_id)
async def get_topological_token_for_event(self, event_id: str) -> str: async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event """The stream token for an event
@ -954,7 +948,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: str = "b", direction: str = "b",
limit: int = -1, limit: int = -1,
event_filter: Optional[Filter] = None, event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], str]: ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Returns list of events before or after a given token. """Returns list of events before or after a given token.
Args: Args:
@ -1054,17 +1048,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead. # TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token next_token = to_token if to_token else from_token
return rows, str(next_token) return rows, next_token
async def paginate_room_events( async def paginate_room_events(
self, self,
room_id: str, room_id: str,
from_key: str, from_key: RoomStreamToken,
to_key: Optional[str] = None, to_key: Optional[RoomStreamToken] = None,
direction: str = "b", direction: str = "b",
limit: int = -1, limit: int = -1,
event_filter: Optional[Filter] = None, event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], str]: ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Returns list of events before or after a given token. """Returns list of events before or after a given token.
Args: Args:
@ -1083,17 +1077,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`). and `to_key`).
""" """
parsed_from_key = RoomStreamToken.parse(from_key)
parsed_to_key = None
if to_key:
parsed_to_key = RoomStreamToken.parse(to_key)
rows, token = await self.db_pool.runInteraction( rows, token = await self.db_pool.runInteraction(
"paginate_room_events", "paginate_room_events",
self._paginate_room_events_txn, self._paginate_room_events_txn,
room_id, room_id,
parsed_from_key, from_key,
parsed_to_key, to_key,
direction, direction,
limit, limit,
event_filter, event_filter,

View File

@ -18,7 +18,7 @@
import itertools import itertools
import logging import logging
from collections import deque, namedtuple from collections import deque, namedtuple
from typing import Iterable, List, Optional, Set, Tuple from typing import Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Histogram from prometheus_client import Counter, Histogram
@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState from synapse.storage.databases.main.events import DeltaState
from synapse.types import StateMap from synapse.types import Collection, StateMap
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -185,6 +185,8 @@ class EventsPersistenceStorage:
# store for now. # store for now.
self.main_store = stores.main self.main_store = stores.main
self.state_store = stores.state self.state_store = stores.state
assert stores.persist_events
self.persist_events_store = stores.persist_events self.persist_events_store = stores.persist_events
self._clock = hs.get_clock() self._clock = hs.get_clock()
@ -208,7 +210,7 @@ class EventsPersistenceStorage:
Returns: Returns:
the stream ordering of the latest persisted event the stream ordering of the latest persisted event
""" """
partitioned = {} partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts: for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx)) partitioned.setdefault(event.room_id, []).append((event, ctx))
@ -305,7 +307,9 @@ class EventsPersistenceStorage:
# Work out the new "current state" for each room. # Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then # We do this by working out what the new extremities are and then
# calculating the state from that. # calculating the state from that.
events_by_room = {} events_by_room = (
{}
) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, context in chunk: for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append( events_by_room.setdefault(event.room_id, []).append(
(event, context) (event, context)
@ -436,7 +440,7 @@ class EventsPersistenceStorage:
self, self,
room_id: str, room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]], event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str], latest_event_ids: Collection[str],
): ):
"""Calculates the new forward extremities for a room given events to """Calculates the new forward extremities for a room given events to
persist. persist.
@ -470,7 +474,7 @@ class EventsPersistenceStorage:
# Remove any events which are prev_events of any existing events. # Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs( existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
result result
) ) # type: Collection[str]
result.difference_update(existing_prevs) result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev # Finally handle the case where the new events have soft-failed prev

View File

@ -425,7 +425,9 @@ class RoomStreamToken:
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class StreamToken: class StreamToken:
room_key = attr.ib(type=str) room_key = attr.ib(
type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
)
presence_key = attr.ib(type=int) presence_key = attr.ib(type=int)
typing_key = attr.ib(type=int) typing_key = attr.ib(type=int)
receipt_key = attr.ib(type=int) receipt_key = attr.ib(type=int)
@ -445,21 +447,16 @@ class StreamToken:
while len(keys) < len(attr.fields(cls)): while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key # i.e. old token from before receipt_key
keys.append("0") keys.append("0")
return cls(keys[0], *(int(k) for k in keys[1:])) return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
except Exception: except Exception:
raise SynapseError(400, "Invalid Token") raise SynapseError(400, "Invalid Token")
def to_string(self): def to_string(self):
return self._SEPARATOR.join([str(k) for k in attr.astuple(self)]) return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
@property @property
def room_stream_id(self): def room_stream_id(self):
# TODO(markjh): Awful hack to work around hacks in the presence tests return self.room_key.stream
# which assume that the keys are integers.
if type(self.room_key) is int:
return self.room_key
else:
return int(self.room_key[1:].split("-")[-1])
def is_after(self, other): def is_after(self, other):
"""Does this token contain events that the other doesn't?""" """Does this token contain events that the other doesn't?"""
@ -475,7 +472,7 @@ class StreamToken:
or (int(other.groups_key) < int(self.groups_key)) or (int(other.groups_key) < int(self.groups_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the """Advance the given key in the token to a new value if and only if the
new value is after the old value. new value is after the old value.
""" """
@ -491,7 +488,7 @@ class StreamToken:
else: else:
return self return self
def copy_and_replace(self, key, new_value): def copy_and_replace(self, key, new_value) -> "StreamToken":
return attr.evolve(self, **{key: new_value}) return attr.evolve(self, **{key: new_value})

View File

@ -71,7 +71,10 @@ async def inject_event(
""" """
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
await hs.get_storage().persistence.persist_event(event, context) persistence = hs.get_storage().persistence
assert persistence is not None
await persistence.persist_event(event, context)
return event return event