mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-06-05 19:59:08 -04:00
Revert "Revert accidental fast-forward merge from v1.49.0rc1"
This reverts commit 158d73ebdd
.
This commit is contained in:
parent
158d73ebdd
commit
4dd9ea8f4f
165 changed files with 7715 additions and 2703 deletions
|
@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401
|
|||
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.types import StreamToken, get_domain_from_id
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
|||
self,
|
||||
stream_name: str,
|
||||
instance_name: str,
|
||||
token: StreamToken,
|
||||
token: int,
|
||||
rows: Iterable[Any],
|
||||
) -> None:
|
||||
pass
|
||||
|
|
|
@ -12,12 +12,22 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AsyncContextManager,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util import Clock, json_encoder
|
||||
|
||||
from . import engines
|
||||
|
||||
|
@ -28,6 +38,45 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
|
||||
DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
|
||||
MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class _BackgroundUpdateHandler:
|
||||
"""A handler for a given background update.
|
||||
|
||||
Attributes:
|
||||
callback: The function to call to make progress on the background
|
||||
update.
|
||||
oneshot: Wether the update is likely to happen all in one go, ignoring
|
||||
the supplied target duration, e.g. index creation. This is used by
|
||||
the update controller to help correctly schedule the update.
|
||||
"""
|
||||
|
||||
callback: Callable[[JsonDict, int], Awaitable[int]]
|
||||
oneshot: bool = False
|
||||
|
||||
|
||||
class _BackgroundUpdateContextManager:
|
||||
BACKGROUND_UPDATE_INTERVAL_MS = 1000
|
||||
BACKGROUND_UPDATE_DURATION_MS = 100
|
||||
|
||||
def __init__(self, sleep: bool, clock: Clock):
|
||||
self._sleep = sleep
|
||||
self._clock = clock
|
||||
|
||||
async def __aenter__(self) -> int:
|
||||
if self._sleep:
|
||||
await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000)
|
||||
|
||||
return self.BACKGROUND_UPDATE_DURATION_MS
|
||||
|
||||
async def __aexit__(self, *exc) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class BackgroundUpdatePerformance:
|
||||
"""Tracks the how long a background update is taking to update its items"""
|
||||
|
||||
|
@ -84,20 +133,22 @@ class BackgroundUpdater:
|
|||
|
||||
MINIMUM_BACKGROUND_BATCH_SIZE = 1
|
||||
DEFAULT_BACKGROUND_BATCH_SIZE = 100
|
||||
BACKGROUND_UPDATE_INTERVAL_MS = 1000
|
||||
BACKGROUND_UPDATE_DURATION_MS = 100
|
||||
|
||||
def __init__(self, hs: "HomeServer", database: "DatabasePool"):
|
||||
self._clock = hs.get_clock()
|
||||
self.db_pool = database
|
||||
|
||||
self._database_name = database.name()
|
||||
|
||||
# if a background update is currently running, its name.
|
||||
self._current_background_update: Optional[str] = None
|
||||
|
||||
self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
|
||||
self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
|
||||
self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
|
||||
|
||||
self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
|
||||
self._background_update_handlers: Dict[
|
||||
str, Callable[[JsonDict, int], Awaitable[int]]
|
||||
] = {}
|
||||
self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
|
||||
self._all_done = False
|
||||
|
||||
# Whether we're currently running updates
|
||||
|
@ -107,6 +158,83 @@ class BackgroundUpdater:
|
|||
# enable/disable background updates via the admin API.
|
||||
self.enabled = True
|
||||
|
||||
def register_update_controller_callbacks(
|
||||
self,
|
||||
on_update: ON_UPDATE_CALLBACK,
|
||||
default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
|
||||
min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
|
||||
) -> None:
|
||||
"""Register callbacks from a module for each hook."""
|
||||
if self._on_update_callback is not None:
|
||||
logger.warning(
|
||||
"More than one module tried to register callbacks for controlling"
|
||||
" background updates. Only the callbacks registered by the first module"
|
||||
" (in order of appearance in Synapse's configuration file) that tried to"
|
||||
" do so will be called."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
self._on_update_callback = on_update
|
||||
|
||||
if default_batch_size is not None:
|
||||
self._default_batch_size_callback = default_batch_size
|
||||
|
||||
if min_batch_size is not None:
|
||||
self._min_batch_size_callback = min_batch_size
|
||||
|
||||
def _get_context_manager_for_update(
|
||||
self,
|
||||
sleep: bool,
|
||||
update_name: str,
|
||||
database_name: str,
|
||||
oneshot: bool,
|
||||
) -> AsyncContextManager[int]:
|
||||
"""Get a context manager to run a background update with.
|
||||
|
||||
If a module has registered a `update_handler` callback, use the context manager
|
||||
it returns.
|
||||
|
||||
Otherwise, returns a context manager that will return a default value, optionally
|
||||
sleeping if needed.
|
||||
|
||||
Args:
|
||||
sleep: Whether we can sleep between updates.
|
||||
update_name: The name of the update.
|
||||
database_name: The name of the database the update is being run on.
|
||||
oneshot: Whether the update will complete all in one go, e.g. index creation.
|
||||
In such cases the returned target duration is ignored.
|
||||
|
||||
Returns:
|
||||
The target duration in milliseconds that the background update should run for.
|
||||
|
||||
Note: this is a *target*, and an iteration may take substantially longer or
|
||||
shorter.
|
||||
"""
|
||||
if self._on_update_callback is not None:
|
||||
return self._on_update_callback(update_name, database_name, oneshot)
|
||||
|
||||
return _BackgroundUpdateContextManager(sleep, self._clock)
|
||||
|
||||
async def _default_batch_size(self, update_name: str, database_name: str) -> int:
|
||||
"""The batch size to use for the first iteration of a new background
|
||||
update.
|
||||
"""
|
||||
if self._default_batch_size_callback is not None:
|
||||
return await self._default_batch_size_callback(update_name, database_name)
|
||||
|
||||
return self.DEFAULT_BACKGROUND_BATCH_SIZE
|
||||
|
||||
async def _min_batch_size(self, update_name: str, database_name: str) -> int:
|
||||
"""A lower bound on the batch size of a new background update.
|
||||
|
||||
Used to ensure that progress is always made. Must be greater than 0.
|
||||
"""
|
||||
if self._min_batch_size_callback is not None:
|
||||
return await self._min_batch_size_callback(update_name, database_name)
|
||||
|
||||
return self.MINIMUM_BACKGROUND_BATCH_SIZE
|
||||
|
||||
def get_current_update(self) -> Optional[BackgroundUpdatePerformance]:
|
||||
"""Returns the current background update, if any."""
|
||||
|
||||
|
@ -135,13 +263,8 @@ class BackgroundUpdater:
|
|||
try:
|
||||
logger.info("Starting background schema updates")
|
||||
while self.enabled:
|
||||
if sleep:
|
||||
await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
|
||||
|
||||
try:
|
||||
result = await self.do_next_background_update(
|
||||
self.BACKGROUND_UPDATE_DURATION_MS
|
||||
)
|
||||
result = await self.do_next_background_update(sleep)
|
||||
except Exception:
|
||||
logger.exception("Error doing update")
|
||||
else:
|
||||
|
@ -203,13 +326,15 @@ class BackgroundUpdater:
|
|||
|
||||
return not update_exists
|
||||
|
||||
async def do_next_background_update(self, desired_duration_ms: float) -> bool:
|
||||
async def do_next_background_update(self, sleep: bool = True) -> bool:
|
||||
"""Does some amount of work on the next queued background update
|
||||
|
||||
Returns once some amount of work is done.
|
||||
|
||||
Args:
|
||||
desired_duration_ms: How long we want to spend updating.
|
||||
sleep: Whether to limit how quickly we run background updates or
|
||||
not.
|
||||
|
||||
Returns:
|
||||
True if we have finished running all the background updates, otherwise False
|
||||
"""
|
||||
|
@ -252,7 +377,19 @@ class BackgroundUpdater:
|
|||
|
||||
self._current_background_update = upd["update_name"]
|
||||
|
||||
await self._do_background_update(desired_duration_ms)
|
||||
# We have a background update to run, otherwise we would have returned
|
||||
# early.
|
||||
assert self._current_background_update is not None
|
||||
update_info = self._background_update_handlers[self._current_background_update]
|
||||
|
||||
async with self._get_context_manager_for_update(
|
||||
sleep=sleep,
|
||||
update_name=self._current_background_update,
|
||||
database_name=self._database_name,
|
||||
oneshot=update_info.oneshot,
|
||||
) as desired_duration_ms:
|
||||
await self._do_background_update(desired_duration_ms)
|
||||
|
||||
return False
|
||||
|
||||
async def _do_background_update(self, desired_duration_ms: float) -> int:
|
||||
|
@ -260,7 +397,7 @@ class BackgroundUpdater:
|
|||
update_name = self._current_background_update
|
||||
logger.info("Starting update batch on background update '%s'", update_name)
|
||||
|
||||
update_handler = self._background_update_handlers[update_name]
|
||||
update_handler = self._background_update_handlers[update_name].callback
|
||||
|
||||
performance = self._background_update_performance.get(update_name)
|
||||
|
||||
|
@ -273,9 +410,14 @@ class BackgroundUpdater:
|
|||
if items_per_ms is not None:
|
||||
batch_size = int(desired_duration_ms * items_per_ms)
|
||||
# Clamp the batch size so that we always make progress
|
||||
batch_size = max(batch_size, self.MINIMUM_BACKGROUND_BATCH_SIZE)
|
||||
batch_size = max(
|
||||
batch_size,
|
||||
await self._min_batch_size(update_name, self._database_name),
|
||||
)
|
||||
else:
|
||||
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
|
||||
batch_size = await self._default_batch_size(
|
||||
update_name, self._database_name
|
||||
)
|
||||
|
||||
progress_json = await self.db_pool.simple_select_one_onecol(
|
||||
"background_updates",
|
||||
|
@ -294,6 +436,8 @@ class BackgroundUpdater:
|
|||
|
||||
duration_ms = time_stop - time_start
|
||||
|
||||
performance.update(items_updated, duration_ms)
|
||||
|
||||
logger.info(
|
||||
"Running background update %r. Processed %r items in %rms."
|
||||
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
|
||||
|
@ -306,8 +450,6 @@ class BackgroundUpdater:
|
|||
batch_size,
|
||||
)
|
||||
|
||||
performance.update(items_updated, duration_ms)
|
||||
|
||||
return len(self._background_update_performance)
|
||||
|
||||
def register_background_update_handler(
|
||||
|
@ -331,7 +473,9 @@ class BackgroundUpdater:
|
|||
update_name: The name of the update that this code handles.
|
||||
update_handler: The function that does the update.
|
||||
"""
|
||||
self._background_update_handlers[update_name] = update_handler
|
||||
self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
|
||||
update_handler
|
||||
)
|
||||
|
||||
def register_noop_background_update(self, update_name: str) -> None:
|
||||
"""Register a noop handler for a background update.
|
||||
|
@ -453,7 +597,9 @@ class BackgroundUpdater:
|
|||
await self._end_background_update(update_name)
|
||||
return 1
|
||||
|
||||
self.register_background_update_handler(update_name, updater)
|
||||
self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
|
||||
updater, oneshot=True
|
||||
)
|
||||
|
||||
async def _end_background_update(self, update_name: str) -> None:
|
||||
"""Removes a completed background update task from the queue.
|
||||
|
|
|
@ -143,7 +143,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
A list of ApplicationServices, which may be empty.
|
||||
"""
|
||||
results = await self.db_pool.simple_select_list(
|
||||
"application_services_state", {"state": state}, ["as_id"]
|
||||
"application_services_state", {"state": state.value}, ["as_id"]
|
||||
)
|
||||
# NB: This assumes this class is linked with ApplicationServiceStore
|
||||
as_list = self.get_app_services()
|
||||
|
@ -173,7 +173,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
desc="get_appservice_state",
|
||||
)
|
||||
if result:
|
||||
return result.get("state")
|
||||
return ApplicationServiceState(result.get("state"))
|
||||
return None
|
||||
|
||||
async def set_appservice_state(
|
||||
|
@ -186,7 +186,7 @@ class ApplicationServiceTransactionWorkerStore(
|
|||
state: The connectivity state to apply.
|
||||
"""
|
||||
await self.db_pool.simple_upsert(
|
||||
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||
"application_services_state", {"as_id": service.id}, {"state": state.value}
|
||||
)
|
||||
|
||||
async def create_appservice_txn(
|
||||
|
|
|
@ -139,6 +139,27 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
|
||||
return {d["device_id"]: d for d in devices}
|
||||
|
||||
async def get_devices_by_auth_provider_session_id(
|
||||
self, auth_provider_id: str, auth_provider_session_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Retrieve the list of devices associated with a SSO IdP session ID.
|
||||
|
||||
Args:
|
||||
auth_provider_id: The SSO IdP ID as defined in the server config
|
||||
auth_provider_session_id: The session ID within the IdP
|
||||
Returns:
|
||||
A list of dicts containing the device_id and the user_id of each device
|
||||
"""
|
||||
return await self.db_pool.simple_select_list(
|
||||
table="device_auth_providers",
|
||||
keyvalues={
|
||||
"auth_provider_id": auth_provider_id,
|
||||
"auth_provider_session_id": auth_provider_session_id,
|
||||
},
|
||||
retcols=("user_id", "device_id"),
|
||||
desc="get_devices_by_auth_provider_session_id",
|
||||
)
|
||||
|
||||
@trace
|
||||
async def get_device_updates_by_remote(
|
||||
self, destination: str, from_stream_id: int, limit: int
|
||||
|
@ -1070,7 +1091,12 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
async def store_device(
|
||||
self, user_id: str, device_id: str, initial_device_display_name: Optional[str]
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
initial_device_display_name: Optional[str],
|
||||
auth_provider_id: Optional[str] = None,
|
||||
auth_provider_session_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Ensure the given device is known; add it to the store if not
|
||||
|
||||
|
@ -1079,6 +1105,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
device_id: id of device
|
||||
initial_device_display_name: initial displayname of the device.
|
||||
Ignored if device exists.
|
||||
auth_provider_id: The SSO IdP the user used, if any.
|
||||
auth_provider_session_id: The session ID (sid) got from a OIDC login.
|
||||
|
||||
Returns:
|
||||
Whether the device was inserted or an existing device existed with that ID.
|
||||
|
@ -1115,6 +1143,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
if hidden:
|
||||
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
|
||||
|
||||
if auth_provider_id and auth_provider_session_id:
|
||||
await self.db_pool.simple_insert(
|
||||
"device_auth_providers",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"auth_provider_id": auth_provider_id,
|
||||
"auth_provider_session_id": auth_provider_session_id,
|
||||
},
|
||||
desc="store_device_auth_provider",
|
||||
)
|
||||
|
||||
self.device_id_exists_cache.set(key, True)
|
||||
return inserted
|
||||
except StoreError:
|
||||
|
@ -1168,6 +1208,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
keyvalues={"user_id": user_id},
|
||||
)
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="device_auth_providers",
|
||||
column="device_id",
|
||||
values=device_ids,
|
||||
keyvalues={"user_id": user_id},
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
|
||||
for device_id in device_ids:
|
||||
self.device_id_exists_cache.invalidate((user_id, device_id))
|
||||
|
|
|
@ -1552,9 +1552,9 @@ class EventFederationStore(EventFederationWorkerStore):
|
|||
DELETE FROM event_auth
|
||||
WHERE event_id IN (
|
||||
SELECT event_id FROM events
|
||||
LEFT JOIN state_events USING (room_id, event_id)
|
||||
LEFT JOIN state_events AS se USING (room_id, event_id)
|
||||
WHERE ? <= stream_ordering AND stream_ordering < ?
|
||||
AND state_key IS null
|
||||
AND se.state_key IS null
|
||||
)
|
||||
"""
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ import logging
|
|||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
|
@ -37,6 +38,20 @@ DEFAULT_HIGHLIGHT_ACTION = [
|
|||
]
|
||||
|
||||
|
||||
class BasePushAction(TypedDict):
|
||||
event_id: str
|
||||
actions: List[Union[dict, str]]
|
||||
|
||||
|
||||
class HttpPushAction(BasePushAction):
|
||||
room_id: str
|
||||
stream_ordering: int
|
||||
|
||||
|
||||
class EmailPushAction(HttpPushAction):
|
||||
received_ts: Optional[int]
|
||||
|
||||
|
||||
def _serialize_action(actions, is_highlight):
|
||||
"""Custom serializer for actions. This allows us to "compress" common actions.
|
||||
|
||||
|
@ -221,7 +236,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
min_stream_ordering: int,
|
||||
max_stream_ordering: int,
|
||||
limit: int = 20,
|
||||
) -> List[dict]:
|
||||
) -> List[HttpPushAction]:
|
||||
"""Get a list of the most recent unread push actions for a given user,
|
||||
within the given stream ordering range. Called by the httppusher.
|
||||
|
||||
|
@ -326,7 +341,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
min_stream_ordering: int,
|
||||
max_stream_ordering: int,
|
||||
limit: int = 20,
|
||||
) -> List[dict]:
|
||||
) -> List[EmailPushAction]:
|
||||
"""Get a list of the most recent unread push actions for a given user,
|
||||
within the given stream ordering range. Called by the emailpusher
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
import itertools
|
||||
import logging
|
||||
from collections import OrderedDict, namedtuple
|
||||
from collections import OrderedDict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
|
@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401
|
|||
from synapse.logging.utils import log_function
|
||||
from synapse.storage._base import db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.databases.main.events_worker import EventCacheEntry
|
||||
from synapse.storage.databases.main.search import SearchEntry
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
|
||||
from synapse.storage.util.sequence import SequenceGenerator
|
||||
from synapse.types import StateMap, get_domain_from_id
|
||||
from synapse.util import json_encoder
|
||||
|
@ -64,9 +65,6 @@ event_counter = Counter(
|
|||
)
|
||||
|
||||
|
||||
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class DeltaState:
|
||||
"""Deltas to use to update the `current_state_events` table.
|
||||
|
@ -108,23 +106,30 @@ class PersistEventsStore:
|
|||
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
# Ideally we'd move these ID gens here, unfortunately some other ID
|
||||
# generators are chained off them so doing so is a bit of a PITA.
|
||||
self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
|
||||
self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
|
||||
|
||||
# This should only exist on instances that are configured to write
|
||||
assert (
|
||||
hs.get_instance_name() in hs.config.worker.writers.events
|
||||
), "Can only instantiate EventsStore on master"
|
||||
|
||||
# Since we have been configured to write, we ought to have id generators,
|
||||
# rather than id trackers.
|
||||
assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
|
||||
assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
|
||||
|
||||
# Ideally we'd move these ID gens here, unfortunately some other ID
|
||||
# generators are chained off them so doing so is a bit of a PITA.
|
||||
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
|
||||
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
|
||||
|
||||
async def _persist_events_and_state_updates(
|
||||
self,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
*,
|
||||
current_state_for_room: Dict[str, StateMap[str]],
|
||||
state_delta_for_room: Dict[str, DeltaState],
|
||||
new_forward_extremeties: Dict[str, List[str]],
|
||||
backfilled: bool = False,
|
||||
use_negative_stream_ordering: bool = False,
|
||||
inhibit_local_membership_updates: bool = False,
|
||||
) -> None:
|
||||
"""Persist a set of events alongside updates to the current state and
|
||||
forward extremities tables.
|
||||
|
@ -137,7 +142,14 @@ class PersistEventsStore:
|
|||
room state
|
||||
new_forward_extremities: Map from room_id to list of event IDs
|
||||
that are the new forward extremities of the room.
|
||||
backfilled
|
||||
use_negative_stream_ordering: Whether to start stream_ordering on
|
||||
the negative side and decrement. This should be set as True
|
||||
for backfilled events because backfilled events get a negative
|
||||
stream ordering so they don't come down incremental `/sync`.
|
||||
inhibit_local_membership_updates: Stop the local_current_membership
|
||||
from being updated by these events. This should be set to True
|
||||
for backfilled events because backfilled events in the past do
|
||||
not affect the current local state.
|
||||
|
||||
Returns:
|
||||
Resolves when the events have been persisted
|
||||
|
@ -159,7 +171,7 @@ class PersistEventsStore:
|
|||
#
|
||||
# Note: Multiple instances of this function cannot be in flight at
|
||||
# the same time for the same room.
|
||||
if backfilled:
|
||||
if use_negative_stream_ordering:
|
||||
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
|
@ -176,13 +188,13 @@ class PersistEventsStore:
|
|||
"persist_events",
|
||||
self._persist_events_txn,
|
||||
events_and_contexts=events_and_contexts,
|
||||
backfilled=backfilled,
|
||||
inhibit_local_membership_updates=inhibit_local_membership_updates,
|
||||
state_delta_for_room=state_delta_for_room,
|
||||
new_forward_extremeties=new_forward_extremeties,
|
||||
)
|
||||
persist_event_counter.inc(len(events_and_contexts))
|
||||
|
||||
if not backfilled:
|
||||
if stream < 0:
|
||||
# backfilled events have negative stream orderings, so we don't
|
||||
# want to set the event_persisted_position to that.
|
||||
synapse.metrics.event_persisted_position.set(
|
||||
|
@ -316,8 +328,9 @@ class PersistEventsStore:
|
|||
def _persist_events_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
*,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool,
|
||||
inhibit_local_membership_updates: bool = False,
|
||||
state_delta_for_room: Optional[Dict[str, DeltaState]] = None,
|
||||
new_forward_extremeties: Optional[Dict[str, List[str]]] = None,
|
||||
):
|
||||
|
@ -330,7 +343,10 @@ class PersistEventsStore:
|
|||
Args:
|
||||
txn
|
||||
events_and_contexts: events to persist
|
||||
backfilled: True if the events were backfilled
|
||||
inhibit_local_membership_updates: Stop the local_current_membership
|
||||
from being updated by these events. This should be set to True
|
||||
for backfilled events because backfilled events in the past do
|
||||
not affect the current local state.
|
||||
delete_existing True to purge existing table rows for the events
|
||||
from the database. This is useful when retrying due to
|
||||
IntegrityError.
|
||||
|
@ -363,9 +379,7 @@ class PersistEventsStore:
|
|||
events_and_contexts
|
||||
)
|
||||
|
||||
self._update_room_depths_txn(
|
||||
txn, events_and_contexts=events_and_contexts, backfilled=backfilled
|
||||
)
|
||||
self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
|
||||
|
||||
# _update_outliers_txn filters out any events which have already been
|
||||
# persisted, and returns the filtered list.
|
||||
|
@ -398,7 +412,7 @@ class PersistEventsStore:
|
|||
txn,
|
||||
events_and_contexts=events_and_contexts,
|
||||
all_events_and_contexts=all_events_and_contexts,
|
||||
backfilled=backfilled,
|
||||
inhibit_local_membership_updates=inhibit_local_membership_updates,
|
||||
)
|
||||
|
||||
# We call this last as it assumes we've inserted the events into
|
||||
|
@ -561,9 +575,9 @@ class PersistEventsStore:
|
|||
# fetch their auth event info.
|
||||
while missing_auth_chains:
|
||||
sql = """
|
||||
SELECT event_id, events.type, state_key, chain_id, sequence_number
|
||||
SELECT event_id, events.type, se.state_key, chain_id, sequence_number
|
||||
FROM events
|
||||
INNER JOIN state_events USING (event_id)
|
||||
INNER JOIN state_events AS se USING (event_id)
|
||||
LEFT JOIN event_auth_chains USING (event_id)
|
||||
WHERE
|
||||
"""
|
||||
|
@ -1200,7 +1214,6 @@ class PersistEventsStore:
|
|||
self,
|
||||
txn,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool,
|
||||
):
|
||||
"""Update min_depth for each room
|
||||
|
||||
|
@ -1208,13 +1221,18 @@ class PersistEventsStore:
|
|||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
backfilled (bool): True if the events were backfilled
|
||||
"""
|
||||
depth_updates: Dict[str, int] = {}
|
||||
for event, context in events_and_contexts:
|
||||
# Remove the any existing cache entries for the event_ids
|
||||
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
|
||||
if not backfilled:
|
||||
# Then update the `stream_ordering` position to mark the latest
|
||||
# event as the front of the room. This should not be done for
|
||||
# backfilled events because backfilled events have negative
|
||||
# stream_ordering and happened in the past so we know that we don't
|
||||
# need to update the stream_ordering tip/front for the room.
|
||||
assert event.internal_metadata.stream_ordering is not None
|
||||
if event.internal_metadata.stream_ordering >= 0:
|
||||
txn.call_after(
|
||||
self.store._events_stream_cache.entity_has_changed,
|
||||
event.room_id,
|
||||
|
@ -1427,7 +1445,12 @@ class PersistEventsStore:
|
|||
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
|
||||
|
||||
def _update_metadata_tables_txn(
|
||||
self, txn, events_and_contexts, all_events_and_contexts, backfilled
|
||||
self,
|
||||
txn,
|
||||
*,
|
||||
events_and_contexts,
|
||||
all_events_and_contexts,
|
||||
inhibit_local_membership_updates: bool = False,
|
||||
):
|
||||
"""Update all the miscellaneous tables for new events
|
||||
|
||||
|
@ -1439,7 +1462,10 @@ class PersistEventsStore:
|
|||
events that we were going to persist. This includes events
|
||||
we've already persisted, etc, that wouldn't appear in
|
||||
events_and_context.
|
||||
backfilled (bool): True if the events were backfilled
|
||||
inhibit_local_membership_updates: Stop the local_current_membership
|
||||
from being updated by these events. This should be set to True
|
||||
for backfilled events because backfilled events in the past do
|
||||
not affect the current local state.
|
||||
"""
|
||||
|
||||
# Insert all the push actions into the event_push_actions table.
|
||||
|
@ -1513,7 +1539,7 @@ class PersistEventsStore:
|
|||
for event, _ in events_and_contexts
|
||||
if event.type == EventTypes.Member
|
||||
],
|
||||
backfilled=backfilled,
|
||||
inhibit_local_membership_updates=inhibit_local_membership_updates,
|
||||
)
|
||||
|
||||
# Insert event_reference_hashes table.
|
||||
|
@ -1553,11 +1579,13 @@ class PersistEventsStore:
|
|||
for row in rows:
|
||||
event = ev_map[row["event_id"]]
|
||||
if not row["rejects"] and not row["redacts"]:
|
||||
to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
|
||||
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
|
||||
|
||||
def prefill():
|
||||
for cache_entry in to_prefill:
|
||||
self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
|
||||
self.store._get_event_cache.set(
|
||||
(cache_entry.event.event_id,), cache_entry
|
||||
)
|
||||
|
||||
txn.call_after(prefill)
|
||||
|
||||
|
@ -1638,8 +1666,19 @@ class PersistEventsStore:
|
|||
txn, table="event_reference_hashes", values=vals
|
||||
)
|
||||
|
||||
def _store_room_members_txn(self, txn, events, backfilled):
|
||||
"""Store a room member in the database."""
|
||||
def _store_room_members_txn(
|
||||
self, txn, events, *, inhibit_local_membership_updates: bool = False
|
||||
):
|
||||
"""
|
||||
Store a room member in the database.
|
||||
Args:
|
||||
txn: The transaction to use.
|
||||
events: List of events to store.
|
||||
inhibit_local_membership_updates: Stop the local_current_membership
|
||||
from being updated by these events. This should be set to True
|
||||
for backfilled events because backfilled events in the past do
|
||||
not affect the current local state.
|
||||
"""
|
||||
|
||||
def non_null_str_or_none(val: Any) -> Optional[str]:
|
||||
return val if isinstance(val, str) and "\u0000" not in val else None
|
||||
|
@ -1682,7 +1721,7 @@ class PersistEventsStore:
|
|||
# band membership", like a remote invite or a rejection of a remote invite.
|
||||
if (
|
||||
self.is_mine_id(event.state_key)
|
||||
and not backfilled
|
||||
and not inhibit_local_membership_updates
|
||||
and event.internal_metadata.is_outlier()
|
||||
and event.internal_metadata.is_out_of_band_membership()
|
||||
):
|
||||
|
|
|
@ -15,14 +15,18 @@
|
|||
import logging
|
||||
import threading
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Container,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
|
@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
|
|||
from synapse.api.room_versions import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventFormatVersions,
|
||||
RoomVersion,
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
|
@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
|||
from synapse.replication.tcp.streams import BackfillStream
|
||||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Connection
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdTracker,
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.util import unwrapFirstError
|
||||
|
@ -69,10 +82,13 @@ from synapse.util.caches.lrucache import LruCache
|
|||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# These values are used in the `enqueus_event` and `_do_fetch` methods to
|
||||
# These values are used in the `enqueue_event` and `_fetch_loop` methods to
|
||||
# control how we batch/bulk fetch events from the database.
|
||||
# The values are plucked out of thing air to make initial sync run faster
|
||||
# on jki.re
|
||||
|
@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
|
|||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class _EventCacheEntry:
|
||||
class EventCacheEntry:
|
||||
event: EventBase
|
||||
redacted_event: Optional[EventBase]
|
||||
|
||||
|
@ -129,7 +145,7 @@ class _EventRow:
|
|||
json: str
|
||||
internal_metadata: str
|
||||
format_version: Optional[int]
|
||||
room_version_id: Optional[int]
|
||||
room_version_id: Optional[str]
|
||||
rejected_reason: Optional[str]
|
||||
redactions: List[str]
|
||||
outlier: bool
|
||||
|
@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
# options controlling this.
|
||||
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._stream_id_gen: AbstractStreamIdTracker
|
||||
self._backfill_id_gen: AbstractStreamIdTracker
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
# If we're using Postgres than we can use `MultiWriterIdGenerator`
|
||||
# regardless of whether this process writes to the streams or not.
|
||||
|
@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
5 * 60 * 1000,
|
||||
)
|
||||
|
||||
self._get_event_cache = LruCache(
|
||||
self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
|
||||
cache_name="*getEvent*",
|
||||
max_size=hs.config.caches.event_cache_size,
|
||||
)
|
||||
|
@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
# ID to cache entry. Note that the returned dict may not have the
|
||||
# requested event in it if the event isn't in the DB.
|
||||
self._current_event_fetches: Dict[
|
||||
str, ObservableDeferred[Dict[str, _EventCacheEntry]]
|
||||
str, ObservableDeferred[Dict[str, EventCacheEntry]]
|
||||
] = {}
|
||||
|
||||
self._event_fetch_lock = threading.Condition()
|
||||
self._event_fetch_list = []
|
||||
self._event_fetch_list: List[
|
||||
Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
|
||||
] = []
|
||||
self._event_fetch_ongoing = 0
|
||||
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
|
||||
|
||||
# We define this sequence here so that it can be referenced from both
|
||||
# the DataStore and PersistEventStore.
|
||||
def get_chain_id_txn(txn):
|
||||
def get_chain_id_txn(txn: Cursor) -> int:
|
||||
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
|
||||
return txn.fetchone()[0]
|
||||
return cast(Tuple[int], txn.fetchone())[0]
|
||||
|
||||
self.event_chain_id_gen = build_sequence_generator(
|
||||
db_conn,
|
||||
|
@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
id_column="chain_id",
|
||||
)
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
def process_replication_rows(
|
||||
self,
|
||||
stream_name: str,
|
||||
instance_name: str,
|
||||
token: int,
|
||||
rows: Iterable[Any],
|
||||
) -> None:
|
||||
if stream_name == EventsStream.NAME:
|
||||
self._stream_id_gen.advance(instance_name, token)
|
||||
elif stream_name == BackfillStream.NAME:
|
||||
|
@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
self,
|
||||
event_id: str,
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
allow_none: Literal[False] = False,
|
||||
check_room_id: Optional[str] = None,
|
||||
get_prev_content: bool = ...,
|
||||
allow_rejected: bool = ...,
|
||||
allow_none: Literal[False] = ...,
|
||||
check_room_id: Optional[str] = ...,
|
||||
) -> EventBase:
|
||||
...
|
||||
|
||||
|
@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
self,
|
||||
event_id: str,
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
allow_none: Literal[True] = False,
|
||||
check_room_id: Optional[str] = None,
|
||||
get_prev_content: bool = ...,
|
||||
allow_rejected: bool = ...,
|
||||
allow_none: Literal[True] = ...,
|
||||
check_room_id: Optional[str] = ...,
|
||||
) -> Optional[EventBase]:
|
||||
...
|
||||
|
||||
|
@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
async def get_events(
|
||||
self,
|
||||
event_ids: Iterable[str],
|
||||
event_ids: Collection[str],
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
|
@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
async def _get_events_from_cache_or_db(
|
||||
self, event_ids: Iterable[str], allow_rejected: bool = False
|
||||
) -> Dict[str, _EventCacheEntry]:
|
||||
) -> Dict[str, EventCacheEntry]:
|
||||
"""Fetch a bunch of events from the cache or the database.
|
||||
|
||||
If events are pulled from the database, they will be cached for future lookups.
|
||||
|
@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
# same dict into itself N times).
|
||||
already_fetching_ids: Set[str] = set()
|
||||
already_fetching_deferreds: Set[
|
||||
ObservableDeferred[Dict[str, _EventCacheEntry]]
|
||||
ObservableDeferred[Dict[str, EventCacheEntry]]
|
||||
] = set()
|
||||
|
||||
for event_id in missing_events_ids:
|
||||
|
@ -601,8 +632,8 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
# function returning more events than requested, but that can happen
|
||||
# already due to `_get_events_from_db`).
|
||||
fetching_deferred: ObservableDeferred[
|
||||
Dict[str, _EventCacheEntry]
|
||||
] = ObservableDeferred(defer.Deferred())
|
||||
Dict[str, EventCacheEntry]
|
||||
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
||||
for event_id in missing_events_ids:
|
||||
self._current_event_fetches[event_id] = fetching_deferred
|
||||
|
||||
|
@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return event_entry_map
|
||||
|
||||
def _invalidate_get_event_cache(self, event_id):
|
||||
def _invalidate_get_event_cache(self, event_id: str) -> None:
|
||||
self._get_event_cache.invalidate((event_id,))
|
||||
|
||||
def _get_events_from_cache(
|
||||
self, events: Iterable[str], update_metrics: bool = True
|
||||
) -> Dict[str, _EventCacheEntry]:
|
||||
) -> Dict[str, EventCacheEntry]:
|
||||
"""Fetch events from the caches.
|
||||
|
||||
May return rejected events.
|
||||
|
@ -736,38 +767,123 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
for e in state_to_include.values()
|
||||
]
|
||||
|
||||
def _do_fetch(self, conn: Connection) -> None:
|
||||
def _maybe_start_fetch_thread(self) -> None:
|
||||
"""Starts an event fetch thread if we are not yet at the maximum number."""
|
||||
with self._event_fetch_lock:
|
||||
if (
|
||||
self._event_fetch_list
|
||||
and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
|
||||
):
|
||||
self._event_fetch_ongoing += 1
|
||||
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
|
||||
# `_event_fetch_ongoing` is decremented in `_fetch_thread`.
|
||||
should_start = True
|
||||
else:
|
||||
should_start = False
|
||||
|
||||
if should_start:
|
||||
run_as_background_process("fetch_events", self._fetch_thread)
|
||||
|
||||
async def _fetch_thread(self) -> None:
|
||||
"""Services requests for events from `_event_fetch_list`."""
|
||||
exc = None
|
||||
try:
|
||||
await self.db_pool.runWithConnection(self._fetch_loop)
|
||||
except BaseException as e:
|
||||
exc = e
|
||||
raise
|
||||
finally:
|
||||
should_restart = False
|
||||
event_fetches_to_fail = []
|
||||
with self._event_fetch_lock:
|
||||
self._event_fetch_ongoing -= 1
|
||||
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
|
||||
|
||||
# There may still be work remaining in `_event_fetch_list` if we
|
||||
# failed, or it was added in between us deciding to exit and
|
||||
# decrementing `_event_fetch_ongoing`.
|
||||
if self._event_fetch_list:
|
||||
if exc is None:
|
||||
# We decided to exit, but then some more work was added
|
||||
# before `_event_fetch_ongoing` was decremented.
|
||||
# If a new event fetch thread was not started, we should
|
||||
# restart ourselves since the remaining event fetch threads
|
||||
# may take a while to get around to the new work.
|
||||
#
|
||||
# Unfortunately it is not possible to tell whether a new
|
||||
# event fetch thread was started, so we restart
|
||||
# unconditionally. If we are unlucky, we will end up with
|
||||
# an idle fetch thread, but it will time out after
|
||||
# `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
|
||||
# in any case.
|
||||
#
|
||||
# Note that multiple fetch threads may run down this path at
|
||||
# the same time.
|
||||
should_restart = True
|
||||
elif isinstance(exc, Exception):
|
||||
if self._event_fetch_ongoing == 0:
|
||||
# We were the last remaining fetcher and failed.
|
||||
# Fail any outstanding fetches since no one else will
|
||||
# handle them.
|
||||
event_fetches_to_fail = self._event_fetch_list
|
||||
self._event_fetch_list = []
|
||||
else:
|
||||
# We weren't the last remaining fetcher, so another
|
||||
# fetcher will pick up the work. This will either happen
|
||||
# after their existing work, however long that takes,
|
||||
# or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
|
||||
# they are idle.
|
||||
pass
|
||||
else:
|
||||
# The exception is a `SystemExit`, `KeyboardInterrupt` or
|
||||
# `GeneratorExit`. Don't try to do anything clever here.
|
||||
pass
|
||||
|
||||
if should_restart:
|
||||
# We exited cleanly but noticed more work.
|
||||
self._maybe_start_fetch_thread()
|
||||
|
||||
if event_fetches_to_fail:
|
||||
# We were the last remaining fetcher and failed.
|
||||
# Fail any outstanding fetches since no one else will handle them.
|
||||
assert exc is not None
|
||||
with PreserveLoggingContext():
|
||||
for _, deferred in event_fetches_to_fail:
|
||||
deferred.errback(exc)
|
||||
|
||||
def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
|
||||
"""Takes a database connection and waits for requests for events from
|
||||
the _event_fetch_list queue.
|
||||
"""
|
||||
try:
|
||||
i = 0
|
||||
while True:
|
||||
with self._event_fetch_lock:
|
||||
event_list = self._event_fetch_list
|
||||
self._event_fetch_list = []
|
||||
i = 0
|
||||
while True:
|
||||
with self._event_fetch_lock:
|
||||
event_list = self._event_fetch_list
|
||||
self._event_fetch_list = []
|
||||
|
||||
if not event_list:
|
||||
single_threaded = self.database_engine.single_threaded
|
||||
if (
|
||||
not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
|
||||
or single_threaded
|
||||
or i > EVENT_QUEUE_ITERATIONS
|
||||
):
|
||||
break
|
||||
else:
|
||||
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
|
||||
i += 1
|
||||
continue
|
||||
i = 0
|
||||
if not event_list:
|
||||
# There are no requests waiting. If we haven't yet reached the
|
||||
# maximum iteration limit, wait for some more requests to turn up.
|
||||
# Otherwise, bail out.
|
||||
single_threaded = self.database_engine.single_threaded
|
||||
if (
|
||||
not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
|
||||
or single_threaded
|
||||
or i > EVENT_QUEUE_ITERATIONS
|
||||
):
|
||||
return
|
||||
|
||||
self._fetch_event_list(conn, event_list)
|
||||
finally:
|
||||
self._event_fetch_ongoing -= 1
|
||||
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
|
||||
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
|
||||
i += 1
|
||||
continue
|
||||
i = 0
|
||||
|
||||
self._fetch_event_list(conn, event_list)
|
||||
|
||||
def _fetch_event_list(
|
||||
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
|
||||
self,
|
||||
conn: LoggingDatabaseConnection,
|
||||
event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
|
||||
) -> None:
|
||||
"""Handle a load of requests from the _event_fetch_list queue
|
||||
|
||||
|
@ -794,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
# We only want to resolve deferreds from the main thread
|
||||
def fire():
|
||||
def fire() -> None:
|
||||
for _, d in event_list:
|
||||
d.callback(row_dict)
|
||||
|
||||
|
@ -804,18 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
logger.exception("do_fetch")
|
||||
|
||||
# We only want to resolve deferreds from the main thread
|
||||
def fire(evs, exc):
|
||||
for _, d in evs:
|
||||
if not d.called:
|
||||
with PreserveLoggingContext():
|
||||
d.errback(exc)
|
||||
def fire_errback(exc: Exception) -> None:
|
||||
for _, d in event_list:
|
||||
d.errback(exc)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.hs.get_reactor().callFromThread(fire, event_list, e)
|
||||
self.hs.get_reactor().callFromThread(fire_errback, e)
|
||||
|
||||
async def _get_events_from_db(
|
||||
self, event_ids: Iterable[str]
|
||||
) -> Dict[str, _EventCacheEntry]:
|
||||
self, event_ids: Collection[str]
|
||||
) -> Dict[str, EventCacheEntry]:
|
||||
"""Fetch a bunch of events from the database.
|
||||
|
||||
May return rejected events.
|
||||
|
@ -831,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
map from event id to result. May return extra events which
|
||||
weren't asked for.
|
||||
"""
|
||||
fetched_events = {}
|
||||
fetched_event_ids: Set[str] = set()
|
||||
fetched_events: Dict[str, _EventRow] = {}
|
||||
events_to_fetch = event_ids
|
||||
|
||||
while events_to_fetch:
|
||||
row_map = await self._enqueue_events(events_to_fetch)
|
||||
|
||||
# we need to recursively fetch any redactions of those events
|
||||
redaction_ids = set()
|
||||
redaction_ids: Set[str] = set()
|
||||
for event_id in events_to_fetch:
|
||||
row = row_map.get(event_id)
|
||||
fetched_events[event_id] = row
|
||||
fetched_event_ids.add(event_id)
|
||||
if row:
|
||||
fetched_events[event_id] = row
|
||||
redaction_ids.update(row.redactions)
|
||||
|
||||
events_to_fetch = redaction_ids.difference(fetched_events.keys())
|
||||
events_to_fetch = redaction_ids.difference(fetched_event_ids)
|
||||
if events_to_fetch:
|
||||
logger.debug("Also fetching redaction events %s", events_to_fetch)
|
||||
|
||||
# build a map from event_id to EventBase
|
||||
event_map = {}
|
||||
event_map: Dict[str, EventBase] = {}
|
||||
for event_id, row in fetched_events.items():
|
||||
if not row:
|
||||
continue
|
||||
assert row.event_id == event_id
|
||||
|
||||
rejected_reason = row.rejected_reason
|
||||
|
@ -881,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
room_version_id = row.room_version_id
|
||||
|
||||
room_version: Optional[RoomVersion]
|
||||
if not room_version_id:
|
||||
# this should only happen for out-of-band membership events which
|
||||
# arrived before #6983 landed. For all other events, we should have
|
||||
|
@ -951,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
# finally, we can decide whether each one needs redacting, and build
|
||||
# the cache entries.
|
||||
result_map = {}
|
||||
result_map: Dict[str, EventCacheEntry] = {}
|
||||
for event_id, original_ev in event_map.items():
|
||||
redactions = fetched_events[event_id].redactions
|
||||
redacted_event = self._maybe_redact_event_row(
|
||||
original_ev, redactions, event_map
|
||||
)
|
||||
|
||||
cache_entry = _EventCacheEntry(
|
||||
cache_entry = EventCacheEntry(
|
||||
event=original_ev, redacted_event=redacted_event
|
||||
)
|
||||
|
||||
|
@ -967,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return result_map
|
||||
|
||||
async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
|
||||
async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
|
||||
"""Fetches events from the database using the _event_fetch_list. This
|
||||
allows batch and bulk fetching of events - it allows us to fetch events
|
||||
without having to create a new transaction for each request for events.
|
||||
|
@ -980,23 +1095,12 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
that weren't requested.
|
||||
"""
|
||||
|
||||
events_d = defer.Deferred()
|
||||
events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
|
||||
with self._event_fetch_lock:
|
||||
self._event_fetch_list.append((events, events_d))
|
||||
|
||||
self._event_fetch_lock.notify()
|
||||
|
||||
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
|
||||
self._event_fetch_ongoing += 1
|
||||
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
|
||||
should_start = True
|
||||
else:
|
||||
should_start = False
|
||||
|
||||
if should_start:
|
||||
run_as_background_process(
|
||||
"fetch_events", self.db_pool.runWithConnection, self._do_fetch
|
||||
)
|
||||
self._maybe_start_fetch_thread()
|
||||
|
||||
logger.debug("Loading %d events: %s", len(events), events)
|
||||
with PreserveLoggingContext():
|
||||
|
@ -1146,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
# no valid redaction found for this event
|
||||
return None
|
||||
|
||||
async def have_events_in_timeline(self, event_ids):
|
||||
async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
|
||||
"""Given a list of event ids, check if we have already processed and
|
||||
stored them as non outliers.
|
||||
"""
|
||||
|
@ -1175,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
event_ids: events we are looking for
|
||||
|
||||
Returns:
|
||||
set[str]: The events we have already seen.
|
||||
The set of events we have already seen.
|
||||
"""
|
||||
res = await self._have_seen_events_dict(
|
||||
(room_id, event_id) for event_id in event_ids
|
||||
|
@ -1198,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
}
|
||||
results = {x: True for x in cache_results}
|
||||
|
||||
def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
|
||||
def have_seen_events_txn(
|
||||
txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
|
||||
) -> None:
|
||||
# we deliberately do *not* query the database for room_id, to make the
|
||||
# query an index-only lookup on `events_event_id_key`.
|
||||
#
|
||||
|
@ -1224,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
return results
|
||||
|
||||
@cached(max_entries=100000, tree=True)
|
||||
async def have_seen_event(self, room_id: str, event_id: str):
|
||||
async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
|
||||
# this only exists for the benefit of the @cachedList descriptor on
|
||||
# _have_seen_events_dict
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_current_state_event_counts_txn(self, txn, room_id):
|
||||
def _get_current_state_event_counts_txn(
|
||||
self, txn: LoggingTransaction, room_id: str
|
||||
) -> int:
|
||||
"""
|
||||
See get_current_state_event_counts.
|
||||
"""
|
||||
|
@ -1254,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
room_id,
|
||||
)
|
||||
|
||||
async def get_room_complexity(self, room_id):
|
||||
async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
|
||||
"""
|
||||
Get a rough approximation of the complexity of the room. This is used by
|
||||
remote servers to decide whether they wish to join the room or not.
|
||||
|
@ -1262,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
more resources.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID to query.
|
||||
|
||||
Returns:
|
||||
dict[str:int] of complexity version to complexity.
|
||||
dict[str:float] of complexity version to complexity.
|
||||
"""
|
||||
state_events = await self.get_current_state_event_counts(room_id)
|
||||
|
||||
|
@ -1275,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return {"v1": complexity_v1}
|
||||
|
||||
def get_current_events_token(self):
|
||||
def get_current_events_token(self) -> int:
|
||||
"""The current maximum token that events have reached"""
|
||||
return self._stream_id_gen.get_current_token()
|
||||
|
||||
async def get_all_new_forward_event_rows(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> List[Tuple]:
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||
"""Returns new events, for the Events replication stream
|
||||
|
||||
Args:
|
||||
|
@ -1295,13 +1403,15 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
EventsStreamRow.
|
||||
"""
|
||||
|
||||
def get_all_new_forward_event_rows(txn):
|
||||
def get_all_new_forward_event_rows(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN state_events AS se USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" LEFT JOIN room_memberships USING (event_id)"
|
||||
" LEFT JOIN rejections USING (event_id)"
|
||||
|
@ -1311,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, instance_name, limit))
|
||||
return txn.fetchall()
|
||||
return cast(
|
||||
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
||||
|
@ -1319,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
async def get_ex_outlier_stream_rows(
|
||||
self, instance_name: str, last_id: int, current_id: int
|
||||
) -> List[Tuple]:
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||
"""Returns de-outliered events, for the Events replication stream
|
||||
|
||||
Args:
|
||||
|
@ -1332,14 +1444,16 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
EventsStreamRow.
|
||||
"""
|
||||
|
||||
def get_ex_outlier_stream_rows_txn(txn):
|
||||
def get_ex_outlier_stream_rows_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||
sql = (
|
||||
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||
" FROM events AS e"
|
||||
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN state_events AS se USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" LEFT JOIN room_memberships USING (event_id)"
|
||||
" LEFT JOIN rejections USING (event_id)"
|
||||
|
@ -1350,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
txn.execute(sql, (last_id, current_id, instance_name))
|
||||
return txn.fetchall()
|
||||
return cast(
|
||||
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
|
||||
|
@ -1358,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
async def get_all_new_backfill_event_rows(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, list]], int, bool]:
|
||||
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
|
||||
"""Get updates for backfill replication stream, including all new
|
||||
backfilled events and events that have gone from being outliers to not.
|
||||
|
||||
|
@ -1386,13 +1502,15 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_new_backfill_event_rows(txn):
|
||||
def get_all_new_backfill_event_rows(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
|
||||
sql = (
|
||||
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" se.state_key, redacts, relates_to_id"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN state_events AS se USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? > stream_ordering AND stream_ordering >= ?"
|
||||
" AND instance_name = ?"
|
||||
|
@ -1400,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
|
||||
new_event_updates = [(row[0], row[1:]) for row in txn]
|
||||
new_event_updates: List[
|
||||
Tuple[int, Tuple[str, str, str, str, str, str]]
|
||||
] = []
|
||||
row: Tuple[int, str, str, str, str, str, str]
|
||||
# Type safety: iterating over `txn` yields `Tuple`, i.e.
|
||||
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
|
||||
# variadic tuple to a fixed length tuple and flags it up as an error.
|
||||
for row in txn: # type: ignore[assignment]
|
||||
new_event_updates.append((row[0], row[1:]))
|
||||
|
||||
limited = False
|
||||
if len(new_event_updates) == limit:
|
||||
|
@ -1411,11 +1537,11 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
sql = (
|
||||
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" se.state_key, redacts, relates_to_id"
|
||||
" FROM events AS e"
|
||||
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN state_events AS se USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? > event_stream_ordering"
|
||||
" AND event_stream_ordering >= ?"
|
||||
|
@ -1423,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
" ORDER BY event_stream_ordering DESC"
|
||||
)
|
||||
txn.execute(sql, (-last_id, -upper_bound, instance_name))
|
||||
new_event_updates.extend((row[0], row[1:]) for row in txn)
|
||||
# Type safety: iterating over `txn` yields `Tuple`, i.e.
|
||||
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
|
||||
# variadic tuple to a fixed length tuple and flags it up as an error.
|
||||
for row in txn: # type: ignore[assignment]
|
||||
new_event_updates.append((row[0], row[1:]))
|
||||
|
||||
if len(new_event_updates) >= limit:
|
||||
upper_bound = new_event_updates[-1][0]
|
||||
|
@ -1437,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
async def get_all_updated_current_state_deltas(
|
||||
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
|
||||
) -> Tuple[List[Tuple], int, bool]:
|
||||
) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
|
||||
"""Fetch updates from current_state_delta_stream
|
||||
|
||||
Args:
|
||||
|
@ -1457,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
* `limited` is whether there are more updates to fetch.
|
||||
"""
|
||||
|
||||
def get_all_updated_current_state_deltas_txn(txn):
|
||||
def get_all_updated_current_state_deltas_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Tuple[int, str, str, str, str]]:
|
||||
sql = """
|
||||
SELECT stream_id, room_id, type, state_key, event_id
|
||||
FROM current_state_delta_stream
|
||||
|
@ -1466,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
ORDER BY stream_id ASC LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
|
||||
return txn.fetchall()
|
||||
return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
|
||||
|
||||
def get_deltas_for_stream_id_txn(txn, stream_id):
|
||||
def get_deltas_for_stream_id_txn(
|
||||
txn: LoggingTransaction, stream_id: int
|
||||
) -> List[Tuple[int, str, str, str, str]]:
|
||||
sql = """
|
||||
SELECT stream_id, room_id, type, state_key, event_id
|
||||
FROM current_state_delta_stream
|
||||
WHERE stream_id = ?
|
||||
"""
|
||||
txn.execute(sql, [stream_id])
|
||||
return txn.fetchall()
|
||||
return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
|
||||
|
||||
# we need to make sure that, for every stream id in the results, we get *all*
|
||||
# the rows with that stream id.
|
||||
|
||||
rows: List[Tuple] = await self.db_pool.runInteraction(
|
||||
rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
|
||||
"get_all_updated_current_state_deltas",
|
||||
get_all_updated_current_state_deltas_txn,
|
||||
)
|
||||
|
@ -1509,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return rows, to_token, True
|
||||
|
||||
async def is_event_after(self, event_id1, event_id2):
|
||||
async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
|
||||
"""Returns True if event_id1 is after event_id2 in the stream"""
|
||||
to_1, so_1 = await self.get_event_ordering(event_id1)
|
||||
to_2, so_2 = await self.get_event_ordering(event_id2)
|
||||
return (to_1, so_1) > (to_2, so_2)
|
||||
|
||||
@cached(max_entries=5000)
|
||||
async def get_event_ordering(self, event_id):
|
||||
async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
|
||||
res = await self.db_pool.simple_select_one(
|
||||
table="events",
|
||||
retcols=["topological_ordering", "stream_ordering"],
|
||||
|
@ -1539,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
None otherwise.
|
||||
"""
|
||||
|
||||
def get_next_event_to_expire_txn(txn):
|
||||
def get_next_event_to_expire_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Tuple[str, int]]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT event_id, expiry_ts FROM event_expiry
|
||||
|
@ -1547,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
"""
|
||||
)
|
||||
|
||||
return txn.fetchone()
|
||||
return cast(Optional[Tuple[str, int]], txn.fetchone())
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
|
@ -1611,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
return mapping
|
||||
|
||||
@wrap_as_background_process("_cleanup_old_transaction_ids")
|
||||
async def _cleanup_old_transaction_ids(self):
|
||||
async def _cleanup_old_transaction_ids(self) -> None:
|
||||
"""Cleans out transaction id mappings older than 24hrs."""
|
||||
|
||||
def _cleanup_old_transaction_ids_txn(txn):
|
||||
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
|
||||
sql = """
|
||||
DELETE FROM event_txn_id
|
||||
WHERE inserted_ts < ?
|
||||
|
@ -1626,3 +1762,198 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
"_cleanup_old_transaction_ids",
|
||||
_cleanup_old_transaction_ids_txn,
|
||||
)
|
||||
|
||||
async def is_event_next_to_backward_gap(self, event: EventBase) -> bool:
|
||||
"""Check if the given event is next to a backward gap of missing events.
|
||||
<latest messages> A(False)--->B(False)--->C(True)---> <gap, unknown events> <oldest messages>
|
||||
|
||||
Args:
|
||||
room_id: room where the event lives
|
||||
event_id: event to check
|
||||
|
||||
Returns:
|
||||
Boolean indicating whether it's an extremity
|
||||
"""
|
||||
|
||||
def is_event_next_to_backward_gap_txn(txn: LoggingTransaction) -> bool:
|
||||
# If the event in question has any of its prev_events listed as a
|
||||
# backward extremity, it's next to a gap.
|
||||
#
|
||||
# We can't just check the backward edges in `event_edges` because
|
||||
# when we persist events, we will also record the prev_events as
|
||||
# edges to the event in question regardless of whether we have those
|
||||
# prev_events yet. We need to check whether those prev_events are
|
||||
# backward extremities, also known as gaps, that need to be
|
||||
# backfilled.
|
||||
backward_extremity_query = """
|
||||
SELECT 1 FROM event_backward_extremities
|
||||
WHERE
|
||||
room_id = ?
|
||||
AND %s
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
# If the event in question is a backward extremity or has any of its
|
||||
# prev_events listed as a backward extremity, it's next to a
|
||||
# backward gap.
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine,
|
||||
"event_id",
|
||||
[event.event_id] + list(event.prev_event_ids()),
|
||||
)
|
||||
|
||||
txn.execute(backward_extremity_query % (clause,), [event.room_id] + args)
|
||||
backward_extremities = txn.fetchall()
|
||||
|
||||
# We consider any backward extremity as a backward gap
|
||||
if len(backward_extremities):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"is_event_next_to_backward_gap_txn",
|
||||
is_event_next_to_backward_gap_txn,
|
||||
)
|
||||
|
||||
async def is_event_next_to_forward_gap(self, event: EventBase) -> bool:
|
||||
"""Check if the given event is next to a forward gap of missing events.
|
||||
The gap in front of the latest events is not considered a gap.
|
||||
<latest messages> A(False)--->B(False)--->C(False)---> <gap, unknown events> <oldest messages>
|
||||
<latest messages> A(False)--->B(False)---> <gap, unknown events> --->D(True)--->E(False) <oldest messages>
|
||||
|
||||
Args:
|
||||
room_id: room where the event lives
|
||||
event_id: event to check
|
||||
|
||||
Returns:
|
||||
Boolean indicating whether it's an extremity
|
||||
"""
|
||||
|
||||
def is_event_next_to_gap_txn(txn: LoggingTransaction) -> bool:
|
||||
# If the event in question is a forward extremity, we will just
|
||||
# consider any potential forward gap as not a gap since it's one of
|
||||
# the latest events in the room.
|
||||
#
|
||||
# `event_forward_extremities` does not include backfilled or outlier
|
||||
# events so we can't rely on it to find forward gaps. We can only
|
||||
# use it to determine whether a message is the latest in the room.
|
||||
#
|
||||
# We can't combine this query with the `forward_edge_query` below
|
||||
# because if the event in question has no forward edges (isn't
|
||||
# referenced by any other event's prev_events) but is in
|
||||
# `event_forward_extremities`, we don't want to return 0 rows and
|
||||
# say it's next to a gap.
|
||||
forward_extremity_query = """
|
||||
SELECT 1 FROM event_forward_extremities
|
||||
WHERE
|
||||
room_id = ?
|
||||
AND event_id = ?
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
# Check to see whether the event in question is already referenced
|
||||
# by another event. If we don't see any edges, we're next to a
|
||||
# forward gap.
|
||||
forward_edge_query = """
|
||||
SELECT 1 FROM event_edges
|
||||
/* Check to make sure the event referencing our event in question is not rejected */
|
||||
LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
|
||||
WHERE
|
||||
event_edges.room_id = ?
|
||||
AND event_edges.prev_event_id = ?
|
||||
/* It's not a valid edge if the event referencing our event in
|
||||
* question is rejected.
|
||||
*/
|
||||
AND rejections.event_id IS NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
# We consider any forward extremity as the latest in the room and
|
||||
# not a forward gap.
|
||||
#
|
||||
# To expand, even though there is technically a gap at the front of
|
||||
# the room where the forward extremities are, we consider those the
|
||||
# latest messages in the room so asking other homeservers for more
|
||||
# is useless. The new latest messages will just be federated as
|
||||
# usual.
|
||||
txn.execute(forward_extremity_query, (event.room_id, event.event_id))
|
||||
forward_extremities = txn.fetchall()
|
||||
if len(forward_extremities):
|
||||
return False
|
||||
|
||||
# If there are no forward edges to the event in question (another
|
||||
# event hasn't referenced this event in their prev_events), then we
|
||||
# assume there is a forward gap in the history.
|
||||
txn.execute(forward_edge_query, (event.room_id, event.event_id))
|
||||
forward_edges = txn.fetchall()
|
||||
if not len(forward_edges):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"is_event_next_to_gap_txn",
|
||||
is_event_next_to_gap_txn,
|
||||
)
|
||||
|
||||
async def get_event_id_for_timestamp(
|
||||
self, room_id: str, timestamp: int, direction: str
|
||||
) -> Optional[str]:
|
||||
"""Find the closest event to the given timestamp in the given direction.
|
||||
|
||||
Args:
|
||||
room_id: Room to fetch the event from
|
||||
timestamp: The point in time (inclusive) we should navigate from in
|
||||
the given direction to find the closest event.
|
||||
direction: ["f"|"b"] to indicate whether we should navigate forward
|
||||
or backward from the given timestamp to find the closest event.
|
||||
|
||||
Returns:
|
||||
The closest event_id otherwise None if we can't find any event in
|
||||
the given direction.
|
||||
"""
|
||||
|
||||
sql_template = """
|
||||
SELECT event_id FROM events
|
||||
LEFT JOIN rejections USING (event_id)
|
||||
WHERE
|
||||
origin_server_ts %s ?
|
||||
AND room_id = ?
|
||||
/* Make sure event is not rejected */
|
||||
AND rejections.event_id IS NULL
|
||||
ORDER BY origin_server_ts %s
|
||||
LIMIT 1;
|
||||
"""
|
||||
|
||||
def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
|
||||
if direction == "b":
|
||||
# Find closest event *before* a given timestamp. We use descending
|
||||
# (which gives values largest to smallest) because we want the
|
||||
# largest possible timestamp *before* the given timestamp.
|
||||
comparison_operator = "<="
|
||||
order = "DESC"
|
||||
else:
|
||||
# Find closest event *after* a given timestamp. We use ascending
|
||||
# (which gives values smallest to largest) because we want the
|
||||
# closest possible timestamp *after* the given timestamp.
|
||||
comparison_operator = ">="
|
||||
order = "ASC"
|
||||
|
||||
txn.execute(
|
||||
sql_template % (comparison_operator, order), (timestamp, room_id)
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
(event_id,) = row
|
||||
return event_id
|
||||
|
||||
return None
|
||||
|
||||
if direction not in ("f", "b"):
|
||||
raise ValueError("Unknown direction: %s" % (direction,))
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_event_id_for_timestamp_txn",
|
||||
get_event_id_for_timestamp_txn,
|
||||
)
|
||||
|
|
|
@ -118,7 +118,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
|||
|
||||
logger.info("[purge] looking for events to delete")
|
||||
|
||||
should_delete_expr = "state_key IS NULL"
|
||||
should_delete_expr = "state_events.state_key IS NULL"
|
||||
should_delete_params: Tuple[Any, ...] = ()
|
||||
if not delete_local_events:
|
||||
should_delete_expr += " AND event_id NOT LIKE ?"
|
||||
|
|
|
@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
|||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdTracker,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
@ -82,9 +85,9 @@ class PushRulesWorkerStore(
|
|||
super().__init__(database, db_conn, hs)
|
||||
|
||||
if hs.config.worker.worker_app is None:
|
||||
self._push_rules_stream_id_gen: Union[
|
||||
StreamIdGenerator, SlavedIdTracker
|
||||
] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
|
||||
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
|
||||
db_conn, "push_rules_stream", "stream_id"
|
||||
)
|
||||
else:
|
||||
self._push_rules_stream_id_gen = SlavedIdTracker(
|
||||
db_conn, "push_rules_stream", "stream_id"
|
||||
|
|
|
@ -106,6 +106,15 @@ class RefreshTokenLookupResult:
|
|||
has_next_access_token_been_used: bool
|
||||
"""True if the next access token was already used at least once."""
|
||||
|
||||
expiry_ts: Optional[int]
|
||||
"""The time at which the refresh token expires and can not be used.
|
||||
If None, the refresh token doesn't expire."""
|
||||
|
||||
ultimate_session_expiry_ts: Optional[int]
|
||||
"""The time at which the session comes to an end and can no longer be
|
||||
refreshed.
|
||||
If None, the session can be refreshed indefinitely."""
|
||||
|
||||
|
||||
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
def __init__(
|
||||
|
@ -1626,8 +1635,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
rt.user_id,
|
||||
rt.device_id,
|
||||
rt.next_token_id,
|
||||
(nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
|
||||
at.used has_next_access_token_been_used
|
||||
(nrt.next_token_id IS NOT NULL) AS has_next_refresh_token_been_refreshed,
|
||||
at.used AS has_next_access_token_been_used,
|
||||
rt.expiry_ts,
|
||||
rt.ultimate_session_expiry_ts
|
||||
FROM refresh_tokens rt
|
||||
LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
|
||||
LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
|
||||
|
@ -1648,6 +1659,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
has_next_refresh_token_been_refreshed=row[4],
|
||||
# This column is nullable, ensure it's a boolean
|
||||
has_next_access_token_been_used=(row[5] or False),
|
||||
expiry_ts=row[6],
|
||||
ultimate_session_expiry_ts=row[7],
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -1915,6 +1928,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
user_id: str,
|
||||
token: str,
|
||||
device_id: Optional[str],
|
||||
expiry_ts: Optional[int],
|
||||
ultimate_session_expiry_ts: Optional[int],
|
||||
) -> int:
|
||||
"""Adds a refresh token for the given user.
|
||||
|
||||
|
@ -1922,6 +1937,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
user_id: The user ID.
|
||||
token: The new access token to add.
|
||||
device_id: ID of the device to associate with the refresh token.
|
||||
expiry_ts (milliseconds since the epoch): Time after which the
|
||||
refresh token cannot be used.
|
||||
If None, the refresh token never expires until it has been used.
|
||||
ultimate_session_expiry_ts (milliseconds since the epoch):
|
||||
Time at which the session will end and can not be extended any
|
||||
further.
|
||||
If None, the session can be refreshed indefinitely.
|
||||
Raises:
|
||||
StoreError if there was a problem adding this.
|
||||
Returns:
|
||||
|
@ -1937,6 +1959,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
|||
"device_id": device_id,
|
||||
"token": token,
|
||||
"next_token_id": None,
|
||||
"expiry_ts": expiry_ts,
|
||||
"ultimate_session_expiry_ts": ultimate_session_expiry_ts,
|
||||
},
|
||||
desc="add_refresh_token_to_user",
|
||||
)
|
||||
|
|
|
@ -476,7 +476,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
WHERE
|
||||
c.type = 'm.room.member'
|
||||
AND state_key = ?
|
||||
AND c.state_key = ?
|
||||
AND c.membership = ?
|
||||
"""
|
||||
else:
|
||||
|
@ -487,7 +487,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
INNER JOIN events AS e USING (room_id, event_id)
|
||||
WHERE
|
||||
c.type = 'm.room.member'
|
||||
AND state_key = ?
|
||||
AND c.state_key = ?
|
||||
AND m.membership = ?
|
||||
"""
|
||||
|
||||
|
|
|
@ -497,7 +497,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
oldest `limit` events.
|
||||
|
||||
Returns:
|
||||
The list of events (in ascending order) and the token from the start
|
||||
The list of events (in ascending stream order) and the token from the start
|
||||
of the chunk of events returned.
|
||||
"""
|
||||
if from_key == to_key:
|
||||
|
@ -510,7 +510,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
if not has_changed:
|
||||
return [], from_key
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
|
||||
# To handle tokens with a non-empty instance_map we fetch more
|
||||
# results than necessary and then filter down
|
||||
min_from_id = from_key.stream
|
||||
|
@ -565,6 +565,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
async def get_membership_changes_for_user(
|
||||
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
|
||||
) -> List[EventBase]:
|
||||
"""Fetch membership events for a given user.
|
||||
|
||||
All such events whose stream ordering `s` lies in the range
|
||||
`from_key < s <= to_key` are returned. Events are ordered by ascending stream
|
||||
order.
|
||||
"""
|
||||
# Start by ruling out cases where a DB query is not necessary.
|
||||
if from_key == to_key:
|
||||
return []
|
||||
|
||||
|
@ -575,7 +582,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
if not has_changed:
|
||||
return []
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
|
||||
# To handle tokens with a non-empty instance_map we fetch more
|
||||
# results than necessary and then filter down
|
||||
min_from_id = from_key.stream
|
||||
|
@ -634,7 +641,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
|
||||
Returns:
|
||||
A list of events and a token pointing to the start of the returned
|
||||
events. The events returned are in ascending order.
|
||||
events. The events returned are in ascending topological order.
|
||||
"""
|
||||
|
||||
rows, token = await self.get_recent_event_ids_for_room(
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
@ -44,6 +45,16 @@ _UpdateTransactionRow = namedtuple(
|
|||
)
|
||||
|
||||
|
||||
class DestinationSortOrder(Enum):
|
||||
"""Enum to define the sorting method used when returning destinations."""
|
||||
|
||||
DESTINATION = "destination"
|
||||
RETRY_LAST_TS = "retry_last_ts"
|
||||
RETTRY_INTERVAL = "retry_interval"
|
||||
FAILURE_TS = "failure_ts"
|
||||
LAST_SUCCESSFUL_STREAM_ORDERING = "last_successful_stream_ordering"
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class DestinationRetryTimings:
|
||||
"""The current destination retry timing info for a remote server."""
|
||||
|
@ -480,3 +491,62 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
|||
|
||||
destinations = [row[0] for row in txn]
|
||||
return destinations
|
||||
|
||||
async def get_destinations_paginate(
|
||||
self,
|
||||
start: int,
|
||||
limit: int,
|
||||
destination: Optional[str] = None,
|
||||
order_by: str = DestinationSortOrder.DESTINATION.value,
|
||||
direction: str = "f",
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
"""Function to retrieve a paginated list of destinations.
|
||||
This will return a json list of destinations and the
|
||||
total number of destinations matching the filter criteria.
|
||||
|
||||
Args:
|
||||
start: start number to begin the query from
|
||||
limit: number of rows to retrieve
|
||||
destination: search string in destination
|
||||
order_by: the sort order of the returned list
|
||||
direction: sort ascending or descending
|
||||
Returns:
|
||||
A tuple of a list of mappings from destination to information
|
||||
and a count of total destinations.
|
||||
"""
|
||||
|
||||
def get_destinations_paginate_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
order_by_column = DestinationSortOrder(order_by).value
|
||||
|
||||
if direction == "b":
|
||||
order = "DESC"
|
||||
else:
|
||||
order = "ASC"
|
||||
|
||||
args = []
|
||||
where_statement = ""
|
||||
if destination:
|
||||
args.extend(["%" + destination.lower() + "%"])
|
||||
where_statement = "WHERE LOWER(destination) LIKE ?"
|
||||
|
||||
sql_base = f"FROM destinations {where_statement} "
|
||||
sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
|
||||
txn.execute(sql, args)
|
||||
count = txn.fetchone()[0]
|
||||
|
||||
sql = f"""
|
||||
SELECT destination, retry_last_ts, retry_interval, failure_ts,
|
||||
last_successful_stream_ordering
|
||||
{sql_base}
|
||||
ORDER BY {order_by_column} {order}, destination ASC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
txn.execute(sql, args + [limit, start])
|
||||
destinations = self.db_pool.cursor_to_dict(txn)
|
||||
return destinations, count
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_destinations_paginate_txn", get_destinations_paginate_txn
|
||||
)
|
||||
|
|
|
@ -583,7 +583,8 @@ class EventsPersistenceStorage:
|
|||
current_state_for_room=current_state_for_room,
|
||||
state_delta_for_room=state_delta_for_room,
|
||||
new_forward_extremeties=new_forward_extremeties,
|
||||
backfilled=backfilled,
|
||||
use_negative_stream_ordering=backfilled,
|
||||
inhibit_local_membership_updates=backfilled,
|
||||
)
|
||||
|
||||
await self._handle_potentially_left_users(potentially_left_users)
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCHEMA_VERSION = 65 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 66 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
@ -46,6 +46,10 @@ Changes in SCHEMA_VERSION = 65:
|
|||
- MSC2716: Remove unique event_id constraint from insertion_event_edges
|
||||
because an insertion event can have multiple edges.
|
||||
- Remove unused tables `user_stats_historical` and `room_stats_historical`.
|
||||
|
||||
Changes in SCHEMA_VERSION = 66:
|
||||
- Queries on state_key columns are now disambiguated (ie, the codebase can handle
|
||||
the `events` table having a `state_key` column).
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
ALTER TABLE refresh_tokens
|
||||
-- We add an expiry_ts column (in milliseconds since the Epoch) to refresh tokens.
|
||||
-- They may not be used after they have expired.
|
||||
-- If null, then the refresh token's lifetime is unlimited.
|
||||
ADD COLUMN expiry_ts BIGINT DEFAULT NULL;
|
||||
|
||||
ALTER TABLE refresh_tokens
|
||||
-- We also add an ultimate session expiry time (in milliseconds since the Epoch).
|
||||
-- No matter how much the access and refresh tokens are refreshed, they cannot
|
||||
-- be extended past this time.
|
||||
-- If null, then the session length is unlimited.
|
||||
ADD COLUMN ultimate_session_expiry_ts BIGINT DEFAULT NULL;
|
|
@ -0,0 +1,27 @@
|
|||
/* Copyright 2021 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Track the auth provider used by each login as well as the session ID
|
||||
CREATE TABLE device_auth_providers (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
auth_provider_id TEXT NOT NULL,
|
||||
auth_provider_session_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX device_auth_providers_devices
|
||||
ON device_auth_providers (user_id, device_id);
|
||||
CREATE INDEX device_auth_providers_sessions
|
||||
ON device_auth_providers (auth_provider_id, auth_provider_session_id);
|
|
@ -89,31 +89,77 @@ def _load_current_id(
|
|||
return (max if step > 0 else min)(current_id, step)
|
||||
|
||||
|
||||
class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def get_next(self) -> AsyncContextManager[int]:
|
||||
raise NotImplementedError()
|
||||
class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
|
||||
"""Tracks the "current" stream ID of a stream that may have multiple writers.
|
||||
|
||||
Stream IDs are monotonically increasing or decreasing integers representing write
|
||||
transactions. The "current" stream ID is the stream ID such that all transactions
|
||||
with equal or smaller stream IDs have completed. Since transactions may complete out
|
||||
of order, this is not the same as the stream ID of the last completed transaction.
|
||||
|
||||
Completed transactions include both committed transactions and transactions that
|
||||
have been rolled back.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||
def advance(self, instance_name: str, new_id: int) -> None:
|
||||
"""Advance the position of the named writer to the given ID, if greater
|
||||
than existing entry.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_token(self) -> int:
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
equal to it have been successfully persisted.
|
||||
|
||||
Returns:
|
||||
The maximum stream id.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
"""Returns the position of the given writer.
|
||||
|
||||
For streams with single writers this is equivalent to `get_current_token`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AbstractStreamIdGenerator(AbstractStreamIdTracker):
|
||||
"""Generates stream IDs for a stream that may have multiple writers.
|
||||
|
||||
Each stream ID represents a write transaction, whose completion is tracked
|
||||
so that the "current" stream ID of the stream can be determined.
|
||||
|
||||
See `AbstractStreamIdTracker` for more details.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_next(self) -> AsyncContextManager[int]:
|
||||
"""
|
||||
Usage:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||
"""
|
||||
Usage:
|
||||
async with stream_id_gen.get_next(n) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||
"""Used to generate new stream ids when persisting events while keeping
|
||||
track of which transactions have been completed.
|
||||
"""Generates and tracks stream IDs for a stream with a single writer.
|
||||
|
||||
This allows us to get the "current" stream id, i.e. the stream id such that
|
||||
all ids less than or equal to it have completed. This handles the fact that
|
||||
persistence of events can complete out of order.
|
||||
This class must only be used when the current Synapse process is the sole
|
||||
writer for a stream.
|
||||
|
||||
Args:
|
||||
db_conn(connection): A database connection to use to fetch the
|
||||
|
@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
|||
# The key and values are the same, but we never look at the values.
|
||||
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
||||
|
||||
def advance(self, instance_name: str, new_id: int) -> None:
|
||||
# `StreamIdGenerator` should only be used when there is a single writer,
|
||||
# so replication should never happen.
|
||||
raise Exception("Replication is not supported by StreamIdGenerator")
|
||||
|
||||
def get_next(self) -> AsyncContextManager[int]:
|
||||
"""
|
||||
Usage:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
with self._lock:
|
||||
self._current += self._step
|
||||
next_id = self._current
|
||||
|
@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
|||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||
"""
|
||||
Usage:
|
||||
async with stream_id_gen.get_next(n) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
with self._lock:
|
||||
next_ids = range(
|
||||
self._current + self._step,
|
||||
|
@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
|||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
equal to it have been successfully persisted.
|
||||
|
||||
Returns:
|
||||
The maximum stream id.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._unfinished_ids:
|
||||
return next(iter(self._unfinished_ids)) - self._step
|
||||
|
@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
|||
return self._current
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
"""Returns the position of the given writer.
|
||||
|
||||
For streams with single writers this is equivalent to
|
||||
`get_current_token`.
|
||||
"""
|
||||
return self.get_current_token()
|
||||
|
||||
|
||||
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||
"""An ID generator that tracks a stream that can have multiple writers.
|
||||
"""Generates and tracks stream IDs for a stream with multiple writers.
|
||||
|
||||
Uses a Postgres sequence to coordinate ID assignment, but positions of other
|
||||
writers will only get updated when `advance` is called (by replication).
|
||||
|
@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
return stream_ids
|
||||
|
||||
def get_next(self) -> AsyncContextManager[int]:
|
||||
"""
|
||||
Usage:
|
||||
async with stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
# If we have a list of instances that are allowed to write to this
|
||||
# stream, make sure we're in it.
|
||||
if self._writers and self._instance_name not in self._writers:
|
||||
|
@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
|
||||
|
||||
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
|
||||
"""
|
||||
Usage:
|
||||
async with stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
|
||||
# If we have a list of instances that are allowed to write to this
|
||||
# stream, make sure we're in it.
|
||||
if self._writers and self._instance_name not in self._writers:
|
||||
|
@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
self._add_persisted_position(next_id)
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
equal to it have been successfully persisted.
|
||||
"""
|
||||
|
||||
return self.get_persisted_upto_position()
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
"""Returns the position of the given writer."""
|
||||
|
||||
# If we don't have an entry for the given instance name, we assume it's a
|
||||
# new writer.
|
||||
#
|
||||
|
@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||
}
|
||||
|
||||
def advance(self, instance_name: str, new_id: int) -> None:
|
||||
"""Advance the position of the named writer to the given ID, if greater
|
||||
than existing entry.
|
||||
"""
|
||||
|
||||
new_id *= self._return_factor
|
||||
|
||||
with self._lock:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue