Revert "Revert accidental fast-forward merge from v1.49.0rc1"

This reverts commit 158d73ebdd.
This commit is contained in:
Olivier Wilkinson (reivilibre) 2021-12-14 14:22:01 +00:00
parent 158d73ebdd
commit 4dd9ea8f4f
165 changed files with 7715 additions and 2703 deletions

View file

@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.types import StreamToken, get_domain_from_id
from synapse.types import get_domain_from_id
from synapse.util import json_decoder
if TYPE_CHECKING:
@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
self,
stream_name: str,
instance_name: str,
token: StreamToken,
token: int,
rows: Iterable[Any],
) -> None:
pass

View file

@ -12,12 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
from typing import (
TYPE_CHECKING,
AsyncContextManager,
Awaitable,
Callable,
Dict,
Iterable,
Optional,
)
import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util import Clock, json_encoder
from . import engines
@ -28,6 +38,45 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _BackgroundUpdateHandler:
"""A handler for a given background update.
Attributes:
callback: The function to call to make progress on the background
update.
oneshot: Wether the update is likely to happen all in one go, ignoring
the supplied target duration, e.g. index creation. This is used by
the update controller to help correctly schedule the update.
"""
callback: Callable[[JsonDict, int], Awaitable[int]]
oneshot: bool = False
class _BackgroundUpdateContextManager:
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, sleep: bool, clock: Clock):
self._sleep = sleep
self._clock = clock
async def __aenter__(self) -> int:
if self._sleep:
await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
return self.BACKGROUND_UPDATE_DURATION_MS
async def __aexit__(self, *exc) -> None:
pass
class BackgroundUpdatePerformance:
"""Tracks the how long a background update is taking to update its items"""
@ -84,20 +133,22 @@ class BackgroundUpdater:
MINIMUM_BACKGROUND_BATCH_SIZE = 1
DEFAULT_BACKGROUND_BATCH_SIZE = 100
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
self._clock = hs.get_clock()
self.db_pool = database
self._database_name = database.name()
# if a background update is currently running, its name.
self._current_background_update: Optional[str] = None
self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
self._background_update_handlers: Dict[
str, Callable[[JsonDict, int], Awaitable[int]]
] = {}
self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
self._all_done = False
# Whether we're currently running updates
@ -107,6 +158,83 @@ class BackgroundUpdater:
# enable/disable background updates via the admin API.
self.enabled = True
def register_update_controller_callbacks(
self,
on_update: ON_UPDATE_CALLBACK,
default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
) -> None:
"""Register callbacks from a module for each hook."""
if self._on_update_callback is not None:
logger.warning(
"More than one module tried to register callbacks for controlling"
" background updates. Only the callbacks registered by the first module"
" (in order of appearance in Synapse's configuration file) that tried to"
" do so will be called."
)
return
self._on_update_callback = on_update
if default_batch_size is not None:
self._default_batch_size_callback = default_batch_size
if min_batch_size is not None:
self._min_batch_size_callback = min_batch_size
def _get_context_manager_for_update(
self,
sleep: bool,
update_name: str,
database_name: str,
oneshot: bool,
) -> AsyncContextManager[int]:
"""Get a context manager to run a background update with.
If a module has registered a `update_handler` callback, use the context manager
it returns.
Otherwise, returns a context manager that will return a default value, optionally
sleeping if needed.
Args:
sleep: Whether we can sleep between updates.
update_name: The name of the update.
database_name: The name of the database the update is being run on.
oneshot: Whether the update will complete all in one go, e.g. index creation.
In such cases the returned target duration is ignored.
Returns:
The target duration in milliseconds that the background update should run for.
Note: this is a *target*, and an iteration may take substantially longer or
shorter.
"""
if self._on_update_callback is not None:
return self._on_update_callback(update_name, database_name, oneshot)
return _BackgroundUpdateContextManager(sleep, self._clock)
async def _default_batch_size(self, update_name: str, database_name: str) -> int:
"""The batch size to use for the first iteration of a new background
update.
"""
if self._default_batch_size_callback is not None:
return await self._default_batch_size_callback(update_name, database_name)
return self.DEFAULT_BACKGROUND_BATCH_SIZE
async def _min_batch_size(self, update_name: str, database_name: str) -> int:
"""A lower bound on the batch size of a new background update.
Used to ensure that progress is always made. Must be greater than 0.
"""
if self._min_batch_size_callback is not None:
return await self._min_batch_size_callback(update_name, database_name)
return self.MINIMUM_BACKGROUND_BATCH_SIZE
def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
"""Returns the current background update, if any."""
@ -135,13 +263,8 @@ class BackgroundUpdater:
try:
logger.info("Starting background schema updates")
while self.enabled:
if sleep:
await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
try:
result = await self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
result = await self.do_next_background_update(sleep)
except Exception:
logger.exception("Error doing update")
else:
@ -203,13 +326,15 @@ class BackgroundUpdater:
return not update_exists
async def do_next_background_update(self, desired_duration_ms: float) -> bool:
async def do_next_background_update(self, sleep: bool = True) -> bool:
"""Does some amount of work on the next queued background update
Returns once some amount of work is done.
Args:
desired_duration_ms: How long we want to spend updating.
sleep: Whether to limit how quickly we run background updates or
not.
Returns:
True if we have finished running all the background updates, otherwise False
"""
@ -252,7 +377,19 @@ class BackgroundUpdater:
self._current_background_update = upd["update_name"]
await self._do_background_update(desired_duration_ms)
# We have a background update to run, otherwise we would have returned
# early.
assert self._current_background_update is not None
update_info = self._background_update_handlers[self._current_background_update]
async with self._get_context_manager_for_update(
sleep=sleep,
update_name=self._current_background_update,
database_name=self._database_name,
oneshot=update_info.oneshot,
) as desired_duration_ms:
await self._do_background_update(desired_duration_ms)
return False
async def _do_background_update(self, desired_duration_ms: float) -> int:
@ -260,7 +397,7 @@ class BackgroundUpdater:
update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
update_handler = self._background_update_handlers[update_name]
update_handler = self._background_update_handlers[update_name].callback
performance = self._background_update_performance.get(update_name)
@ -273,9 +410,14 @@ class BackgroundUpdater:
if items_per_ms is not None:
batch_size = int(desired_duration_ms * items_per_ms)
# Clamp the batch size so that we always make progress
batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
batch_size = max(
batch_size,
await self._min_batch_size(update_name, self._database_name),
)
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
batch_size = await self._default_batch_size(
update_name, self._database_name
)
progress_json = await self.db_pool.simple_select_one_onecol(
"background_updates",
@ -294,6 +436,8 @@ class BackgroundUpdater:
duration_ms = time_stop - time_start
performance.update(items_updated, duration_ms)
logger.info(
"Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
@ -306,8 +450,6 @@ class BackgroundUpdater:
batch_size,
)
performance.update(items_updated, duration_ms)
return len(self._background_update_performance)
def register_background_update_handler(
@ -331,7 +473,9 @@ class BackgroundUpdater:
update_name: The name of the update that this code handles.
update_handler: The function that does the update.
"""
self._background_update_handlers[update_name] = update_handler
self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
update_handler
)
def register_noop_background_update(self, update_name: str) -> None:
"""Register a noop handler for a background update.
@ -453,7 +597,9 @@ class BackgroundUpdater:
await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, updater)
self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
updater, oneshot=True
)
async def _end_background_update(self, update_name: str) -> None:
"""Removes a completed background update task from the queue.

View file

@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore(
A list of ApplicationServices, which may be empty.
"""
results = await self.db_pool.simple_select_list(
"application_services_state", {"state": state}, ["as_id"]
"application_services_state", {"state": state.value}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore(
desc="get_appservice_state",
)
if result:
return result.get("state")
return ApplicationServiceState(result.get("state"))
return None
async def set_appservice_state(
@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore(
state: The connectivity state to apply.
"""
await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state}
"application_services_state", {"as_id": service.id}, {"state": state.value}
)
async def create_appservice_txn(

View file

@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
) -> List[Dict[str, Any]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.
Args:
auth_provider_id: The SSO IdP ID as defined in the server config
auth_provider_session_id: The session ID within the IdP
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
return await self.db_pool.simple_select_list(
table="device_auth_providers",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id",
)
@trace
async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int
@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def store_device(
self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
self,
user_id: str,
device_id: str,
initial_device_display_name: Optional[str],
auth_provider_id: Optional[str] = None,
auth_provider_session_id: Optional[str] = None,
) -> bool:
"""Ensure the given device is known; add it to the store if not
@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: id of device
initial_device_display_name: initial displayname of the device.
Ignored if device exists.
auth_provider_id: The SSO IdP the user used, if any.
auth_provider_session_id: The session ID (sid) got from a OIDC login.
Returns:
Whether the device was inserted or an existing device existed with that ID.
@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
if auth_provider_id and auth_provider_session_id:
await self.db_pool.simple_insert(
"device_auth_providers",
values={
"user_id": user_id,
"device_id": device_id,
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
desc="store_device_auth_provider",
)
self.device_id_exists_cache.set(key, True)
return inserted
except StoreError:
@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
keyvalues={"user_id": user_id},
)
self.db_pool.simple_delete_many_txn(
txn,
table="device_auth_providers",
column="device_id",
values=device_ids,
keyvalues={"user_id": user_id},
)
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))

View file

@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore):
DELETE FROM event_auth
WHERE event_id IN (
SELECT event_id FROM events
LEFT JOIN state_events USING (room_id, event_id)
LEFT JOIN state_events AS se USING (room_id, event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
AND state_key IS null
AND se.state_key IS null
)
"""

View file

@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
from typing_extensions import TypedDict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [
]
class BasePushAction(TypedDict):
event_id: str
actions: List[Union[dict, str]]
class HttpPushAction(BasePushAction):
room_id: str
stream_ordering: int
class EmailPushAction(HttpPushAction):
received_ts: Optional[int]
def _serialize_action(actions, is_highlight):
"""Custom serializer for actions. This allows us to "compress" common actions.
@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
) -> List[dict]:
) -> List[HttpPushAction]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
) -> List[dict]:
) -> List[EmailPushAction]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher

View file

@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
from collections import OrderedDict, namedtuple
from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
@ -64,9 +65,6 @@ event_counter = Counter(
)
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@ -108,23 +106,30 @@ class PersistEventsStore:
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
# Since we have been configured to write, we ought to have id generators,
# rather than id trackers.
assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
*,
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False,
) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@ -137,7 +142,14 @@ class PersistEventsStore:
room state
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
backfilled
use_negative_stream_ordering: Whether to start stream_ordering on
the negative side and decrement. This should be set as True
for backfilled events because backfilled events get a negative
stream ordering so they don't come down incremental `/sync`.
inhibit_local_membership_updates: Stop the local_current_membership
from being updated by these events. This should be set to True
for backfilled events because backfilled events in the past do
not affect the current local state.
Returns:
Resolves when the events have been persisted
@ -159,7 +171,7 @@ class PersistEventsStore:
#
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
if use_negative_stream_ordering:
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
@ -176,13 +188,13 @@ class PersistEventsStore:
"persist_events",
self._persist_events_txn,
events_and_contexts=events_and_contexts,
backfilled=backfilled,
inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc(len(events_and_contexts))
if not backfilled:
if stream < 0:
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
synapse.metrics.event_persisted_position.set(
@ -316,8 +328,9 @@ class PersistEventsStore:
def _persist_events_txn(
self,
txn: LoggingTransaction,
*,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
inhibit_local_membership_updates: bool = False,
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
):
@ -330,7 +343,10 @@ class PersistEventsStore:
Args:
txn
events_and_contexts: events to persist
backfilled: True if the events were backfilled
inhibit_local_membership_updates: Stop the local_current_membership
from being updated by these events. This should be set to True
for backfilled events because backfilled events in the past do
not affect the current local state.
delete_existing True to purge existing table rows for the events
from the database. This is useful when retrying due to
IntegrityError.
@ -363,9 +379,7 @@ class PersistEventsStore:
events_and_contexts
)
self._update_room_depths_txn(
txn, events_and_contexts=events_and_contexts, backfilled=backfilled
)
self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
# _update_outliers_txn filters out any events which have already been
# persisted, and returns the filtered list.
@ -398,7 +412,7 @@ class PersistEventsStore:
txn,
events_and_contexts=events_and_contexts,
all_events_and_contexts=all_events_and_contexts,
backfilled=backfilled,
inhibit_local_membership_updates=inhibit_local_membership_updates,
)
# We call this last as it assumes we've inserted the events into
@ -561,9 +575,9 @@ class PersistEventsStore:
# fetch their auth event info.
while missing_auth_chains:
sql = """
SELECT event_id, events.type, state_key, chain_id, sequence_number
SELECT event_id, events.type, se.state_key, chain_id, sequence_number
FROM events
INNER JOIN state_events USING (event_id)
INNER JOIN state_events AS se USING (event_id)
LEFT JOIN event_auth_chains USING (event_id)
WHERE
"""
@ -1200,7 +1214,6 @@ class PersistEventsStore:
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
):
"""Update min_depth for each room
@ -1208,13 +1221,18 @@ class PersistEventsStore:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]): events
we are persisting
backfilled (bool): True if the events were backfilled
"""
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
if not backfilled:
# Then update the `stream_ordering` position to mark the latest
# event as the front of the room. This should not be done for
# backfilled events because backfilled events have negative
# stream_ordering and happened in the past so we know that we don't
# need to update the stream_ordering tip/front for the room.
assert event.internal_metadata.stream_ordering is not None
if event.internal_metadata.stream_ordering >= 0:
txn.call_after(
self.store._events_stream_cache.entity_has_changed,
event.room_id,
@ -1427,7 +1445,12 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
def _update_metadata_tables_txn(
self, txn, events_and_contexts, all_events_and_contexts, backfilled
self,
txn,
*,
events_and_contexts,
all_events_and_contexts,
inhibit_local_membership_updates: bool = False,
):
"""Update all the miscellaneous tables for new events
@ -1439,7 +1462,10 @@ class PersistEventsStore:
events that we were going to persist. This includes events
we've already persisted, etc, that wouldn't appear in
events_and_context.
backfilled (bool): True if the events were backfilled
inhibit_local_membership_updates: Stop the local_current_membership
from being updated by these events. This should be set to True
for backfilled events because backfilled events in the past do
not affect the current local state.
"""
# Insert all the push actions into the event_push_actions table.
@ -1513,7 +1539,7 @@ class PersistEventsStore:
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
backfilled=backfilled,
inhibit_local_membership_updates=inhibit_local_membership_updates,
)
# Insert event_reference_hashes table.
@ -1553,11 +1579,13 @@ class PersistEventsStore:
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
def prefill():
for cache_entry in to_prefill:
self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry
)
txn.call_after(prefill)
@ -1638,8 +1666,19 @@ class PersistEventsStore:
txn, table="event_reference_hashes", values=vals
)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database."""
def _store_room_members_txn(
self, txn, events, *, inhibit_local_membership_updates: bool = False
):
"""
Store a room member in the database.
Args:
txn: The transaction to use.
events: List of events to store.
inhibit_local_membership_updates: Stop the local_current_membership
from being updated by these events. This should be set to True
for backfilled events because backfilled events in the past do
not affect the current local state.
"""
def non_null_str_or_none(val: Any) -> Optional[str]:
return val if isinstance(val, str) and "\u0000" not in val else None
@ -1682,7 +1721,7 @@ class PersistEventsStore:
# band membership", like a remote invite or a rejection of a remote invite.
if (
self.is_mine_id(event.state_key)
and not backfilled
and not inhibit_local_membership_updates
and event.internal_metadata.is_outlier()
and event.internal_metadata.is_out_of_band_membership()
):

View file

@ -15,14 +15,18 @@
import logging
import threading
from typing import (
TYPE_CHECKING,
Any,
Collection,
Container,
Dict,
Iterable,
List,
NoReturn,
Optional,
Set,
Tuple,
cast,
overload,
)
@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
RoomVersion,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# These values are used in the `enqueus_event` and `_do_fetch` methods to
# These values are used in the `enqueue_event` and `_fetch_loop` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
@attr.s(slots=True, auto_attribs=True)
class _EventCacheEntry:
class EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
@ -129,7 +145,7 @@ class _EventRow:
json: str
internal_metadata: str
format_version: Optional[int]
room_version_id: Optional[int]
room_version_id: Optional[str]
rejected_reason: Optional[str]
redactions: List[str]
outlier: bool
@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
# options controlling this.
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._stream_id_gen: AbstractStreamIdTracker
self._backfill_id_gen: AbstractStreamIdTracker
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
5 * 60 * 1000,
)
self._get_event_cache = LruCache(
self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
)
@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
str, ObservableDeferred[Dict[str, _EventCacheEntry]]
str, ObservableDeferred[Dict[str, EventCacheEntry]]
] = {}
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_list: List[
Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
] = []
self._event_fetch_ongoing = 0
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# We define this sequence here so that it can be referenced from both
# the DataStore and PersistEventStore.
def get_chain_id_txn(txn):
def get_chain_id_txn(txn: Cursor) -> int:
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
return txn.fetchone()[0]
return cast(Tuple[int], txn.fetchone())[0]
self.event_chain_id_gen = build_sequence_generator(
db_conn,
@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
id_column="chain_id",
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[False] = False,
check_room_id: Optional[str] = None,
get_prev_content: bool = ...,
allow_rejected: bool = ...,
allow_none: Literal[False] = ...,
check_room_id: Optional[str] = ...,
) -> EventBase:
...
@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[True] = False,
check_room_id: Optional[str] = None,
get_prev_content: bool = ...,
allow_rejected: bool = ...,
allow_none: Literal[True] = ...,
check_room_id: Optional[str] = ...,
) -> Optional[EventBase]:
...
@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_events(
self,
event_ids: Iterable[str],
event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Dict[str, _EventCacheEntry]:
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
# same dict into itself N times).
already_fetching_ids: Set[str] = set()
already_fetching_deferreds: Set[
ObservableDeferred[Dict[str, _EventCacheEntry]]
ObservableDeferred[Dict[str, EventCacheEntry]]
] = set()
for event_id in missing_events_ids:
@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, _EventCacheEntry]
] = ObservableDeferred(defer.Deferred())
Dict[str, EventCacheEntry]
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred
@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
def _invalidate_get_event_cache(self, event_id):
def _invalidate_get_event_cache(self, event_id: str) -> None:
self._get_event_cache.invalidate((event_id,))
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, _EventCacheEntry]:
) -> Dict[str, EventCacheEntry]:
"""Fetch events from the caches.
May return rejected events.
@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
for e in state_to_include.values()
]
def _do_fetch(self, conn: Connection) -> None:
def _maybe_start_fetch_thread(self) -> None:
"""Starts an event fetch thread if we are not yet at the maximum number."""
with self._event_fetch_lock:
if (
self._event_fetch_list
and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
):
self._event_fetch_ongoing += 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# `_event_fetch_ongoing` is decremented in `_fetch_thread`.
should_start = True
else:
should_start = False
if should_start:
run_as_background_process("fetch_events", self._fetch_thread)
async def _fetch_thread(self) -> None:
"""Services requests for events from `_event_fetch_list`."""
exc = None
try:
await self.db_pool.runWithConnection(self._fetch_loop)
except BaseException as e:
exc = e
raise
finally:
should_restart = False
event_fetches_to_fail = []
with self._event_fetch_lock:
self._event_fetch_ongoing -= 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
# There may still be work remaining in `_event_fetch_list` if we
# failed, or it was added in between us deciding to exit and
# decrementing `_event_fetch_ongoing`.
if self._event_fetch_list:
if exc is None:
# We decided to exit, but then some more work was added
# before `_event_fetch_ongoing` was decremented.
# If a new event fetch thread was not started, we should
# restart ourselves since the remaining event fetch threads
# may take a while to get around to the new work.
#
# Unfortunately it is not possible to tell whether a new
# event fetch thread was started, so we restart
# unconditionally. If we are unlucky, we will end up with
# an idle fetch thread, but it will time out after
# `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
# in any case.
#
# Note that multiple fetch threads may run down this path at
# the same time.
should_restart = True
elif isinstance(exc, Exception):
if self._event_fetch_ongoing == 0:
# We were the last remaining fetcher and failed.
# Fail any outstanding fetches since no one else will
# handle them.
event_fetches_to_fail = self._event_fetch_list
self._event_fetch_list = []
else:
# We weren't the last remaining fetcher, so another
# fetcher will pick up the work. This will either happen
# after their existing work, however long that takes,
# or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
# they are idle.
pass
else:
# The exception is a `SystemExit`, `KeyboardInterrupt` or
# `GeneratorExit`. Don't try to do anything clever here.
pass
if should_restart:
# We exited cleanly but noticed more work.
self._maybe_start_fetch_thread()
if event_fetches_to_fail:
# We were the last remaining fetcher and failed.
# Fail any outstanding fetches since no one else will handle them.
assert exc is not None
with PreserveLoggingContext():
for _, deferred in event_fetches_to_fail:
deferred.errback(exc)
def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
try:
i = 0
while True:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
i = 0
while True:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
if not event_list:
single_threaded = self.database_engine.single_threaded
if (
not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
or single_threaded
or i > EVENT_QUEUE_ITERATIONS
):
break
else:
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
if not event_list:
# There are no requests waiting. If we haven't yet reached the
# maximum iteration limit, wait for some more requests to turn up.
# Otherwise, bail out.
single_threaded = self.database_engine.single_threaded
if (
not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
or single_threaded
or i > EVENT_QUEUE_ITERATIONS
):
return
self._fetch_event_list(conn, event_list)
finally:
self._event_fetch_ongoing -= 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
self._fetch_event_list(conn, event_list)
def _fetch_event_list(
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
self,
conn: LoggingDatabaseConnection,
event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
) -> None:
"""Handle a load of requests from the _event_fetch_list queue
@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
)
# We only want to resolve deferreds from the main thread
def fire():
def fire() -> None:
for _, d in event_list:
d.callback(row_dict)
@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(exc)
def fire_errback(exc: Exception) -> None:
for _, d in event_list:
d.errback(exc)
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
self.hs.get_reactor().callFromThread(fire_errback, e)
async def _get_events_from_db(
self, event_ids: Iterable[str]
) -> Dict[str, _EventCacheEntry]:
self, event_ids: Collection[str]
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the database.
May return rejected events.
@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
map from event id to result. May return extra events which
weren't asked for.
"""
fetched_events = {}
fetched_event_ids: Set[str] = set()
fetched_events: Dict[str, _EventRow] = {}
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
redaction_ids: Set[str] = set()
for event_id in events_to_fetch:
row = row_map.get(event_id)
fetched_events[event_id] = row
fetched_event_ids.add(event_id)
if row:
fetched_events[event_id] = row
redaction_ids.update(row.redactions)
events_to_fetch = redaction_ids.difference(fetched_events.keys())
events_to_fetch = redaction_ids.difference(fetched_event_ids)
if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch)
# build a map from event_id to EventBase
event_map = {}
event_map: Dict[str, EventBase] = {}
for event_id, row in fetched_events.items():
if not row:
continue
assert row.event_id == event_id
rejected_reason = row.rejected_reason
@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
room_version_id = row.room_version_id
room_version: Optional[RoomVersion]
if not room_version_id:
# this should only happen for out-of-band membership events which
# arrived before #6983 landed. For all other events, we should have
@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build
# the cache entries.
result_map = {}
result_map: Dict[str, EventCacheEntry] = {}
for event_id, original_ev in event_map.items():
redactions = fetched_events[event_id].redactions
redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map
)
cache_entry = _EventCacheEntry(
cache_entry = EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)
@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
that weren't requested.
"""
events_d = defer.Deferred()
events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append((events, events_d))
self._event_fetch_lock.notify()
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
should_start = True
else:
should_start = False
if should_start:
run_as_background_process(
"fetch_events", self.db_pool.runWithConnection, self._do_fetch
)
self._maybe_start_fetch_thread()
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
async def have_events_in_timeline(self, event_ids):
async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids: events we are looking for
Returns:
set[str]: The events we have already seen.
The set of events we have already seen.
"""
res = await self._have_seen_events_dict(
(room_id, event_id) for event_id in event_ids
@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
}
results = {x: True for x in cache_results}
def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
def have_seen_events_txn(
txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
) -> None:
# we deliberately do *not* query the database for room_id, to make the
# query an index-only lookup on `events_event_id_key`.
#
@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str):
async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
# this only exists for the benefit of the @cachedList descriptor on
# _have_seen_events_dict
raise NotImplementedError()
def _get_current_state_event_counts_txn(self, txn, room_id):
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
) -> int:
"""
See get_current_state_event_counts.
"""
@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
async def get_room_complexity(self, room_id):
async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
more resources.
Args:
room_id (str)
room_id: The room ID to query.
Returns:
dict[str:int] of complexity version to complexity.
dict[str:float] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
return {"v1": complexity_v1}
def get_current_events_token(self):
def get_current_events_token(self) -> int:
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@ -1295,13 +1403,15 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
def get_all_new_forward_event_rows(txn):
def get_all_new_forward_event_rows(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
return txn.fetchall()
return cast(
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
)
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
) -> List[Tuple]:
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@ -1332,14 +1444,16 @@ class EventsWorkerStore(SQLBaseStore):
EventsStreamRow.
"""
def get_ex_outlier_stream_rows_txn(txn):
def get_ex_outlier_stream_rows_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" LEFT JOIN room_memberships USING (event_id)"
" LEFT JOIN rejections USING (event_id)"
@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name))
return txn.fetchall()
return cast(
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
)
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_backfill_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]:
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
"""Get updates for backfill replication stream, including all new
backfilled events and events that have gone from being outliers to not.
@ -1386,13 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
def get_all_new_backfill_event_rows(txn):
def get_all_new_backfill_event_rows(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" se.state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" AND instance_name = ?"
@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
new_event_updates = [(row[0], row[1:]) for row in txn]
new_event_updates: List[
Tuple[int, Tuple[str, str, str, str, str, str]]
] = []
row: Tuple[int, str, str, str, str, str, str]
# Type safety: iterating over `txn` yields `Tuple`, i.e.
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
# variadic tuple to a fixed length tuple and flags it up as an error.
for row in txn: # type: ignore[assignment]
new_event_updates.append((row[0], row[1:]))
limited = False
if len(new_event_updates) == limit:
@ -1411,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore):
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" se.state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound, instance_name))
new_event_updates.extend((row[0], row[1:]) for row in txn)
# Type safety: iterating over `txn` yields `Tuple`, i.e.
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
# variadic tuple to a fixed length tuple and flags it up as an error.
for row in txn: # type: ignore[assignment]
new_event_updates.append((row[0], row[1:]))
if len(new_event_updates) >= limit:
upper_bound = new_event_updates[-1][0]
@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_updated_current_state_deltas(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple], int, bool]:
) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
"""Fetch updates from current_state_delta_stream
Args:
@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
* `limited` is whether there are more updates to fetch.
"""
def get_all_updated_current_state_deltas_txn(txn):
def get_all_updated_current_state_deltas_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
return txn.fetchall()
return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
def get_deltas_for_stream_id_txn(txn, stream_id):
def get_deltas_for_stream_id_txn(
txn: LoggingTransaction, stream_id: int
) -> List[Tuple[int, str, str, str, str]]:
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE stream_id = ?
"""
txn.execute(sql, [stream_id])
return txn.fetchall()
return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
# we need to make sure that, for every stream id in the results, we get *all*
# the rows with that stream id.
rows: List[Tuple] = await self.db_pool.runInteraction(
rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
async def is_event_after(self, event_id1, event_id2):
async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
async def get_event_ordering(self, event_id):
async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
None otherwise.
"""
def get_next_event_to_expire_txn(txn):
def get_next_event_to_expire_txn(
txn: LoggingTransaction,
) -> Optional[Tuple[str, int]]:
txn.execute(
"""
SELECT event_id, expiry_ts FROM event_expiry
@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
)
return txn.fetchone()
return cast(Optional[Tuple[str, int]], txn.fetchone())
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
async def _cleanup_old_transaction_ids(self):
async def _cleanup_old_transaction_ids(self) -> None:
"""Cleans out transaction id mappings older than 24hrs."""
def _cleanup_old_transaction_ids_txn(txn):
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
sql = """
DELETE FROM event_txn_id
WHERE inserted_ts < ?
@ -1626,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore):
"_cleanup_old_transaction_ids",
_cleanup_old_transaction_ids_txn,
)
async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
"""Check if the given event is next to a backward gap of missing events.
<latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages>
Args:
room_id: room where the event lives
event_id: event to check
Returns:
Boolean indicating whether it's an extremity
"""
def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
# If the event in question has any of its prev_events listed as a
# backward extremity, it's next to a gap.
#
# We can't just check the backward edges in `event_edges` because
# when we persist events, we will also record the prev_events as
# edges to the event in question regardless of whether we have those
# prev_events yet. We need to check whether those prev_events are
# backward extremities, also known as gaps, that need to be
# backfilled.
backward_extremity_query = """
SELECT 1 FROM event_backward_extremities
WHERE
room_id = ?
AND %s
LIMIT 1
"""
# If the event in question is a backward extremity or has any of its
# prev_events listed as a backward extremity, it's next to a
# backward gap.
clause, args = make_in_list_sql_clause(
self.database_engine,
"event_id",
[event.event_id] + list(event.prev_event_ids()),
)
txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
backward_extremities = txn.fetchall()
# We consider any backward extremity as a backward gap
if len(backward_extremities):
return True
return False
return await self.db_pool.runInteraction(
"is_event_next_to_backward_gap_txn",
is_event_next_to_backward_gap_txn,
)
async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
"""Check if the given event is next to a forward gap of missing events.
The gap in front of the latest events is not considered a gap.
<latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages>
<latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages>
Args:
room_id: room where the event lives
event_id: event to check
Returns:
Boolean indicating whether it's an extremity
"""
def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
# If the event in question is a forward extremity, we will just
# consider any potential forward gap as not a gap since it's one of
# the latest events in the room.
#
# `event_forward_extremities` does not include backfilled or outlier
# events so we can't rely on it to find forward gaps. We can only
# use it to determine whether a message is the latest in the room.
#
# We can't combine this query with the `forward_edge_query` below
# because if the event in question has no forward edges (isn't
# referenced by any other event's prev_events) but is in
# `event_forward_extremities`, we don't want to return 0 rows and
# say it's next to a gap.
forward_extremity_query = """
SELECT 1 FROM event_forward_extremities
WHERE
room_id = ?
AND event_id = ?
LIMIT 1
"""
# Check to see whether the event in question is already referenced
# by another event. If we don't see any edges, we're next to a
# forward gap.
forward_edge_query = """
SELECT 1 FROM event_edges
/* Check to make sure the event referencing our event in question is not rejected */
LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
WHERE
event_edges.room_id = ?
AND event_edges.prev_event_id = ?
/* It's not a valid edge if the event referencing our event in
* question is rejected.
*/
AND rejections.event_id IS NULL
LIMIT 1
"""
# We consider any forward extremity as the latest in the room and
# not a forward gap.
#
# To expand, even though there is technically a gap at the front of
# the room where the forward extremities are, we consider those the
# latest messages in the room so asking other homeservers for more
# is useless. The new latest messages will just be federated as
# usual.
txn.execute(forward_extremity_query, (event.room_id, event.event_id))
forward_extremities = txn.fetchall()
if len(forward_extremities):
return False
# If there are no forward edges to the event in question (another
# event hasn't referenced this event in their prev_events), then we
# assume there is a forward gap in the history.
txn.execute(forward_edge_query, (event.room_id, event.event_id))
forward_edges = txn.fetchall()
if not len(forward_edges):
return True
return False
return await self.db_pool.runInteraction(
"is_event_next_to_gap_txn",
is_event_next_to_gap_txn,
)
async def get_event_id_for_timestamp(
self, room_id: str, timestamp: int, direction: str
) -> Optional[str]:
"""Find the closest event to the given timestamp in the given direction.
Args:
room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward
or backward from the given timestamp to find the closest event.
Returns:
The closest event_id otherwise None if we can't find any event in
the given direction.
"""
sql_template = """
SELECT event_id FROM events
LEFT JOIN rejections USING (event_id)
WHERE
origin_server_ts %s ?
AND room_id = ?
/* Make sure event is not rejected */
AND rejections.event_id IS NULL
ORDER BY origin_server_ts %s
LIMIT 1;
"""
def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
if direction == "b":
# Find closest event *before* a given timestamp. We use descending
# (which gives values largest to smallest) because we want the
# largest possible timestamp *before* the given timestamp.
comparison_operator = "<="
order = "DESC"
else:
# Find closest event *after* a given timestamp. We use ascending
# (which gives values smallest to largest) because we want the
# closest possible timestamp *after* the given timestamp.
comparison_operator = ">="
order = "ASC"
txn.execute(
sql_template % (comparison_operator, order), (timestamp, room_id)
)
row = txn.fetchone()
if row:
(event_id,) = row
return event_id
return None
if direction not in ("f", "b"):
raise ValueError("Unknown direction: %s" % (direction,))
return await self.db_pool.runInteraction(
"get_event_id_for_timestamp_txn",
get_event_id_for_timestamp_txn,
)

View file

@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
logger.info("[purge] looking for events to delete")
should_delete_expr = "state_key IS NULL"
should_delete_expr = "state_events.state_key IS NULL"
should_delete_params: Tuple[Any, ...] = ()
if not delete_local_events:
should_delete_expr += " AND event_id NOT LIKE ?"

View file

@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -82,9 +85,9 @@ class PushRulesWorkerStore(
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen: Union[
StreamIdGenerator, SlavedIdTracker
] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn, "push_rules_stream", "stream_id"
)
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"

View file

@ -106,6 +106,15 @@ class RefreshTokenLookupResult:
has_next_access_token_been_used: bool
"""True if the next access token was already used at least once."""
expiry_ts: Optional[int]
"""The time at which the refresh token expires and can not be used.
If None, the refresh token doesn't expire."""
ultimate_session_expiry_ts: Optional[int]
"""The time at which the session comes to an end and can no longer be
refreshed.
If None, the session can be refreshed indefinitely."""
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
@ -1626,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
rt.user_id,
rt.device_id,
rt.next_token_id,
(nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
at.used has_next_access_token_been_used
(nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
at.used AS has_next_access_token_been_used,
rt.expiry_ts,
rt.ultimate_session_expiry_ts
FROM refresh_tokens rt
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
@ -1648,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
has_next_refresh_token_been_refreshed=row[4],
# This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False),
expiry_ts=row[6],
ultimate_session_expiry_ts=row[7],
)
return await self.db_pool.runInteraction(
@ -1915,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: str,
token: str,
device_id: Optional[str],
expiry_ts: Optional[int],
ultimate_session_expiry_ts: Optional[int],
) -> int:
"""Adds a refresh token for the given user.
@ -1922,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
user_id: The user ID.
token: The new access token to add.
device_id: ID of the device to associate with the refresh token.
expiry_ts (milliseconds since the epoch): Time after which the
refresh token cannot be used.
If None, the refresh token never expires until it has been used.
ultimate_session_expiry_ts (milliseconds since the epoch):
Time at which the session will end and can not be extended any
further.
If None, the session can be refreshed indefinitely.
Raises:
StoreError if there was a problem adding this.
Returns:
@ -1937,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"device_id": device_id,
"token": token,
"next_token_id": None,
"expiry_ts": expiry_ts,
"ultimate_session_expiry_ts": ultimate_session_expiry_ts,
},
desc="add_refresh_token_to_user",
)

View file

@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND c.state_key = ?
AND c.membership = ?
"""
else:
@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
INNER JOIN events AS e USING (room_id, event_id)
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND c.state_key = ?
AND m.membership = ?
"""

View file

@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
oldest `limit` events.
Returns:
The list of events (in ascending order) and the token from the start
The list of events (in ascending stream order) and the token from the start
of the chunk of events returned.
"""
if from_key == to_key:
@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if not has_changed:
return [], from_key
def f(txn):
def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
"""Fetch membership events for a given user.
All such events whose stream ordering `s` lies in the range
`from_key < s <= to_key` are returned. Events are ordered by ascending stream
order.
"""
# Start by ruling out cases where a DB query is not necessary.
if from_key == to_key:
return []
@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if not has_changed:
return []
def f(txn):
def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
# To handle tokens with a non-empty instance_map we fetch more
# results than necessary and then filter down
min_from_id = from_key.stream
@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
Returns:
A list of events and a token pointing to the start of the returned
events. The events returned are in ascending order.
events. The events returned are in ascending topological order.
"""
rows, token = await self.get_recent_event_ids_for_room(

View file

@ -14,6 +14,7 @@
import logging
from collections import namedtuple
from enum import Enum
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
import attr
@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple(
)
class DestinationSortOrder(Enum):
"""Enum to define the sorting method used when returning destinations."""
DESTINATION = "destination"
RETRY_LAST_TS = "retry_last_ts"
RETTRY_INTERVAL = "retry_interval"
FAILURE_TS = "failure_ts"
LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DestinationRetryTimings:
"""The current destination retry timing info for a remote server."""
@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
destinations = [row[0] for row in txn]
return destinations
async def get_destinations_paginate(
self,
start: int,
limit: int,
destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value,
direction: str = "f",
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the
total number of destinations matching the filter criteria.
Args:
start: start number to begin the query from
limit: number of rows to retrieve
destination: search string in destination
order_by: the sort order of the returned list
direction: sort ascending or descending
Returns:
A tuple of a list of mappings from destination to information
and a count of total destinations.
"""
def get_destinations_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
order_by_column = DestinationSortOrder(order_by).value
if direction == "b":
order = "DESC"
else:
order = "ASC"
args = []
where_statement = ""
if destination:
args.extend(["%" + destination.lower() + "%"])
where_statement = "WHERE LOWER(destination) LIKE ?"
sql_base = f"FROM destinations {where_statement} "
sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
txn.execute(sql, args)
count = txn.fetchone()[0]
sql = f"""
SELECT destination, retry_last_ts, retry_interval, failure_ts,
last_successful_stream_ordering
{sql_base}
ORDER BY {order_by_column} {order}, destination ASC
LIMIT ? OFFSET ?
"""
txn.execute(sql, args + [limit, start])
destinations = self.db_pool.cursor_to_dict(txn)
return destinations, count
return await self.db_pool.runInteraction(
"get_destinations_paginate_txn", get_destinations_paginate_txn
)

View file

@ -583,7 +583,8 @@ class EventsPersistenceStorage:
current_state_for_room=current_state_for_room,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
backfilled=backfilled,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
)
await self._handle_potentially_left_users(potentially_left_users)

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
SCHEMA_VERSION = 65 # remember to update the list below when updating
SCHEMA_VERSION = 66 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@ -46,6 +46,10 @@ Changes in SCHEMA_VERSION = 65:
- MSC2716: Remove unique event_id constraint from insertion_event_edges
because an insertion event can have multiple edges.
- Remove unused tables `user_stats_historical` and `room_stats_historical`.
Changes in SCHEMA_VERSION = 66:
- Queries on state_key columns are now disambiguated (ie, the codebase can handle
the `events` table having a `state_key` column).
"""

View file

@ -0,0 +1,28 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE refresh_tokens
-- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens.
-- They may not be used after they have expired.
-- If null, then the refresh token's lifetime is unlimited.
ADD COLUMN expiry_ts BIGINT DEFAULT NULL;
ALTER TABLE refresh_tokens
-- We also add an ultimate session expiry time (in milliseconds since the Epoch).
-- No matter how much the access and refresh tokens are refreshed, they cannot
-- be extended past this time.
-- If null, then the session length is unlimited.
ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL;

View file

@ -0,0 +1,27 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Track the auth provider used by each login as well as the session ID
CREATE TABLE device_auth_providers (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
auth_provider_id TEXT NOT NULL,
auth_provider_session_id TEXT NOT NULL
);
CREATE INDEX device_auth_providers_devices
ON device_auth_providers (user_id, device_id);
CREATE INDEX device_auth_providers_sessions
ON device_auth_providers (auth_provider_id, auth_provider_session_id);

View file

@ -89,31 +89,77 @@ def _load_current_id(
return (max if step > 0 else min)(current_id, step)
class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
raise NotImplementedError()
class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
"""Tracks the "current" stream ID of a stream that may have multiple writers.
Stream IDs are monotonically increasing or decreasing integers representing write
transactions. The "current" stream ID is the stream ID such that all transactions
with equal or smaller stream IDs have completed. Since transactions may complete out
of order, this is not the same as the stream ID of the last completed transaction.
Completed transactions include both committed transactions and transactions that
have been rolled back.
"""
@abc.abstractmethod
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
The maximum stream id.
"""
raise NotImplementedError()
@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to `get_current_token`.
"""
raise NotImplementedError()
class AbstractStreamIdGenerator(AbstractStreamIdTracker):
"""Generates stream IDs for a stream that may have multiple writers.
Each stream ID represents a write transaction, whose completion is tracked
so that the "current" stream ID of the stream can be determined.
See `AbstractStreamIdTracker` for more details.
"""
@abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
raise NotImplementedError()
@abc.abstractmethod
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
Usage:
async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
"""Used to generate new stream ids when persisting events while keeping
track of which transactions have been completed.
"""Generates and tracks stream IDs for a stream with a single writer.
This allows us to get the "current" stream id, i.e. the stream id such that
all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order.
This class must only be used when the current Synapse process is the sole
writer for a stream.
Args:
db_conn(connection): A database connection to use to fetch the
@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
def advance(self, instance_name: str, new_id: int) -> None:
# `StreamIdGenerator` should only be used when there is a single writer,
# so replication should never happen.
raise Exception("Replication is not supported by StreamIdGenerator")
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
self._current += self._step
next_id = self._current
@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
"""
Usage:
async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
next_ids = range(
self._current + self._step,
@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
Returns:
The maximum stream id.
"""
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""An ID generator that tracks a stream that can have multiple writers.
"""Generates and tracks stream IDs for a stream with multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other
writers will only get updated when `advance` is called (by replication).
@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return stream_ids
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
"""
Usage:
async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._add_persisted_position(next_id)
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer."""
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
#
@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
}
def advance(self, instance_name: str, new_id: int) -> None:
"""Advance the position of the named writer to the given ID, if greater
than existing entry.
"""
new_id *= self._return_factor
with self._lock: