Revert "Revert accidental fast-forward merge from v1.49.0rc1"

This reverts commit 158d73ebdd.
This commit is contained in:
Olivier Wilkinson (reivilibre) 2021-12-14 14:22:01 +00:00
parent 158d73ebdd
commit 4dd9ea8f4f
165 changed files with 7715 additions and 2703 deletions

View file

@ -15,14 +15,18 @@
import logging
import threading
from typing import (
TYPE_CHECKING,
Any,
Collection,
Container,
Dict,
Iterable,
List,
NoReturn,
Optional,
Set,
Tuple,
cast,
overload,
)
@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# These values are used in the `enqueus_event` and `_do_fetch` methods to
# These values are used in the `enqueue_event` and `_fetch_loop` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
class _EventCacheEntry:
class EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@ -129,7 +145,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
room_version_id: Optional[int]
room_version_id: Optional[str]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._stream_id_gen: AbstractStreamIdTracker
self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
self._get_event_cache = LruCache(
self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
str, ObservableDeferred[Dict[str, _EventCacheEntry]]
str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_list: List[
Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
] = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
def get_chain_id_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
return cast(Tuple[int], txn.fetchone())[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[False] = False,
check_room_id: Optional[str] = None,
get_prev_content: bool = ...,
allow_rejected: bool = ...,
allow_none: Literal[False] = ...,
check_room_id: Optional[str] = ...,
) -> EventBase:
...
@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[True] = False,
check_room_id: Optional[str] = None,
get_prev_content: bool = ...,
allow_rejected: bool = ...,
allow_none: Literal[True] = ...,
check_room_id: Optional[str] = ...,
) -> Optional[EventBase]:
...
@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
event_ids: Iterable[str],
event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Dict[str, _EventCacheEntry]:
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
ObservableDeferred[Dict[str, _EventCacheEntry]]
ObservableDeferred[Dict[str, EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, _EventCacheEntry]
] = ObservableDeferred(defer.Deferred())
Dict[str, EventCacheEntry]
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
def _invalidate_get_event_cache(self, event_id):
def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, _EventCacheEntry]:
) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
def _do_fetch(self, conn: Connection) -> None:
def _maybe_start_fetch_thread(self) -> None:
"""Starts an event fetch thread if we are not yet at the maximum number."""
with self._event_fetch_lock:
if (
self._event_fetch_list
and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
):
self._event_fetch_ongoing += 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# `_event_fetch_ongoing` is decremented in `_fetch_thread`.
should_start = True
else:
should_start = False
if should_start:
run_as_background_process("fetch_events", self._fetch_thread)
async def _fetch_thread(self) -> None:
"""Services requests for events from `_event_fetch_list`."""
exc = None
try:
await self.db_pool.runWithConnection(self._fetch_loop)
except BaseException as e:
exc = e
raise
finally:
should_restart = False
event_fetches_to_fail = []
with self._event_fetch_lock:
self._event_fetch_ongoing -= 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# There may still be work remaining in `_event_fetch_list` if we
# failed, or it was added in between us deciding to exit and
# decrementing `_event_fetch_ongoing`.
if self._event_fetch_list:
if exc is None:
# We decided to exit, but then some more work was added
# before `_event_fetch_ongoing` was decremented.
# If a new event fetch thread was not started, we should
# restart ourselves since the remaining event fetch threads
# may take a while to get around to the new work.
#
# Unfortunately it is not possible to tell whether a new
# event fetch thread was started, so we restart
# unconditionally. If we are unlucky, we will end up with
# an idle fetch thread, but it will time out after
# `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
# in any case.
#
# Note that multiple fetch threads may run down this path at
# the same time.
should_restart = True
elif isinstance(exc, Exception):
if self._event_fetch_ongoing == 0:
# We were the last remaining fetcher and failed.
# Fail any outstanding fetches since no one else will
# handle them.
event_fetches_to_fail = self._event_fetch_list
self._event_fetch_list = []
else:
# We weren't the last remaining fetcher, so another
# fetcher will pick up the work. This will either happen
# after their existing work, however long that takes,
# or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
# they are idle.
pass
else:
# The exception is a `SystemExit`, `KeyboardInterrupt` or
# `GeneratorExit`. Don't try to do anything clever here.
pass
if should_restart:
# We exited cleanly but noticed more work.
self._maybe_start_fetch_thread()
if event_fetches_to_fail:
# We were the last remaining fetcher and failed.
# Fail any outstanding fetches since no one else will handle them.
assert exc is not None
with PreserveLoggingContext():
for _, deferred in event_fetches_to_fail:
deferred.errback(exc)
def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
try:
i = 0
while True:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
i = 0
while True:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
if not event_list:
single_threaded = self.database_engine.single_threaded
if (
not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
or single_threaded
or i > EVENT_QUEUE_ITERATIONS
):
break
else:
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
if not event_list:
# There are no requests waiting. If we haven't yet reached the
# maximum iteration limit, wait for some more requests to turn up.
# Otherwise, bail out.
single_threaded = self.database_engine.single_threaded
if (
not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
or single_threaded
or i > EVENT_QUEUE_ITERATIONS
):
return
self._fetch_event_list(conn, event_list)
finally:
self._event_fetch_ongoing -= 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
self._fetch_event_list(conn, event_list)
def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
self,
conn: LoggingDatabaseConnection,
event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
def fire():
def fire() -> None:
for _, d in event_list:
d.callback(row_dict)
@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(exc)
def fire_errback(exc: Exception) -> None:
for _, d in event_list:
d.errback(exc)
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
self.hs.get_reactor().callFromThread(fire_errback, e)
async def _get_events_from_db(
self, event_ids: Iterable[str]
) -> Dict[str, _EventCacheEntry]:
self, event_ids: Collection[str]
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
fetched_events = {}
fetched_event_ids: Set[str] = set()
fetched_events: Dict[str, _EventRow] = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
redaction_ids: Set[str] = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
fetched_events[event_id] = row
fetched_event_ids.add(event_id)
if row:
fetched_events[event_id] = row
redaction_ids.update(row.redactions)
events_to_fetch = redaction_ids.difference(fetched_events.keys())
events_to_fetch = redaction_ids.difference(fetched_event_ids)
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
event_map = {}
event_map: Dict[str, EventBase] = {}
for event_id, row in fetched_events.items():
if not row:
continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
result_map = {}
result_map: Dict[str, EventCacheEntry] = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
cache_entry = _EventCacheEntry(
cache_entry = EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
events_d = defer.Deferred()
events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
self._event_fetch_lock.notify()
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
should_start = True
else:
should_start = False
if should_start:
run_as_background_process(
"fetch_events", self.db_pool.runWithConnection, self._do_fetch
)
self._maybe_start_fetch_thread()
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
async def have_events_in_timeline(self, event_ids):
async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
set[str]: The events we have already seen.
The set of events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
def have_seen_events_txn(
txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str):
async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
def _get_current_state_event_counts_txn(self, txn, room_id):
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
) -> int:
"""
See get_current_state_event_counts.
"""
@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
async def get_room_complexity(self, room_id):
async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
room_id (str)
room_id: The room ID to query.
Returns:
dict[str:int] of complexity version to complexity.
dict[str:float] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
def get_current_events_token(self):
def get_current_events_token(self) -> int:
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@ -1295,13 +1403,15 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
def get_all_new_forward_event_rows(txn):
def get_all_new_forward_event_rows(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
return txn.fetchall()
return cast(
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
)
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
) -> List[Tuple]:
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@ -1332,14 +1444,16 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
def get_ex_outlier_stream_rows_txn(txn):
def get_ex_outlier_stream_rows_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
return txn.fetchall()
return cast(
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
)
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]:
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@ -1386,13 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
def get_all_new_backfill_event_rows(txn):
def get_all_new_backfill_event_rows(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" se.state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" AND instance_name = ?"
@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
new_event_updates = [(row[0], row[1:]) for row in txn]
new_event_updates: List[
Tuple[int, Tuple[str, str, str, str, str, str]]
] = []
row: Tuple[int, str, str, str, str, str, str]
# Type safety: iterating over `txn` yields `Tuple`, i.e.
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
# variadic tuple to a fixed length tuple and flags it up as an error.
for row in txn: # type: ignore[assignment]
new_event_updates.append((row[0], row[1:]))
limited = False
if len(new_event_updates) == limit:
@ -1411,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore):
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" se.state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
new_event_updates.extend((row[0], row[1:]) for row in txn)
# Type safety: iterating over `txn` yields `Tuple`, i.e.
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
# variadic tuple to a fixed length tuple and flags it up as an error.
for row in txn: # type: ignore[assignment]
new_event_updates.append((row[0], row[1:]))
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
def get_all_updated_current_state_deltas_txn(txn):
def get_all_updated_current_state_deltas_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
return txn.fetchall()
return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
def get_deltas_for_stream_id_txn(txn, stream_id):
def get_deltas_for_stream_id_txn(
txn: LoggingTransaction, stream_id: int
) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
return txn.fetchall()
return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
rows: List[Tuple] = await self.db_pool.runInteraction(
rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
async def is_event_after(self, event_id1, event_id2):
async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
async def get_event_ordering(self, event_id):
async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
def get_next_event_to_expire_txn(txn):
def get_next_event_to_expire_txn(
txn: LoggingTransaction,
) -> Optional[Tuple[str, int]]:
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
return txn.fetchone()
return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
async def _cleanup_old_transaction_ids(self):
async def _cleanup_old_transaction_ids(self) -> None:
"""Cleans out transaction id mappings older than 24hrs."""
def _cleanup_old_transaction_ids_txn(txn):
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
@ -1626,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore):
"_cleanup_old_transaction_ids",
_cleanup_old_transaction_ids_txn,
)
async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
"""Check if the given event is next to a backward gap of missing events.
<latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages>
Args:
room_id: room where the event lives
event_id: event to check
Returns:
Boolean indicating whether it's an extremity
"""
def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
# If the event in question has any of its prev_events listed as a
# backward extremity, it's next to a gap.
#
# We can't just check the backward edges in `event_edges` because
# when we persist events, we will also record the prev_events as
# edges to the event in question regardless of whether we have those
# prev_events yet. We need to check whether those prev_events are
# backward extremities, also known as gaps, that need to be
# backfilled.
backward_extremity_query = """
SELECT 1 FROM event_backward_extremities
WHERE
room_id = ?
AND %s
LIMIT 1
"""
# If the event in question is a backward extremity or has any of its
# prev_events listed as a backward extremity, it's next to a
# backward gap.
clause, args = make_in_list_sql_clause(
self.database_engine,
"event_id",
[event.event_id] + list(event.prev_event_ids()),
)
txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
backward_extremities = txn.fetchall()
# We consider any backward extremity as a backward gap
if len(backward_extremities):
return True
return False
return await self.db_pool.runInteraction(
"is_event_next_to_backward_gap_txn",
is_event_next_to_backward_gap_txn,
)
async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
"""Check if the given event is next to a forward gap of missing events.
The gap in front of the latest events is not considered a gap.
<latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages>
<latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages>
Args:
room_id: room where the event lives
event_id: event to check
Returns:
Boolean indicating whether it's an extremity
"""
def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
# If the event in question is a forward extremity, we will just
# consider any potential forward gap as not a gap since it's one of
# the latest events in the room.
#
# `event_forward_extremities` does not include backfilled or outlier
# events so we can't rely on it to find forward gaps. We can only
# use it to determine whether a message is the latest in the room.
#
# We can't combine this query with the `forward_edge_query` below
# because if the event in question has no forward edges (isn't
# referenced by any other event's prev_events) but is in
# `event_forward_extremities`, we don't want to return 0 rows and
# say it's next to a gap.
forward_extremity_query = """
SELECT 1 FROM event_forward_extremities
WHERE
room_id = ?
AND event_id = ?
LIMIT 1
"""
# Check to see whether the event in question is already referenced
# by another event. If we don't see any edges, we're next to a
# forward gap.
forward_edge_query = """
SELECT 1 FROM event_edges
/* Check to make sure the event referencing our event in question is not rejected */
LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
WHERE
event_edges.room_id = ?
AND event_edges.prev_event_id = ?
/* It's not a valid edge if the event referencing our event in
* question is rejected.
*/
AND rejections.event_id IS NULL
LIMIT 1
"""
# We consider any forward extremity as the latest in the room and
# not a forward gap.
#
# To expand, even though there is technically a gap at the front of
# the room where the forward extremities are, we consider those the
# latest messages in the room so asking other homeservers for more
# is useless. The new latest messages will just be federated as
# usual.
txn.execute(forward_extremity_query, (event.room_id, event.event_id))
forward_extremities = txn.fetchall()
if len(forward_extremities):
return False
# If there are no forward edges to the event in question (another
# event hasn't referenced this event in their prev_events), then we
# assume there is a forward gap in the history.
txn.execute(forward_edge_query, (event.room_id, event.event_id))
forward_edges = txn.fetchall()
if not len(forward_edges):
return True
return False
return await self.db_pool.runInteraction(
"is_event_next_to_gap_txn",
is_event_next_to_gap_txn,
)
async def get_event_id_for_timestamp(
self, room_id: str, timestamp: int, direction: str
) -> Optional[str]:
"""Find the closest event to the given timestamp in the given direction.
Args:
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
The closest event_id otherwise None if we can't find any event in
the given direction.
"""
sql_template = """
SELECT event_id FROM events
LEFT JOIN rejections USING (event_id)
WHERE
origin_server_ts %s ?
AND room_id = ?
/* Make sure event is not rejected */
AND rejections.event_id IS NULL
ORDER BY origin_server_ts %s
LIMIT 1;
"""
def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
if direction == "b":
# Find closest event *before* a given timestamp. We use descending
# (which gives values largest to smallest) because we want the
# largest possible timestamp *before* the given timestamp.
comparison_operator = "<="
order = "DESC"
else:
# Find closest event *after* a given timestamp. We use ascending
# (which gives values smallest to largest) because we want the
# closest possible timestamp *after* the given timestamp.
comparison_operator = ">="
order = "ASC"
txn.execute(
sql_template % (comparison_operator, order), (timestamp, room_id)
)
row = txn.fetchone()
if row:
(event_id,) = row
return event_id
return None
if direction not in ("f", "b"):
raise ValueError("Unknown direction: %s" % (direction,))
return await self.db_pool.runInteraction(
"get_event_id_for_timestamp_txn",
get_event_id_for_timestamp_txn,
)