mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-25 23:49:42 -05:00
Add type hints to synapse/storage/databases/main/events_worker.py
(#11411)
Also refactor the stream ID trackers/generators a bit and try to document them better.
This commit is contained in:
parent
1d8b80b334
commit
ffd858aa68
1
changelog.d/11411.misc
Normal file
1
changelog.d/11411.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add type hints to storage classes.
|
4
mypy.ini
4
mypy.ini
@ -33,7 +33,6 @@ exclude = (?x)
|
|||||||
|synapse/storage/databases/main/event_federation.py
|
|synapse/storage/databases/main/event_federation.py
|
||||||
|synapse/storage/databases/main/event_push_actions.py
|
|synapse/storage/databases/main/event_push_actions.py
|
||||||
|synapse/storage/databases/main/events_bg_updates.py
|
|synapse/storage/databases/main/events_bg_updates.py
|
||||||
|synapse/storage/databases/main/events_worker.py
|
|
||||||
|synapse/storage/databases/main/group_server.py
|
|synapse/storage/databases/main/group_server.py
|
||||||
|synapse/storage/databases/main/metrics.py
|
|synapse/storage/databases/main/metrics.py
|
||||||
|synapse/storage/databases/main/monthly_active_users.py
|
|synapse/storage/databases/main/monthly_active_users.py
|
||||||
@ -184,6 +183,9 @@ disallow_untyped_defs = True
|
|||||||
[mypy-synapse.storage.databases.main.directory]
|
[mypy-synapse.storage.databases.main.directory]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.databases.main.events_worker]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.storage.databases.main.room_batch]
|
[mypy-synapse.storage.databases.main.room_batch]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -14,10 +14,18 @@
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.storage.database import LoggingDatabaseConnection
|
from synapse.storage.database import LoggingDatabaseConnection
|
||||||
from synapse.storage.util.id_generators import _load_current_id
|
from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id
|
||||||
|
|
||||||
|
|
||||||
class SlavedIdTracker:
|
class SlavedIdTracker(AbstractStreamIdTracker):
|
||||||
|
"""Tracks the "current" stream ID of a stream with a single writer.
|
||||||
|
|
||||||
|
See `AbstractStreamIdTracker` for more details.
|
||||||
|
|
||||||
|
Note that this class does not work correctly when there are multiple
|
||||||
|
writers.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
db_conn: LoggingDatabaseConnection,
|
db_conn: LoggingDatabaseConnection,
|
||||||
@ -36,17 +44,7 @@ class SlavedIdTracker:
|
|||||||
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
||||||
|
|
||||||
def get_current_token(self) -> int:
|
def get_current_token(self) -> int:
|
||||||
"""
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int
|
|
||||||
"""
|
|
||||||
return self._current
|
return self._current
|
||||||
|
|
||||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||||
"""Returns the position of the given writer.
|
|
||||||
|
|
||||||
For streams with single writers this is equivalent to
|
|
||||||
`get_current_token`.
|
|
||||||
"""
|
|
||||||
return self.get_current_token()
|
return self.get_current_token()
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
|
||||||
from synapse.replication.tcp.streams import PushRulesStream
|
from synapse.replication.tcp.streams import PushRulesStream
|
||||||
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
||||||
|
|
||||||
@ -25,9 +24,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
|||||||
return self._push_rules_stream_id_gen.get_current_token()
|
return self._push_rules_stream_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
# We assert this for the benefit of mypy
|
|
||||||
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
|
|
||||||
|
|
||||||
if stream_name == PushRulesStream.NAME:
|
if stream_name == PushRulesStream.NAME:
|
||||||
self._push_rules_stream_id_gen.advance(instance_name, token)
|
self._push_rules_stream_id_gen.advance(instance_name, token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import heapq
|
import heapq
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
|
from typing import TYPE_CHECKING, Optional, Tuple, Type
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
@ -157,7 +157,7 @@ class EventsStream(Stream):
|
|||||||
|
|
||||||
# now we fetch up to that many rows from the events table
|
# now we fetch up to that many rows from the events table
|
||||||
|
|
||||||
event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
|
event_rows = await self._store.get_all_new_forward_event_rows(
|
||||||
instance_name, from_token, current_token, target_row_count
|
instance_name, from_token, current_token, target_row_count
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -191,7 +191,7 @@ class EventsStream(Stream):
|
|||||||
# finally, fetch the ex-outliers rows. We assume there are few enough of these
|
# finally, fetch the ex-outliers rows. We assume there are few enough of these
|
||||||
# not to bother with the limit.
|
# not to bother with the limit.
|
||||||
|
|
||||||
ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
|
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
|
||||||
instance_name, from_token, upper_limit
|
instance_name, from_token, upper_limit
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -764,7 +764,7 @@ class StateResolutionStore:
|
|||||||
store: "DataStore"
|
store: "DataStore"
|
||||||
|
|
||||||
def get_events(
|
def get_events(
|
||||||
self, event_ids: Iterable[str], allow_rejected: bool = False
|
self, event_ids: Collection[str], allow_rejected: bool = False
|
||||||
) -> Awaitable[Dict[str, EventBase]]:
|
) -> Awaitable[Dict[str, EventBase]]:
|
||||||
"""Get events from the database
|
"""Get events from the database
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ import logging
|
|||||||
from typing import (
|
from typing import (
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
@ -44,7 +45,7 @@ async def resolve_events_with_store(
|
|||||||
room_version: RoomVersion,
|
room_version: RoomVersion,
|
||||||
state_sets: Sequence[StateMap[str]],
|
state_sets: Sequence[StateMap[str]],
|
||||||
event_map: Optional[Dict[str, EventBase]],
|
event_map: Optional[Dict[str, EventBase]],
|
||||||
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
|
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -21,7 +21,7 @@ from synapse.storage.database import LoggingTransaction # noqa: F401
|
|||||||
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
|
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.types import Connection
|
from synapse.storage.types import Connection
|
||||||
from synapse.types import StreamToken, get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -48,7 +48,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
|||||||
self,
|
self,
|
||||||
stream_name: str,
|
stream_name: str,
|
||||||
instance_name: str,
|
instance_name: str,
|
||||||
token: StreamToken,
|
token: int,
|
||||||
rows: Iterable[Any],
|
rows: Iterable[Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict, namedtuple
|
from collections import OrderedDict
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -41,9 +41,10 @@ from synapse.events.snapshot import EventContext # noqa: F401
|
|||||||
from synapse.logging.utils import log_function
|
from synapse.logging.utils import log_function
|
||||||
from synapse.storage._base import db_to_json, make_in_list_sql_clause
|
from synapse.storage._base import db_to_json, make_in_list_sql_clause
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
|
from synapse.storage.databases.main.events_worker import EventCacheEntry
|
||||||
from synapse.storage.databases.main.search import SearchEntry
|
from synapse.storage.databases.main.search import SearchEntry
|
||||||
from synapse.storage.types import Connection
|
from synapse.storage.types import Connection
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
|
||||||
from synapse.storage.util.sequence import SequenceGenerator
|
from synapse.storage.util.sequence import SequenceGenerator
|
||||||
from synapse.types import StateMap, get_domain_from_id
|
from synapse.types import StateMap, get_domain_from_id
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
@ -64,9 +65,6 @@ event_counter = Counter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
class DeltaState:
|
class DeltaState:
|
||||||
"""Deltas to use to update the `current_state_events` table.
|
"""Deltas to use to update the `current_state_events` table.
|
||||||
@ -108,16 +106,21 @@ class PersistEventsStore:
|
|||||||
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
|
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
|
|
||||||
# Ideally we'd move these ID gens here, unfortunately some other ID
|
|
||||||
# generators are chained off them so doing so is a bit of a PITA.
|
|
||||||
self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
|
|
||||||
self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen
|
|
||||||
|
|
||||||
# This should only exist on instances that are configured to write
|
# This should only exist on instances that are configured to write
|
||||||
assert (
|
assert (
|
||||||
hs.get_instance_name() in hs.config.worker.writers.events
|
hs.get_instance_name() in hs.config.worker.writers.events
|
||||||
), "Can only instantiate EventsStore on master"
|
), "Can only instantiate EventsStore on master"
|
||||||
|
|
||||||
|
# Since we have been configured to write, we ought to have id generators,
|
||||||
|
# rather than id trackers.
|
||||||
|
assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
|
||||||
|
assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)
|
||||||
|
|
||||||
|
# Ideally we'd move these ID gens here, unfortunately some other ID
|
||||||
|
# generators are chained off them so doing so is a bit of a PITA.
|
||||||
|
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
|
||||||
|
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
|
||||||
|
|
||||||
async def _persist_events_and_state_updates(
|
async def _persist_events_and_state_updates(
|
||||||
self,
|
self,
|
||||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||||
@ -1553,11 +1556,13 @@ class PersistEventsStore:
|
|||||||
for row in rows:
|
for row in rows:
|
||||||
event = ev_map[row["event_id"]]
|
event = ev_map[row["event_id"]]
|
||||||
if not row["rejects"] and not row["redacts"]:
|
if not row["rejects"] and not row["redacts"]:
|
||||||
to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
|
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
|
||||||
|
|
||||||
def prefill():
|
def prefill():
|
||||||
for cache_entry in to_prefill:
|
for cache_entry in to_prefill:
|
||||||
self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
|
self.store._get_event_cache.set(
|
||||||
|
(cache_entry.event.event_id,), cache_entry
|
||||||
|
)
|
||||||
|
|
||||||
txn.call_after(prefill)
|
txn.call_after(prefill)
|
||||||
|
|
||||||
|
@ -15,14 +15,18 @@
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
Collection,
|
Collection,
|
||||||
Container,
|
Container,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
cast,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -38,6 +42,7 @@ from synapse.api.errors import NotFoundError, SynapseError
|
|||||||
from synapse.api.room_versions import (
|
from synapse.api.room_versions import (
|
||||||
KNOWN_ROOM_VERSIONS,
|
KNOWN_ROOM_VERSIONS,
|
||||||
EventFormatVersions,
|
EventFormatVersions,
|
||||||
|
RoomVersion,
|
||||||
RoomVersions,
|
RoomVersions,
|
||||||
)
|
)
|
||||||
from synapse.events import EventBase, make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
@ -56,10 +61,18 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
|||||||
from synapse.replication.tcp.streams import BackfillStream
|
from synapse.replication.tcp.streams import BackfillStream
|
||||||
from synapse.replication.tcp.streams.events import EventsStream
|
from synapse.replication.tcp.streams.events import EventsStream
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.types import Connection
|
from synapse.storage.types import Cursor
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import (
|
||||||
|
AbstractStreamIdTracker,
|
||||||
|
MultiWriterIdGenerator,
|
||||||
|
StreamIdGenerator,
|
||||||
|
)
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict, get_domain_from_id
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
@ -69,6 +82,9 @@ from synapse.util.caches.lrucache import LruCache
|
|||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -89,7 +105,7 @@ event_fetch_ongoing_gauge = Gauge(
|
|||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True, auto_attribs=True)
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
class _EventCacheEntry:
|
class EventCacheEntry:
|
||||||
event: EventBase
|
event: EventBase
|
||||||
redacted_event: Optional[EventBase]
|
redacted_event: Optional[EventBase]
|
||||||
|
|
||||||
@ -129,7 +145,7 @@ class _EventRow:
|
|||||||
json: str
|
json: str
|
||||||
internal_metadata: str
|
internal_metadata: str
|
||||||
format_version: Optional[int]
|
format_version: Optional[int]
|
||||||
room_version_id: Optional[int]
|
room_version_id: Optional[str]
|
||||||
rejected_reason: Optional[str]
|
rejected_reason: Optional[str]
|
||||||
redactions: List[str]
|
redactions: List[str]
|
||||||
outlier: bool
|
outlier: bool
|
||||||
@ -153,9 +169,16 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
# options controlling this.
|
# options controlling this.
|
||||||
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
|
USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = True
|
||||||
|
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
self._stream_id_gen: AbstractStreamIdTracker
|
||||||
|
self._backfill_id_gen: AbstractStreamIdTracker
|
||||||
if isinstance(database.engine, PostgresEngine):
|
if isinstance(database.engine, PostgresEngine):
|
||||||
# If we're using Postgres than we can use `MultiWriterIdGenerator`
|
# If we're using Postgres than we can use `MultiWriterIdGenerator`
|
||||||
# regardless of whether this process writes to the streams or not.
|
# regardless of whether this process writes to the streams or not.
|
||||||
@ -214,7 +237,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
5 * 60 * 1000,
|
5 * 60 * 1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._get_event_cache = LruCache(
|
self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
|
||||||
cache_name="*getEvent*",
|
cache_name="*getEvent*",
|
||||||
max_size=hs.config.caches.event_cache_size,
|
max_size=hs.config.caches.event_cache_size,
|
||||||
)
|
)
|
||||||
@ -223,19 +246,21 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
# ID to cache entry. Note that the returned dict may not have the
|
# ID to cache entry. Note that the returned dict may not have the
|
||||||
# requested event in it if the event isn't in the DB.
|
# requested event in it if the event isn't in the DB.
|
||||||
self._current_event_fetches: Dict[
|
self._current_event_fetches: Dict[
|
||||||
str, ObservableDeferred[Dict[str, _EventCacheEntry]]
|
str, ObservableDeferred[Dict[str, EventCacheEntry]]
|
||||||
] = {}
|
] = {}
|
||||||
|
|
||||||
self._event_fetch_lock = threading.Condition()
|
self._event_fetch_lock = threading.Condition()
|
||||||
self._event_fetch_list = []
|
self._event_fetch_list: List[
|
||||||
|
Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
|
||||||
|
] = []
|
||||||
self._event_fetch_ongoing = 0
|
self._event_fetch_ongoing = 0
|
||||||
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
|
event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
|
||||||
|
|
||||||
# We define this sequence here so that it can be referenced from both
|
# We define this sequence here so that it can be referenced from both
|
||||||
# the DataStore and PersistEventStore.
|
# the DataStore and PersistEventStore.
|
||||||
def get_chain_id_txn(txn):
|
def get_chain_id_txn(txn: Cursor) -> int:
|
||||||
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
|
txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
|
||||||
return txn.fetchone()[0]
|
return cast(Tuple[int], txn.fetchone())[0]
|
||||||
|
|
||||||
self.event_chain_id_gen = build_sequence_generator(
|
self.event_chain_id_gen = build_sequence_generator(
|
||||||
db_conn,
|
db_conn,
|
||||||
@ -246,7 +271,13 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
id_column="chain_id",
|
id_column="chain_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(
|
||||||
|
self,
|
||||||
|
stream_name: str,
|
||||||
|
instance_name: str,
|
||||||
|
token: int,
|
||||||
|
rows: Iterable[Any],
|
||||||
|
) -> None:
|
||||||
if stream_name == EventsStream.NAME:
|
if stream_name == EventsStream.NAME:
|
||||||
self._stream_id_gen.advance(instance_name, token)
|
self._stream_id_gen.advance(instance_name, token)
|
||||||
elif stream_name == BackfillStream.NAME:
|
elif stream_name == BackfillStream.NAME:
|
||||||
@ -280,10 +311,10 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
self,
|
self,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||||
get_prev_content: bool = False,
|
get_prev_content: bool = ...,
|
||||||
allow_rejected: bool = False,
|
allow_rejected: bool = ...,
|
||||||
allow_none: Literal[False] = False,
|
allow_none: Literal[False] = ...,
|
||||||
check_room_id: Optional[str] = None,
|
check_room_id: Optional[str] = ...,
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -292,10 +323,10 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
self,
|
self,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||||
get_prev_content: bool = False,
|
get_prev_content: bool = ...,
|
||||||
allow_rejected: bool = False,
|
allow_rejected: bool = ...,
|
||||||
allow_none: Literal[True] = False,
|
allow_none: Literal[True] = ...,
|
||||||
check_room_id: Optional[str] = None,
|
check_room_id: Optional[str] = ...,
|
||||||
) -> Optional[EventBase]:
|
) -> Optional[EventBase]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -357,7 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
async def get_events(
|
async def get_events(
|
||||||
self,
|
self,
|
||||||
event_ids: Iterable[str],
|
event_ids: Collection[str],
|
||||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||||
get_prev_content: bool = False,
|
get_prev_content: bool = False,
|
||||||
allow_rejected: bool = False,
|
allow_rejected: bool = False,
|
||||||
@ -544,7 +575,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
async def _get_events_from_cache_or_db(
|
async def _get_events_from_cache_or_db(
|
||||||
self, event_ids: Iterable[str], allow_rejected: bool = False
|
self, event_ids: Iterable[str], allow_rejected: bool = False
|
||||||
) -> Dict[str, _EventCacheEntry]:
|
) -> Dict[str, EventCacheEntry]:
|
||||||
"""Fetch a bunch of events from the cache or the database.
|
"""Fetch a bunch of events from the cache or the database.
|
||||||
|
|
||||||
If events are pulled from the database, they will be cached for future lookups.
|
If events are pulled from the database, they will be cached for future lookups.
|
||||||
@ -578,7 +609,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
# same dict into itself N times).
|
# same dict into itself N times).
|
||||||
already_fetching_ids: Set[str] = set()
|
already_fetching_ids: Set[str] = set()
|
||||||
already_fetching_deferreds: Set[
|
already_fetching_deferreds: Set[
|
||||||
ObservableDeferred[Dict[str, _EventCacheEntry]]
|
ObservableDeferred[Dict[str, EventCacheEntry]]
|
||||||
] = set()
|
] = set()
|
||||||
|
|
||||||
for event_id in missing_events_ids:
|
for event_id in missing_events_ids:
|
||||||
@ -601,7 +632,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
# function returning more events than requested, but that can happen
|
# function returning more events than requested, but that can happen
|
||||||
# already due to `_get_events_from_db`).
|
# already due to `_get_events_from_db`).
|
||||||
fetching_deferred: ObservableDeferred[
|
fetching_deferred: ObservableDeferred[
|
||||||
Dict[str, _EventCacheEntry]
|
Dict[str, EventCacheEntry]
|
||||||
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
||||||
for event_id in missing_events_ids:
|
for event_id in missing_events_ids:
|
||||||
self._current_event_fetches[event_id] = fetching_deferred
|
self._current_event_fetches[event_id] = fetching_deferred
|
||||||
@ -658,12 +689,12 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return event_entry_map
|
return event_entry_map
|
||||||
|
|
||||||
def _invalidate_get_event_cache(self, event_id):
|
def _invalidate_get_event_cache(self, event_id: str) -> None:
|
||||||
self._get_event_cache.invalidate((event_id,))
|
self._get_event_cache.invalidate((event_id,))
|
||||||
|
|
||||||
def _get_events_from_cache(
|
def _get_events_from_cache(
|
||||||
self, events: Iterable[str], update_metrics: bool = True
|
self, events: Iterable[str], update_metrics: bool = True
|
||||||
) -> Dict[str, _EventCacheEntry]:
|
) -> Dict[str, EventCacheEntry]:
|
||||||
"""Fetch events from the caches.
|
"""Fetch events from the caches.
|
||||||
|
|
||||||
May return rejected events.
|
May return rejected events.
|
||||||
@ -820,7 +851,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
for _, deferred in event_fetches_to_fail:
|
for _, deferred in event_fetches_to_fail:
|
||||||
deferred.errback(exc)
|
deferred.errback(exc)
|
||||||
|
|
||||||
def _fetch_loop(self, conn: Connection) -> None:
|
def _fetch_loop(self, conn: LoggingDatabaseConnection) -> None:
|
||||||
"""Takes a database connection and waits for requests for events from
|
"""Takes a database connection and waits for requests for events from
|
||||||
the _event_fetch_list queue.
|
the _event_fetch_list queue.
|
||||||
"""
|
"""
|
||||||
@ -850,7 +881,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
self._fetch_event_list(conn, event_list)
|
self._fetch_event_list(conn, event_list)
|
||||||
|
|
||||||
def _fetch_event_list(
|
def _fetch_event_list(
|
||||||
self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
|
self,
|
||||||
|
conn: LoggingDatabaseConnection,
|
||||||
|
event_list: List[Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle a load of requests from the _event_fetch_list queue
|
"""Handle a load of requests from the _event_fetch_list queue
|
||||||
|
|
||||||
@ -877,7 +910,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# We only want to resolve deferreds from the main thread
|
# We only want to resolve deferreds from the main thread
|
||||||
def fire():
|
def fire() -> None:
|
||||||
for _, d in event_list:
|
for _, d in event_list:
|
||||||
d.callback(row_dict)
|
d.callback(row_dict)
|
||||||
|
|
||||||
@ -887,16 +920,16 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
logger.exception("do_fetch")
|
logger.exception("do_fetch")
|
||||||
|
|
||||||
# We only want to resolve deferreds from the main thread
|
# We only want to resolve deferreds from the main thread
|
||||||
def fire(evs, exc):
|
def fire_errback(exc: Exception) -> None:
|
||||||
for _, d in evs:
|
for _, d in event_list:
|
||||||
d.errback(exc)
|
d.errback(exc)
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
self.hs.get_reactor().callFromThread(fire, event_list, e)
|
self.hs.get_reactor().callFromThread(fire_errback, e)
|
||||||
|
|
||||||
async def _get_events_from_db(
|
async def _get_events_from_db(
|
||||||
self, event_ids: Iterable[str]
|
self, event_ids: Collection[str]
|
||||||
) -> Dict[str, _EventCacheEntry]:
|
) -> Dict[str, EventCacheEntry]:
|
||||||
"""Fetch a bunch of events from the database.
|
"""Fetch a bunch of events from the database.
|
||||||
|
|
||||||
May return rejected events.
|
May return rejected events.
|
||||||
@ -912,29 +945,29 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
map from event id to result. May return extra events which
|
map from event id to result. May return extra events which
|
||||||
weren't asked for.
|
weren't asked for.
|
||||||
"""
|
"""
|
||||||
fetched_events = {}
|
fetched_event_ids: Set[str] = set()
|
||||||
|
fetched_events: Dict[str, _EventRow] = {}
|
||||||
events_to_fetch = event_ids
|
events_to_fetch = event_ids
|
||||||
|
|
||||||
while events_to_fetch:
|
while events_to_fetch:
|
||||||
row_map = await self._enqueue_events(events_to_fetch)
|
row_map = await self._enqueue_events(events_to_fetch)
|
||||||
|
|
||||||
# we need to recursively fetch any redactions of those events
|
# we need to recursively fetch any redactions of those events
|
||||||
redaction_ids = set()
|
redaction_ids: Set[str] = set()
|
||||||
for event_id in events_to_fetch:
|
for event_id in events_to_fetch:
|
||||||
row = row_map.get(event_id)
|
row = row_map.get(event_id)
|
||||||
fetched_events[event_id] = row
|
fetched_event_ids.add(event_id)
|
||||||
if row:
|
if row:
|
||||||
|
fetched_events[event_id] = row
|
||||||
redaction_ids.update(row.redactions)
|
redaction_ids.update(row.redactions)
|
||||||
|
|
||||||
events_to_fetch = redaction_ids.difference(fetched_events.keys())
|
events_to_fetch = redaction_ids.difference(fetched_event_ids)
|
||||||
if events_to_fetch:
|
if events_to_fetch:
|
||||||
logger.debug("Also fetching redaction events %s", events_to_fetch)
|
logger.debug("Also fetching redaction events %s", events_to_fetch)
|
||||||
|
|
||||||
# build a map from event_id to EventBase
|
# build a map from event_id to EventBase
|
||||||
event_map = {}
|
event_map: Dict[str, EventBase] = {}
|
||||||
for event_id, row in fetched_events.items():
|
for event_id, row in fetched_events.items():
|
||||||
if not row:
|
|
||||||
continue
|
|
||||||
assert row.event_id == event_id
|
assert row.event_id == event_id
|
||||||
|
|
||||||
rejected_reason = row.rejected_reason
|
rejected_reason = row.rejected_reason
|
||||||
@ -962,6 +995,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
room_version_id = row.room_version_id
|
room_version_id = row.room_version_id
|
||||||
|
|
||||||
|
room_version: Optional[RoomVersion]
|
||||||
if not room_version_id:
|
if not room_version_id:
|
||||||
# this should only happen for out-of-band membership events which
|
# this should only happen for out-of-band membership events which
|
||||||
# arrived before #6983 landed. For all other events, we should have
|
# arrived before #6983 landed. For all other events, we should have
|
||||||
@ -1032,14 +1066,14 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
# finally, we can decide whether each one needs redacting, and build
|
# finally, we can decide whether each one needs redacting, and build
|
||||||
# the cache entries.
|
# the cache entries.
|
||||||
result_map = {}
|
result_map: Dict[str, EventCacheEntry] = {}
|
||||||
for event_id, original_ev in event_map.items():
|
for event_id, original_ev in event_map.items():
|
||||||
redactions = fetched_events[event_id].redactions
|
redactions = fetched_events[event_id].redactions
|
||||||
redacted_event = self._maybe_redact_event_row(
|
redacted_event = self._maybe_redact_event_row(
|
||||||
original_ev, redactions, event_map
|
original_ev, redactions, event_map
|
||||||
)
|
)
|
||||||
|
|
||||||
cache_entry = _EventCacheEntry(
|
cache_entry = EventCacheEntry(
|
||||||
event=original_ev, redacted_event=redacted_event
|
event=original_ev, redacted_event=redacted_event
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1048,7 +1082,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return result_map
|
return result_map
|
||||||
|
|
||||||
async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]:
|
async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:
|
||||||
"""Fetches events from the database using the _event_fetch_list. This
|
"""Fetches events from the database using the _event_fetch_list. This
|
||||||
allows batch and bulk fetching of events - it allows us to fetch events
|
allows batch and bulk fetching of events - it allows us to fetch events
|
||||||
without having to create a new transaction for each request for events.
|
without having to create a new transaction for each request for events.
|
||||||
@ -1061,7 +1095,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
that weren't requested.
|
that weren't requested.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
events_d = defer.Deferred()
|
events_d: "defer.Deferred[Dict[str, _EventRow]]" = defer.Deferred()
|
||||||
with self._event_fetch_lock:
|
with self._event_fetch_lock:
|
||||||
self._event_fetch_list.append((events, events_d))
|
self._event_fetch_list.append((events, events_d))
|
||||||
self._event_fetch_lock.notify()
|
self._event_fetch_lock.notify()
|
||||||
@ -1216,7 +1250,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
# no valid redaction found for this event
|
# no valid redaction found for this event
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def have_events_in_timeline(self, event_ids):
|
async def have_events_in_timeline(self, event_ids: Iterable[str]) -> Set[str]:
|
||||||
"""Given a list of event ids, check if we have already processed and
|
"""Given a list of event ids, check if we have already processed and
|
||||||
stored them as non outliers.
|
stored them as non outliers.
|
||||||
"""
|
"""
|
||||||
@ -1245,7 +1279,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
event_ids: events we are looking for
|
event_ids: events we are looking for
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
set[str]: The events we have already seen.
|
The set of events we have already seen.
|
||||||
"""
|
"""
|
||||||
res = await self._have_seen_events_dict(
|
res = await self._have_seen_events_dict(
|
||||||
(room_id, event_id) for event_id in event_ids
|
(room_id, event_id) for event_id in event_ids
|
||||||
@ -1268,7 +1302,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
}
|
}
|
||||||
results = {x: True for x in cache_results}
|
results = {x: True for x in cache_results}
|
||||||
|
|
||||||
def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]):
|
def have_seen_events_txn(
|
||||||
|
txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...]
|
||||||
|
) -> None:
|
||||||
# we deliberately do *not* query the database for room_id, to make the
|
# we deliberately do *not* query the database for room_id, to make the
|
||||||
# query an index-only lookup on `events_event_id_key`.
|
# query an index-only lookup on `events_event_id_key`.
|
||||||
#
|
#
|
||||||
@ -1294,12 +1330,14 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
@cached(max_entries=100000, tree=True)
|
@cached(max_entries=100000, tree=True)
|
||||||
async def have_seen_event(self, room_id: str, event_id: str):
|
async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
|
||||||
# this only exists for the benefit of the @cachedList descriptor on
|
# this only exists for the benefit of the @cachedList descriptor on
|
||||||
# _have_seen_events_dict
|
# _have_seen_events_dict
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _get_current_state_event_counts_txn(self, txn, room_id):
|
def _get_current_state_event_counts_txn(
|
||||||
|
self, txn: LoggingTransaction, room_id: str
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
See get_current_state_event_counts.
|
See get_current_state_event_counts.
|
||||||
"""
|
"""
|
||||||
@ -1324,7 +1362,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
room_id,
|
room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_room_complexity(self, room_id):
|
async def get_room_complexity(self, room_id: str) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Get a rough approximation of the complexity of the room. This is used by
|
Get a rough approximation of the complexity of the room. This is used by
|
||||||
remote servers to decide whether they wish to join the room or not.
|
remote servers to decide whether they wish to join the room or not.
|
||||||
@ -1332,10 +1370,10 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
more resources.
|
more resources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str)
|
room_id: The room ID to query.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str:int] of complexity version to complexity.
|
dict[str:float] of complexity version to complexity.
|
||||||
"""
|
"""
|
||||||
state_events = await self.get_current_state_event_counts(room_id)
|
state_events = await self.get_current_state_event_counts(room_id)
|
||||||
|
|
||||||
@ -1345,13 +1383,13 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return {"v1": complexity_v1}
|
return {"v1": complexity_v1}
|
||||||
|
|
||||||
def get_current_events_token(self):
|
def get_current_events_token(self) -> int:
|
||||||
"""The current maximum token that events have reached"""
|
"""The current maximum token that events have reached"""
|
||||||
return self._stream_id_gen.get_current_token()
|
return self._stream_id_gen.get_current_token()
|
||||||
|
|
||||||
async def get_all_new_forward_event_rows(
|
async def get_all_new_forward_event_rows(
|
||||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||||
) -> List[Tuple]:
|
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||||
"""Returns new events, for the Events replication stream
|
"""Returns new events, for the Events replication stream
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1365,7 +1403,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
EventsStreamRow.
|
EventsStreamRow.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_all_new_forward_event_rows(txn):
|
def get_all_new_forward_event_rows(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||||
@ -1381,7 +1421,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
" LIMIT ?"
|
" LIMIT ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (last_id, current_id, instance_name, limit))
|
txn.execute(sql, (last_id, current_id, instance_name, limit))
|
||||||
return txn.fetchall()
|
return cast(
|
||||||
|
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
|
||||||
|
)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
||||||
@ -1389,7 +1431,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
async def get_ex_outlier_stream_rows(
|
async def get_ex_outlier_stream_rows(
|
||||||
self, instance_name: str, last_id: int, current_id: int
|
self, instance_name: str, last_id: int, current_id: int
|
||||||
) -> List[Tuple]:
|
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||||
"""Returns de-outliered events, for the Events replication stream
|
"""Returns de-outliered events, for the Events replication stream
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1402,7 +1444,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
EventsStreamRow.
|
EventsStreamRow.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_ex_outlier_stream_rows_txn(txn):
|
def get_ex_outlier_stream_rows_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
" state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||||
@ -1420,7 +1464,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(sql, (last_id, current_id, instance_name))
|
txn.execute(sql, (last_id, current_id, instance_name))
|
||||||
return txn.fetchall()
|
return cast(
|
||||||
|
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
|
||||||
|
)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
|
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
|
||||||
@ -1428,7 +1474,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
async def get_all_new_backfill_event_rows(
|
async def get_all_new_backfill_event_rows(
|
||||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||||
) -> Tuple[List[Tuple[int, list]], int, bool]:
|
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
|
||||||
"""Get updates for backfill replication stream, including all new
|
"""Get updates for backfill replication stream, including all new
|
||||||
backfilled events and events that have gone from being outliers to not.
|
backfilled events and events that have gone from being outliers to not.
|
||||||
|
|
||||||
@ -1456,7 +1502,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
def get_all_new_backfill_event_rows(txn):
|
def get_all_new_backfill_event_rows(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, str]]], int, bool]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||||
" state_key, redacts, relates_to_id"
|
" state_key, redacts, relates_to_id"
|
||||||
@ -1470,7 +1518,15 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
" LIMIT ?"
|
" LIMIT ?"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
|
txn.execute(sql, (-last_id, -current_id, instance_name, limit))
|
||||||
new_event_updates = [(row[0], row[1:]) for row in txn]
|
new_event_updates: List[
|
||||||
|
Tuple[int, Tuple[str, str, str, str, str, str]]
|
||||||
|
] = []
|
||||||
|
row: Tuple[int, str, str, str, str, str, str]
|
||||||
|
# Type safety: iterating over `txn` yields `Tuple`, i.e.
|
||||||
|
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
|
||||||
|
# variadic tuple to a fixed length tuple and flags it up as an error.
|
||||||
|
for row in txn: # type: ignore[assignment]
|
||||||
|
new_event_updates.append((row[0], row[1:]))
|
||||||
|
|
||||||
limited = False
|
limited = False
|
||||||
if len(new_event_updates) == limit:
|
if len(new_event_updates) == limit:
|
||||||
@ -1493,7 +1549,11 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
" ORDER BY event_stream_ordering DESC"
|
" ORDER BY event_stream_ordering DESC"
|
||||||
)
|
)
|
||||||
txn.execute(sql, (-last_id, -upper_bound, instance_name))
|
txn.execute(sql, (-last_id, -upper_bound, instance_name))
|
||||||
new_event_updates.extend((row[0], row[1:]) for row in txn)
|
# Type safety: iterating over `txn` yields `Tuple`, i.e.
|
||||||
|
# `Tuple[Any, ...]` of arbitrary length. Mypy detects assigning a
|
||||||
|
# variadic tuple to a fixed length tuple and flags it up as an error.
|
||||||
|
for row in txn: # type: ignore[assignment]
|
||||||
|
new_event_updates.append((row[0], row[1:]))
|
||||||
|
|
||||||
if len(new_event_updates) >= limit:
|
if len(new_event_updates) >= limit:
|
||||||
upper_bound = new_event_updates[-1][0]
|
upper_bound = new_event_updates[-1][0]
|
||||||
@ -1507,7 +1567,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
async def get_all_updated_current_state_deltas(
|
async def get_all_updated_current_state_deltas(
|
||||||
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
|
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
|
||||||
) -> Tuple[List[Tuple], int, bool]:
|
) -> Tuple[List[Tuple[int, str, str, str, str]], int, bool]:
|
||||||
"""Fetch updates from current_state_delta_stream
|
"""Fetch updates from current_state_delta_stream
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1527,7 +1587,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
* `limited` is whether there are more updates to fetch.
|
* `limited` is whether there are more updates to fetch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_all_updated_current_state_deltas_txn(txn):
|
def get_all_updated_current_state_deltas_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Tuple[int, str, str, str, str]]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT stream_id, room_id, type, state_key, event_id
|
SELECT stream_id, room_id, type, state_key, event_id
|
||||||
FROM current_state_delta_stream
|
FROM current_state_delta_stream
|
||||||
@ -1536,21 +1598,23 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
ORDER BY stream_id ASC LIMIT ?
|
ORDER BY stream_id ASC LIMIT ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
|
txn.execute(sql, (from_token, to_token, instance_name, target_row_count))
|
||||||
return txn.fetchall()
|
return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
|
||||||
|
|
||||||
def get_deltas_for_stream_id_txn(txn, stream_id):
|
def get_deltas_for_stream_id_txn(
|
||||||
|
txn: LoggingTransaction, stream_id: int
|
||||||
|
) -> List[Tuple[int, str, str, str, str]]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT stream_id, room_id, type, state_key, event_id
|
SELECT stream_id, room_id, type, state_key, event_id
|
||||||
FROM current_state_delta_stream
|
FROM current_state_delta_stream
|
||||||
WHERE stream_id = ?
|
WHERE stream_id = ?
|
||||||
"""
|
"""
|
||||||
txn.execute(sql, [stream_id])
|
txn.execute(sql, [stream_id])
|
||||||
return txn.fetchall()
|
return cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
|
||||||
|
|
||||||
# we need to make sure that, for every stream id in the results, we get *all*
|
# we need to make sure that, for every stream id in the results, we get *all*
|
||||||
# the rows with that stream id.
|
# the rows with that stream id.
|
||||||
|
|
||||||
rows: List[Tuple] = await self.db_pool.runInteraction(
|
rows: List[Tuple[int, str, str, str, str]] = await self.db_pool.runInteraction(
|
||||||
"get_all_updated_current_state_deltas",
|
"get_all_updated_current_state_deltas",
|
||||||
get_all_updated_current_state_deltas_txn,
|
get_all_updated_current_state_deltas_txn,
|
||||||
)
|
)
|
||||||
@ -1579,14 +1643,14 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return rows, to_token, True
|
return rows, to_token, True
|
||||||
|
|
||||||
async def is_event_after(self, event_id1, event_id2):
|
async def is_event_after(self, event_id1: str, event_id2: str) -> bool:
|
||||||
"""Returns True if event_id1 is after event_id2 in the stream"""
|
"""Returns True if event_id1 is after event_id2 in the stream"""
|
||||||
to_1, so_1 = await self.get_event_ordering(event_id1)
|
to_1, so_1 = await self.get_event_ordering(event_id1)
|
||||||
to_2, so_2 = await self.get_event_ordering(event_id2)
|
to_2, so_2 = await self.get_event_ordering(event_id2)
|
||||||
return (to_1, so_1) > (to_2, so_2)
|
return (to_1, so_1) > (to_2, so_2)
|
||||||
|
|
||||||
@cached(max_entries=5000)
|
@cached(max_entries=5000)
|
||||||
async def get_event_ordering(self, event_id):
|
async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
|
||||||
res = await self.db_pool.simple_select_one(
|
res = await self.db_pool.simple_select_one(
|
||||||
table="events",
|
table="events",
|
||||||
retcols=["topological_ordering", "stream_ordering"],
|
retcols=["topological_ordering", "stream_ordering"],
|
||||||
@ -1609,7 +1673,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
None otherwise.
|
None otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_next_event_to_expire_txn(txn):
|
def get_next_event_to_expire_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Optional[Tuple[str, int]]:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"""
|
"""
|
||||||
SELECT event_id, expiry_ts FROM event_expiry
|
SELECT event_id, expiry_ts FROM event_expiry
|
||||||
@ -1617,7 +1683,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
return txn.fetchone()
|
return cast(Optional[Tuple[str, int]], txn.fetchone())
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||||
@ -1681,10 +1747,10 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
@wrap_as_background_process("_cleanup_old_transaction_ids")
|
@wrap_as_background_process("_cleanup_old_transaction_ids")
|
||||||
async def _cleanup_old_transaction_ids(self):
|
async def _cleanup_old_transaction_ids(self) -> None:
|
||||||
"""Cleans out transaction id mappings older than 24hrs."""
|
"""Cleans out transaction id mappings older than 24hrs."""
|
||||||
|
|
||||||
def _cleanup_old_transaction_ids_txn(txn):
|
def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None:
|
||||||
sql = """
|
sql = """
|
||||||
DELETE FROM event_txn_id
|
DELETE FROM event_txn_id
|
||||||
WHERE inserted_ts < ?
|
WHERE inserted_ts < ?
|
||||||
|
@ -28,7 +28,10 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
|||||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
from synapse.storage.util.id_generators import (
|
||||||
|
AbstractStreamIdTracker,
|
||||||
|
StreamIdGenerator,
|
||||||
|
)
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
@ -82,9 +85,9 @@ class PushRulesWorkerStore(
|
|||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
if hs.config.worker.worker_app is None:
|
if hs.config.worker.worker_app is None:
|
||||||
self._push_rules_stream_id_gen: Union[
|
self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
|
||||||
StreamIdGenerator, SlavedIdTracker
|
db_conn, "push_rules_stream", "stream_id"
|
||||||
] = StreamIdGenerator(db_conn, "push_rules_stream", "stream_id")
|
)
|
||||||
else:
|
else:
|
||||||
self._push_rules_stream_id_gen = SlavedIdTracker(
|
self._push_rules_stream_id_gen = SlavedIdTracker(
|
||||||
db_conn, "push_rules_stream", "stream_id"
|
db_conn, "push_rules_stream", "stream_id"
|
||||||
|
@ -89,31 +89,77 @@ def _load_current_id(
|
|||||||
return (max if step > 0 else min)(current_id, step)
|
return (max if step > 0 else min)(current_id, step)
|
||||||
|
|
||||||
|
|
||||||
class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
|
class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
|
||||||
@abc.abstractmethod
|
"""Tracks the "current" stream ID of a stream that may have multiple writers.
|
||||||
def get_next(self) -> AsyncContextManager[int]:
|
|
||||||
raise NotImplementedError()
|
Stream IDs are monotonically increasing or decreasing integers representing write
|
||||||
|
transactions. The "current" stream ID is the stream ID such that all transactions
|
||||||
|
with equal or smaller stream IDs have completed. Since transactions may complete out
|
||||||
|
of order, this is not the same as the stream ID of the last completed transaction.
|
||||||
|
|
||||||
|
Completed transactions include both committed transactions and transactions that
|
||||||
|
have been rolled back.
|
||||||
|
"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
def advance(self, instance_name: str, new_id: int) -> None:
|
||||||
|
"""Advance the position of the named writer to the given ID, if greater
|
||||||
|
than existing entry.
|
||||||
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_current_token(self) -> int:
|
def get_current_token(self) -> int:
|
||||||
|
"""Returns the maximum stream id such that all stream ids less than or
|
||||||
|
equal to it have been successfully persisted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The maximum stream id.
|
||||||
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||||
|
"""Returns the position of the given writer.
|
||||||
|
|
||||||
|
For streams with single writers this is equivalent to `get_current_token`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractStreamIdGenerator(AbstractStreamIdTracker):
|
||||||
|
"""Generates stream IDs for a stream that may have multiple writers.
|
||||||
|
|
||||||
|
Each stream ID represents a write transaction, whose completion is tracked
|
||||||
|
so that the "current" stream ID of the stream can be determined.
|
||||||
|
|
||||||
|
See `AbstractStreamIdTracker` for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_next(self) -> AsyncContextManager[int]:
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
async with stream_id_gen.get_next() as stream_id:
|
||||||
|
# ... persist event ...
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
async with stream_id_gen.get_next(n) as stream_ids:
|
||||||
|
# ... persist events ...
|
||||||
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class StreamIdGenerator(AbstractStreamIdGenerator):
|
class StreamIdGenerator(AbstractStreamIdGenerator):
|
||||||
"""Used to generate new stream ids when persisting events while keeping
|
"""Generates and tracks stream IDs for a stream with a single writer.
|
||||||
track of which transactions have been completed.
|
|
||||||
|
|
||||||
This allows us to get the "current" stream id, i.e. the stream id such that
|
This class must only be used when the current Synapse process is the sole
|
||||||
all ids less than or equal to it have completed. This handles the fact that
|
writer for a stream.
|
||||||
persistence of events can complete out of order.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_conn(connection): A database connection to use to fetch the
|
db_conn(connection): A database connection to use to fetch the
|
||||||
@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
|||||||
# The key and values are the same, but we never look at the values.
|
# The key and values are the same, but we never look at the values.
|
||||||
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
|
||||||
|
|
||||||
|
def advance(self, instance_name: str, new_id: int) -> None:
|
||||||
|
# `StreamIdGenerator` should only be used when there is a single writer,
|
||||||
|
# so replication should never happen.
|
||||||
|
raise Exception("Replication is not supported by StreamIdGenerator")
|
||||||
|
|
||||||
def get_next(self) -> AsyncContextManager[int]:
|
def get_next(self) -> AsyncContextManager[int]:
|
||||||
"""
|
|
||||||
Usage:
|
|
||||||
async with stream_id_gen.get_next() as stream_id:
|
|
||||||
# ... persist event ...
|
|
||||||
"""
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._current += self._step
|
self._current += self._step
|
||||||
next_id = self._current
|
next_id = self._current
|
||||||
@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
|||||||
return _AsyncCtxManagerWrapper(manager())
|
return _AsyncCtxManagerWrapper(manager())
|
||||||
|
|
||||||
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
|
||||||
"""
|
|
||||||
Usage:
|
|
||||||
async with stream_id_gen.get_next(n) as stream_ids:
|
|
||||||
# ... persist events ...
|
|
||||||
"""
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
next_ids = range(
|
next_ids = range(
|
||||||
self._current + self._step,
|
self._current + self._step,
|
||||||
@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
|||||||
return _AsyncCtxManagerWrapper(manager())
|
return _AsyncCtxManagerWrapper(manager())
|
||||||
|
|
||||||
def get_current_token(self) -> int:
|
def get_current_token(self) -> int:
|
||||||
"""Returns the maximum stream id such that all stream ids less than or
|
|
||||||
equal to it have been successfully persisted.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The maximum stream id.
|
|
||||||
"""
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._unfinished_ids:
|
if self._unfinished_ids:
|
||||||
return next(iter(self._unfinished_ids)) - self._step
|
return next(iter(self._unfinished_ids)) - self._step
|
||||||
@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
|
|||||||
return self._current
|
return self._current
|
||||||
|
|
||||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||||
"""Returns the position of the given writer.
|
|
||||||
|
|
||||||
For streams with single writers this is equivalent to
|
|
||||||
`get_current_token`.
|
|
||||||
"""
|
|
||||||
return self.get_current_token()
|
return self.get_current_token()
|
||||||
|
|
||||||
|
|
||||||
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||||
"""An ID generator that tracks a stream that can have multiple writers.
|
"""Generates and tracks stream IDs for a stream with multiple writers.
|
||||||
|
|
||||||
Uses a Postgres sequence to coordinate ID assignment, but positions of other
|
Uses a Postgres sequence to coordinate ID assignment, but positions of other
|
||||||
writers will only get updated when `advance` is called (by replication).
|
writers will only get updated when `advance` is called (by replication).
|
||||||
@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||||||
return stream_ids
|
return stream_ids
|
||||||
|
|
||||||
def get_next(self) -> AsyncContextManager[int]:
|
def get_next(self) -> AsyncContextManager[int]:
|
||||||
"""
|
|
||||||
Usage:
|
|
||||||
async with stream_id_gen.get_next() as stream_id:
|
|
||||||
# ... persist event ...
|
|
||||||
"""
|
|
||||||
|
|
||||||
# If we have a list of instances that are allowed to write to this
|
# If we have a list of instances that are allowed to write to this
|
||||||
# stream, make sure we're in it.
|
# stream, make sure we're in it.
|
||||||
if self._writers and self._instance_name not in self._writers:
|
if self._writers and self._instance_name not in self._writers:
|
||||||
@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||||||
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
|
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
|
||||||
|
|
||||||
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
|
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
|
||||||
"""
|
|
||||||
Usage:
|
|
||||||
async with stream_id_gen.get_next_mult(5) as stream_ids:
|
|
||||||
# ... persist events ...
|
|
||||||
"""
|
|
||||||
|
|
||||||
# If we have a list of instances that are allowed to write to this
|
# If we have a list of instances that are allowed to write to this
|
||||||
# stream, make sure we're in it.
|
# stream, make sure we're in it.
|
||||||
if self._writers and self._instance_name not in self._writers:
|
if self._writers and self._instance_name not in self._writers:
|
||||||
@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||||||
self._add_persisted_position(next_id)
|
self._add_persisted_position(next_id)
|
||||||
|
|
||||||
def get_current_token(self) -> int:
|
def get_current_token(self) -> int:
|
||||||
"""Returns the maximum stream id such that all stream ids less than or
|
|
||||||
equal to it have been successfully persisted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self.get_persisted_upto_position()
|
return self.get_persisted_upto_position()
|
||||||
|
|
||||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||||
"""Returns the position of the given writer."""
|
|
||||||
|
|
||||||
# If we don't have an entry for the given instance name, we assume it's a
|
# If we don't have an entry for the given instance name, we assume it's a
|
||||||
# new writer.
|
# new writer.
|
||||||
#
|
#
|
||||||
@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def advance(self, instance_name: str, new_id: int) -> None:
|
def advance(self, instance_name: str, new_id: int) -> None:
|
||||||
"""Advance the position of the named writer to the given ID, if greater
|
|
||||||
than existing entry.
|
|
||||||
"""
|
|
||||||
|
|
||||||
new_id *= self._return_factor
|
new_id *= self._return_factor
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -17,6 +17,7 @@ from unittest.mock import patch
|
|||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room, sync
|
from synapse.rest.client import login, room, sync
|
||||||
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
|
|
||||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
from tests.server import make_request
|
from tests.server import make_request
|
||||||
@ -193,7 +194,10 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||||||
#
|
#
|
||||||
# Worker2's event stream position will not advance until we call
|
# Worker2's event stream position will not advance until we call
|
||||||
# __aexit__ again.
|
# __aexit__ again.
|
||||||
actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
|
worker_store2 = worker_hs2.get_datastore()
|
||||||
|
assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
|
||||||
|
|
||||||
|
actx = worker_store2._stream_id_gen.get_next()
|
||||||
self.get_success(actx.__aenter__())
|
self.get_success(actx.__aenter__())
|
||||||
|
|
||||||
response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
|
response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
|
||||||
|
Loading…
Reference in New Issue
Block a user