Convert stream database to async/await. (#8074)

This commit is contained in:
Patrick Cloke 2020-08-17 07:24:46 -04:00 committed by GitHub
parent ac77cdb64e
commit ad6190c925
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 224 additions and 227 deletions

1
changelog.d/8074.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -23,7 +23,7 @@ from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState from synapse.api.presence import UserPresenceState
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
FILTER_SCHEMA = { FILTER_SCHEMA = {

View File

@ -37,8 +37,8 @@ from sortedcontainers import SortedDict
from twisted.internet import defer from twisted.internet import defer
from synapse.api.presence import UserPresenceState
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from .units import Edu from .units import Edu

View File

@ -22,6 +22,7 @@ from twisted.internet import defer
import synapse import synapse
import synapse.metrics import synapse.metrics
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.sender.transaction_manager import TransactionManager
@ -39,7 +40,6 @@ from synapse.metrics import (
events_processed_counter, events_processed_counter,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt from synapse.types import ReadReceipt
from synapse.util.metrics import Measure, measure_func from synapse.util.metrics import Measure, measure_func

View File

@ -24,12 +24,12 @@ from synapse.api.errors import (
HttpResponseException, HttpResponseException,
RequestSendFailed, RequestSendFailed,
) )
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase from synapse.events import EventBase
from synapse.federation.units import Edu from synapse.federation.units import Edu
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter

View File

@ -33,13 +33,13 @@ from typing_extensions import ContextManager
import synapse.metrics import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.storage.presence import UserPresenceState
from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached

View File

@ -15,8 +15,8 @@
from typing import List, Tuple from typing import List, Tuple
from synapse.api.presence import UserPresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter

View File

@ -39,15 +39,17 @@ what sort order was used:
import abc import abc
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Optional from typing import Dict, Iterable, List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -68,8 +70,12 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause( def generate_pagination_where_clause(
direction, column_names, from_token, to_token, engine direction: str,
): column_names: Tuple[str, str],
from_token: Optional[Tuple[int, int]],
to_token: Optional[Tuple[int, int]],
engine: BaseDatabaseEngine,
) -> str:
"""Creates an SQL expression to bound the columns by the pagination """Creates an SQL expression to bound the columns by the pagination
tokens. tokens.
@ -90,21 +96,19 @@ def generate_pagination_where_clause(
token, but include those that match the to token. token, but include those that match the to token.
Args: Args:
direction (str): Whether we're paginating backwards("b") or direction: Whether we're paginating backwards("b") or forwards ("f").
forwards ("f"). column_names: The column names to bound. Must *not* be user defined as
column_names (tuple[str, str]): The column names to bound. Must *not* these get inserted directly into the SQL statement without escapes.
be user defined as these get inserted directly into the SQL from_token: The start point for the pagination. This is an exclusive
statement without escapes. minimum bound if direction is "f", and an inclusive maximum bound if
from_token (tuple[int, int]|None): The start point for the pagination. direction is "b".
This is an exclusive minimum bound if direction is "f", and an to_token: The endpoint point for the pagination. This is an inclusive
inclusive maximum bound if direction is "b". maximum bound if direction is "f", and an exclusive minimum bound if
to_token (tuple[int, int]|None): The endpoint point for the pagination. direction is "b".
This is an inclusive maximum bound if direction is "f", and an
exclusive minimum bound if direction is "b".
engine: The database engine to generate the clauses for engine: The database engine to generate the clauses for
Returns: Returns:
str: The sql expression The sql expression
""" """
assert direction in ("b", "f") assert direction in ("b", "f")
@ -132,7 +136,12 @@ def generate_pagination_where_clause(
return " AND ".join(where_clause) return " AND ".join(where_clause)
def _make_generic_sql_bound(bound, column_names, values, engine): def _make_generic_sql_bound(
bound: str,
column_names: Tuple[str, str],
values: Tuple[Optional[int], int],
engine: BaseDatabaseEngine,
) -> str:
"""Create an SQL expression that bounds the given column names by the """Create an SQL expression that bounds the given column names by the
values, e.g. create the equivalent of `(1, 2) < (col1, col2)`. values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
@ -142,18 +151,18 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
out manually. out manually.
Args: Args:
bound (str): The comparison operator to use. One of ">", "<", ">=", bound: The comparison operator to use. One of ">", "<", ">=",
"<=", where the values are on the left and columns on the right. "<=", where the values are on the left and columns on the right.
names (tuple[str, str]): The column names. Must *not* be user defined names: The column names. Must *not* be user defined
as these get inserted directly into the SQL statement without as these get inserted directly into the SQL statement without
escapes. escapes.
values (tuple[int|None, int]): The values to bound the columns by. If values: The values to bound the columns by. If
the first value is None then only creates a bound on the second the first value is None then only creates a bound on the second
column. column.
engine: The database engine to generate the SQL for engine: The database engine to generate the SQL for
Returns: Returns:
str The SQL statement
""" """
assert bound in (">", "<", ">=", "<=") assert bound in (">", "<", ">=", "<=")
@ -193,7 +202,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
) )
def filter_to_clause(event_filter): def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't # NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create # have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise. # "room_id == X AND room_id != X", which postgres doesn't optimise.
@ -291,34 +300,35 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_room_min_stream_ordering(self): def get_room_min_stream_ordering(self):
raise NotImplementedError() raise NotImplementedError()
@defer.inlineCallbacks async def get_room_events_stream_for_rooms(
def get_room_events_stream_for_rooms( self,
self, room_ids, from_key, to_key, limit=0, order="DESC" room_ids: Iterable[str],
): from_key: str,
to_key: str,
limit: int = 0,
order: str = "DESC",
) -> Dict[str, Tuple[List[EventBase], str]]:
"""Get new room events in stream ordering since `from_key`. """Get new room events in stream ordering since `from_key`.
Args: Args:
room_id (str) room_ids
from_key (str): Token from which no events are returned before from_key: Token from which no events are returned before
to_key (str): Token from which no events are returned after. (This to_key: Token from which no events are returned after. (This
is typically the current stream token) is typically the current stream token)
limit (int): Maximum number of events to return limit: Maximum number of events to return
order (str): Either "DESC" or "ASC". Determines which events are order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the recent `limit` events are returned, otherwise returns the
oldest `limit` events. oldest `limit` events.
Returns: Returns:
Deferred[dict[str,tuple[list[FrozenEvent], str]]]
A map from room id to a tuple containing: A map from room id to a tuple containing:
- list of recent events in the room - list of recent events in the room
- stream ordering key for the start of the chunk of events returned. - stream ordering key for the start of the chunk of events returned.
""" """
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = yield self._events_stream_cache.get_entities_changed( room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
room_ids, from_id
)
if not room_ids: if not room_ids:
return {} return {}
@ -326,7 +336,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {} results = {}
room_ids = list(room_ids) room_ids = list(room_ids)
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)): for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
res = yield make_deferred_yieldable( res = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background( run_in_background(
@ -361,28 +371,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if self._events_stream_cache.has_entity_changed(room_id, from_key) if self._events_stream_cache.has_entity_changed(room_id, from_key)
} }
@defer.inlineCallbacks async def get_room_events_stream_for_room(
def get_room_events_stream_for_room( self,
self, room_id, from_key, to_key, limit=0, order="DESC" room_id: str,
): from_key: str,
to_key: str,
limit: int = 0,
order: str = "DESC",
) -> Tuple[List[EventBase], str]:
"""Get new room events in stream ordering since `from_key`. """Get new room events in stream ordering since `from_key`.
Args: Args:
room_id (str) room_id
from_key (str): Token from which no events are returned before from_key: Token from which no events are returned before
to_key (str): Token from which no events are returned after. (This to_key: Token from which no events are returned after. (This
is typically the current stream token) is typically the current stream token)
limit (int): Maximum number of events to return limit: Maximum number of events to return
order (str): Either "DESC" or "ASC". Determines which events are order: Either "DESC" or "ASC". Determines which events are
returned when the result is limited. If "DESC" then the most returned when the result is limited. If "DESC" then the most
recent `limit` events are returned, otherwise returns the recent `limit` events are returned, otherwise returns the
oldest `limit` events. oldest `limit` events.
Returns: Returns:
Deferred[tuple[list[FrozenEvent], str]]: Returns the list of The list of events (in ascending order) and the token from the start
events (in ascending order) and the token from the start of of the chunk of events returned.
the chunk of events returned.
""" """
if from_key == to_key: if from_key == to_key:
return [], from_key return [], from_key
@ -390,9 +403,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream
has_changed = yield self._events_stream_cache.has_entity_changed( has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
room_id, from_id
)
if not has_changed: if not has_changed:
return [], from_key return [], from_key
@ -410,9 +421,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows return rows
rows = yield self.db_pool.runInteraction("get_room_events_stream_for_room", f) rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
ret = yield self.get_events_as_list( ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True [r.event_id for r in rows], get_prev_content=True
) )
@ -430,8 +441,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key return ret, key
@defer.inlineCallbacks async def get_membership_changes_for_user(self, user_id, from_key, to_key):
def get_membership_changes_for_user(self, user_id, from_key, to_key):
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream to_id = RoomStreamToken.parse_stream_token(to_key).stream
@ -460,9 +470,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows return rows
rows = yield self.db_pool.runInteraction("get_membership_changes_for_user", f) rows = await self.db_pool.runInteraction("get_membership_changes_for_user", f)
ret = yield self.get_events_as_list( ret = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True [r.event_id for r in rows], get_prev_content=True
) )
@ -470,27 +480,26 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret return ret
@defer.inlineCallbacks async def get_recent_events_for_room(
def get_recent_events_for_room(self, room_id, limit, end_token): self, room_id: str, limit: int, end_token: str
) -> Tuple[List[EventBase], str]:
"""Get the most recent events in the room in topological ordering. """Get the most recent events in the room in topological ordering.
Args: Args:
room_id (str) room_id
limit (int) limit
end_token (str): The stream token representing now. end_token: The stream token representing now.
Returns: Returns:
Deferred[tuple[list[FrozenEvent], str]]: Returns a list of A list of events and a token pointing to the start of the returned
events and a token pointing to the start of the returned events. The events returned are in ascending order.
events.
The events returned are in ascending order.
""" """
rows, token = yield self.get_recent_event_ids_for_room( rows, token = await self.get_recent_event_ids_for_room(
room_id, limit, end_token room_id, limit, end_token
) )
events = yield self.get_events_as_list( events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True [r.event_id for r in rows], get_prev_content=True
) )
@ -498,20 +507,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token) return (events, token)
@defer.inlineCallbacks async def get_recent_event_ids_for_room(
def get_recent_event_ids_for_room(self, room_id, limit, end_token): self, room_id: str, limit: int, end_token: str
) -> Tuple[List[_EventDictReturn], str]:
"""Get the most recent events in the room in topological ordering. """Get the most recent events in the room in topological ordering.
Args: Args:
room_id (str) room_id
limit (int) limit
end_token (str): The stream token representing now. end_token: The stream token representing now.
Returns: Returns:
Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of A list of _EventDictReturn and a token pointing to the start of the
_EventDictReturn and a token pointing to the start of the returned returned events. The events returned are in ascending order.
events.
The events returned are in ascending order.
""" """
# Allow a zero limit here, and no-op. # Allow a zero limit here, and no-op.
if limit == 0: if limit == 0:
@ -519,7 +527,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token) end_token = RoomStreamToken.parse(end_token)
rows, token = yield self.db_pool.runInteraction( rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room", "get_recent_event_ids_for_room",
self._paginate_room_events_txn, self._paginate_room_events_txn,
room_id, room_id,
@ -532,12 +540,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token return rows, token
def get_room_event_before_stream_ordering(self, room_id, stream_ordering): def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
"""Gets details of the first event in a room at or before a stream ordering """Gets details of the first event in a room at or before a stream ordering
Args: Args:
room_id (str): room_id:
stream_ordering (int): stream_ordering:
Returns: Returns:
Deferred[(int, int, str)]: Deferred[(int, int, str)]:
@ -574,55 +582,56 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
) )
return "t%d-%d" % (topo, token) return "t%d-%d" % (topo, token)
def get_stream_token_for_event(self, event_id): async def get_stream_token_for_event(self, event_id: str) -> str:
"""The stream token for an event """The stream token for an event
Args: Args:
event_id(str): The id of the event to look up a stream token for. event_id: The id of the event to look up a stream token for.
Raises: Raises:
StoreError if the event wasn't in the database. StoreError if the event wasn't in the database.
Returns: Returns:
A deferred "s%d" stream token. A "s%d" stream token.
""" """
return self.db_pool.simple_select_one_onecol( row = await self.db_pool.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering" table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,)) )
return "s%d" % (row,)
def get_topological_token_for_event(self, event_id): async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event """The stream token for an event
Args: Args:
event_id(str): The id of the event to look up a stream token for. event_id: The id of the event to look up a stream token for.
Raises: Raises:
StoreError if the event wasn't in the database. StoreError if the event wasn't in the database.
Returns: Returns:
A deferred "t%d-%d" topological token. A "t%d-%d" topological token.
""" """
return self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
table="events", table="events",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"), retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event", desc="get_topological_token_for_event",
).addCallback(
lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
) )
return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
def get_max_topological_token(self, room_id, stream_key): async def get_max_topological_token(self, room_id: str, stream_key: int) -> int:
"""Get the max topological token in a room before the given stream """Get the max topological token in a room before the given stream
ordering. ordering.
Args: Args:
room_id (str) room_id
stream_key (int) stream_key
Returns: Returns:
Deferred[int] The maximum topological token.
""" """
sql = ( sql = (
"SELECT coalesce(max(topological_ordering), 0) FROM events" "SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?" " WHERE room_id = ? AND stream_ordering < ?"
) )
return self.db_pool.execute( row = await self.db_pool.execute(
"get_max_topological_token", None, sql, room_id, stream_key "get_max_topological_token", None, sql, room_id, stream_key
).addCallback(lambda r: r[0][0] if r else 0) )
return row[0][0] if row else 0
def _get_max_topological_txn(self, txn, room_id): def _get_max_topological_txn(self, txn, room_id):
txn.execute( txn.execute(
@ -634,16 +643,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows[0][0] if rows else 0 return rows[0][0] if rows else 0
@staticmethod @staticmethod
def _set_before_and_after(events, rows, topo_order=True): def _set_before_and_after(
events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
):
"""Inserts ordering information to events' internal metadata from """Inserts ordering information to events' internal metadata from
the DB rows. the DB rows.
Args: Args:
events (list[FrozenEvent]) events
rows (list[_EventDictReturn]) rows
topo_order (bool): Whether the events were ordered topologically topo_order: Whether the events were ordered topologically or by stream
or by stream ordering. If true then all rows should have a non ordering. If true then all rows should have a non null
null topological_ordering. topological_ordering.
""" """
for event, row in zip(events, rows): for event, row in zip(events, rows):
stream = row.stream_ordering stream = row.stream_ordering
@ -656,25 +667,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal.after = str(RoomStreamToken(topo, stream)) internal.after = str(RoomStreamToken(topo, stream))
internal.order = (int(topo) if topo else 0, int(stream)) internal.order = (int(topo) if topo else 0, int(stream))
@defer.inlineCallbacks async def get_events_around(
def get_events_around( self,
self, room_id, event_id, before_limit, after_limit, event_filter=None room_id: str,
): event_id: str,
before_limit: int,
after_limit: int,
event_filter: Optional[Filter] = None,
) -> dict:
"""Retrieve events and pagination tokens around a given event in a """Retrieve events and pagination tokens around a given event in a
room. room.
Args:
room_id (str)
event_id (str)
before_limit (int)
after_limit (int)
event_filter (Filter|None)
Returns:
dict
""" """
results = yield self.db_pool.runInteraction( results = await self.db_pool.runInteraction(
"get_events_around", "get_events_around",
self._get_events_around_txn, self._get_events_around_txn,
room_id, room_id,
@ -684,11 +689,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter, event_filter,
) )
events_before = yield self.get_events_as_list( events_before = await self.get_events_as_list(
list(results["before"]["event_ids"]), get_prev_content=True list(results["before"]["event_ids"]), get_prev_content=True
) )
events_after = yield self.get_events_as_list( events_after = await self.get_events_as_list(
list(results["after"]["event_ids"]), get_prev_content=True list(results["after"]["event_ids"]), get_prev_content=True
) )
@ -700,17 +705,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
} }
def _get_events_around_txn( def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter self,
): txn,
room_id: str,
event_id: str,
before_limit: int,
after_limit: int,
event_filter: Optional[Filter],
) -> dict:
"""Retrieves event_ids and pagination tokens around a given event in a """Retrieves event_ids and pagination tokens around a given event in a
room. room.
Args: Args:
room_id (str) room_id
event_id (str) event_id
before_limit (int) before_limit
after_limit (int) after_limit
event_filter (Filter|None) event_filter
Returns: Returns:
dict dict
@ -758,22 +769,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token}, "after": {"event_ids": events_after, "token": end_token},
} }
@defer.inlineCallbacks async def get_all_new_events_stream(
def get_all_new_events_stream(self, from_id, current_id, limit): self, from_id: int, current_id: int, limit: int
) -> Tuple[int, List[EventBase]]:
"""Get all new events """Get all new events
Returns all events with from_id < stream_ordering <= current_id. Returns all events with from_id < stream_ordering <= current_id.
Args: Args:
from_id (int): the stream_ordering of the last event we processed from_id: the stream_ordering of the last event we processed
current_id (int): the stream_ordering of the most recently processed event current_id: the stream_ordering of the most recently processed event
limit (int): the maximum number of events to return limit: the maximum number of events to return
Returns: Returns:
Deferred[Tuple[int, list[FrozenEvent]]]: A tuple of (next_id, events), where A tuple of (next_id, events), where `next_id` is the next value to
`next_id` is the next value to pass as `from_id` (it will either be the pass as `from_id` (it will either be the stream_ordering of the
stream_ordering of the last returned event, or, if fewer than `limit` events last returned event, or, if fewer than `limit` events were found,
were found, `current_id`. the `current_id`).
""" """
def get_all_new_events_stream_txn(txn): def get_all_new_events_stream_txn(txn):
@ -795,11 +807,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows] return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.db_pool.runInteraction( upper_bound, event_ids = await self.db_pool.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn "get_all_new_events_stream", get_all_new_events_stream_txn
) )
events = yield self.get_events_as_list(event_ids) events = await self.get_events_as_list(event_ids)
return upper_bound, events return upper_bound, events
@ -817,21 +829,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="get_federation_out_pos", desc="get_federation_out_pos",
) )
async def update_federation_out_pos(self, typ, stream_id): async def update_federation_out_pos(self, typ: str, stream_id: int) -> None:
if self._need_to_reset_federation_stream_positions: if self._need_to_reset_federation_stream_positions:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_reset_federation_positions_txn", self._reset_federation_positions_txn "_reset_federation_positions_txn", self._reset_federation_positions_txn
) )
self._need_to_reset_federation_stream_positions = False self._need_to_reset_federation_stream_positions = False
return await self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="federation_stream_position", table="federation_stream_position",
keyvalues={"type": typ, "instance_name": self._instance_name}, keyvalues={"type": typ, "instance_name": self._instance_name},
updatevalues={"stream_id": stream_id}, updatevalues={"stream_id": stream_id},
desc="update_federation_out_pos", desc="update_federation_out_pos",
) )
def _reset_federation_positions_txn(self, txn): def _reset_federation_positions_txn(self, txn) -> None:
"""Fiddles with the `federation_stream_position` table to make it match """Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up. the configured federation sender instances during start up.
""" """
@ -892,39 +904,37 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
values={"stream_id": stream_id}, values={"stream_id": stream_id},
) )
def has_room_changed_since(self, room_id, stream_id): def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
return self._events_stream_cache.has_entity_changed(room_id, stream_id) return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn( def _paginate_room_events_txn(
self, self,
txn, txn,
room_id, room_id: str,
from_token, from_token: RoomStreamToken,
to_token=None, to_token: Optional[RoomStreamToken] = None,
direction="b", direction: str = "b",
limit=-1, limit: int = -1,
event_filter=None, event_filter: Optional[Filter] = None,
): ) -> Tuple[List[_EventDictReturn], str]:
"""Returns list of events before or after a given token. """Returns list of events before or after a given token.
Args: Args:
txn txn
room_id (str) room_id
from_token (RoomStreamToken): The token used to stream from from_token: The token used to stream from
to_token (RoomStreamToken|None): A token which if given limits the to_token: A token which if given limits the results to only those before
results to only those before direction: Either 'b' or 'f' to indicate whether we are paginating
direction(char): Either 'b' or 'f' to indicate whether we are forwards or backwards from `from_key`.
paginating forwards or backwards from `from_key`. limit: The maximum number of events to return.
limit (int): The maximum number of events to return. event_filter: If provided filters the events to
event_filter (Filter|None): If provided filters the events to
those that match the filter. those that match the filter.
Returns: Returns:
Deferred[tuple[list[_EventDictReturn], str]]: Returns the results A list of _EventDictReturn and a token that points to the end of the
as a list of _EventDictReturn and a token that points to the end result set. If no events are returned then the end of the stream has
of the result set. If no events are returned then the end of the been reached (i.e. there are no events between `from_token` and
stream has been reached (i.e. there are no events between `to_token`), or `limit` is zero.
`from_token` and `to_token`), or `limit` is zero.
""" """
assert int(limit) >= 0 assert int(limit) >= 0
@ -1008,35 +1018,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, str(next_token) return rows, str(next_token)
@defer.inlineCallbacks async def paginate_room_events(
def paginate_room_events( self,
self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None room_id: str,
): from_key: str,
to_key: Optional[str] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], str]:
"""Returns list of events before or after a given token. """Returns list of events before or after a given token.
Args: Args:
room_id (str) room_id
from_key (str): The token used to stream from from_key: The token used to stream from
to_key (str|None): A token which if given limits the results to to_key: A token which if given limits the results to only those before
only those before direction: Either 'b' or 'f' to indicate whether we are paginating
direction(char): Either 'b' or 'f' to indicate whether we are forwards or backwards from `from_key`.
paginating forwards or backwards from `from_key`. limit: The maximum number of events to return.
limit (int): The maximum number of events to return. event_filter: If provided filters the events to those that match the filter.
event_filter (Filter|None): If provided filters the events to
those that match the filter.
Returns: Returns:
tuple[list[FrozenEvent], str]: Returns the results as a list of The results as a list of events and a token that points to the end
events and a token that points to the end of the result set. If no of the result set. If no events are returned then the end of the
events are returned then the end of the stream has been reached stream has been reached (i.e. there are no events between `from_key`
(i.e. there are no events between `from_key` and `to_key`). and `to_key`).
""" """
from_key = RoomStreamToken.parse(from_key) from_key = RoomStreamToken.parse(from_key)
if to_key: if to_key:
to_key = RoomStreamToken.parse(to_key) to_key = RoomStreamToken.parse(to_key)
rows, token = yield self.db_pool.runInteraction( rows, token = await self.db_pool.runInteraction(
"paginate_room_events", "paginate_room_events",
self._paginate_room_events_txn, self._paginate_room_events_txn,
room_id, room_id,
@ -1047,7 +1060,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter, event_filter,
) )
events = yield self.get_events_as_list( events = await self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True [r.event_id for r in rows], get_prev_content=True
) )
@ -1057,8 +1070,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
class StreamStore(StreamWorkerStore): class StreamStore(StreamWorkerStore):
def get_room_max_stream_ordering(self): def get_room_max_stream_ordering(self) -> int:
return self._stream_id_gen.get_current_token() return self._stream_id_gen.get_current_token()
def get_room_min_stream_ordering(self): def get_room_min_stream_ordering(self) -> int:
return self._backfill_id_gen.get_current_token() return self._backfill_id_gen.get_current_token()

View File

@ -19,6 +19,7 @@ from mock import Mock, call
from signedjson.key import generate_signing_key from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events.builder import EventBuilder from synapse.events.builder import EventBuilder
from synapse.handlers.presence import ( from synapse.handlers.presence import (
@ -32,7 +33,6 @@ from synapse.handlers.presence import (
handle_update, handle_update,
) )
from synapse.rest.client.v1 import room from synapse.rest.client.v1 import room
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from tests import unittest from tests import unittest

View File

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import NotFoundError
from synapse.rest.client.v1 import room from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -46,30 +47,19 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage() storage = self.hs.get_storage()
# Get the topological token # Get the topological token
event = store.get_topological_token_for_event(last["event_id"]) event = self.get_success(
self.pump() store.get_topological_token_for_event(last["event_id"])
event = self.successResultOf(event) )
# Purge everything before this topological token # Purge everything before this topological token
purge = defer.ensureDeferred( self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
storage.purge_events.purge_history(self.room_id, event, True)
)
self.pump()
self.assertEqual(self.successResultOf(purge), None)
# Try and get the events
get_first = store.get_event(first["event_id"])
get_second = store.get_event(second["event_id"])
get_third = store.get_event(third["event_id"])
get_last = store.get_event(last["event_id"])
self.pump()
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not. # and last is not.
self.failureResultOf(get_first) self.get_failure(store.get_event(first["event_id"]), NotFoundError)
self.failureResultOf(get_second) self.get_failure(store.get_event(second["event_id"]), NotFoundError)
self.failureResultOf(get_third) self.get_failure(store.get_event(third["event_id"]), NotFoundError)
self.successResultOf(get_last) self.get_success(store.get_event(last["event_id"]))
def test_purge_wont_delete_extrems(self): def test_purge_wont_delete_extrems(self):
""" """
@ -84,9 +74,9 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore() storage = self.hs.get_datastore()
# Set the topological token higher than it should be # Set the topological token higher than it should be
event = storage.get_topological_token_for_event(last["event_id"]) event = self.get_success(
self.pump() storage.get_topological_token_for_event(last["event_id"])
event = self.successResultOf(event) )
event = "t{}-{}".format( event = "t{}-{}".format(
*list(map(lambda x: x + 1, map(int, event[1:].split("-")))) *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
) )
@ -98,14 +88,7 @@ class PurgeTests(HomeserverTestCase):
self.assertIn("greater than forward", f.value.args[0]) self.assertIn("greater than forward", f.value.args[0])
# Try and get the events # Try and get the events
get_first = storage.get_event(first["event_id"]) self.get_success(storage.get_event(first["event_id"]))
get_second = storage.get_event(second["event_id"]) self.get_success(storage.get_event(second["event_id"]))
get_third = storage.get_event(third["event_id"]) self.get_success(storage.get_event(third["event_id"]))
get_last = storage.get_event(last["event_id"]) self.get_success(storage.get_event(last["event_id"]))
self.pump()
# Nothing is deleted.
self.successResultOf(get_first)
self.successResultOf(get_second)
self.successResultOf(get_third)
self.successResultOf(get_last)