Merge/remove Slaved* stores into WorkerStores (#14375)

This commit is contained in:
Nick Mills-Barrett 2022-11-11 10:51:49 +00:00 committed by GitHub
parent 13ca8bb2fc
commit 3a4f80f8c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 202 additions and 377 deletions

View file

@ -26,9 +26,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
@ -138,41 +136,8 @@ class DataStore(
self._clock = hs.get_clock()
self.database_engine = database.engine
self._device_list_id_gen = StreamIdGenerator(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
super().__init__(database, db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
"_curr_state_delta_stream_cache",
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
def get_device_stream_token(self) -> int:
# TODO: shouldn't this be moved to `DeviceWorkerStore`?
return self._device_list_id_gen.get_current_token()
async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table.

View file

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import (
TYPE_CHECKING,
@ -39,6 +38,8 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@ -49,6 +50,11 @@ from synapse.storage.database import (
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -80,9 +86,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
):
super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
else:
self._device_list_id_gen = SlavedIdTracker(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
device_list_max = self._device_list_id_gen.get_current_token()
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_stream",
@ -136,6 +165,39 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
)
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
@ -677,11 +739,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
},
)
@abc.abstractmethod
def get_device_stream_token(self) -> int:
"""Get the current stream id from the _device_list_id_gen"""
...
@trace
@cancellable
async def get_user_devices_from_cache(
@ -1481,6 +1538,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Because we have write access, this will be a StreamIdGenerator
# (see DeviceWorkerStore.__init__)
_device_list_id_gen: AbstractStreamIdGenerator
def __init__(
self,
database: DatabasePool,
@ -1805,7 +1866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context,
)
async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@ -2044,7 +2105,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[],
)
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
@ -2058,7 +2119,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
updates during partial joins.
"""
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.simple_upsert(
table="device_lists_remote_pending",
keyvalues={

View file

@ -81,6 +81,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import AsyncLruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@ -233,6 +234,21 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1
)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache: StreamChangeCache = StreamChangeCache(
"_curr_state_delta_stream_cache",
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
if hs.config.worker.run_background_tasks:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(

View file

@ -24,7 +24,7 @@ from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
class FilteringStore(SQLBaseStore):
class FilteringWorkerStore(SQLBaseStore):
@cached(num_args=2)
async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str]
@ -46,6 +46,8 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
class FilteringStore(FilteringWorkerStore):
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter)

View file

@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
@ -31,6 +31,7 @@ from typing import (
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
@ -90,8 +91,6 @@ def _load_rules(
return filtered_rules
# The ABCMeta metaclass ensures that it cannot be instantiated without
# the abstract methods being implemented.
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
PusherWorkerStore,
@ -99,7 +98,6 @@ class PushRulesWorkerStore(
ReceiptsWorkerStore,
EventsWorkerStore,
SQLBaseStore,
metaclass=abc.ABCMeta,
):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
@ -136,14 +134,23 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill,
)
@abc.abstractmethod
def get_max_push_rules_stream_id(self) -> int:
"""Get the position of the push rules stream.
Returns:
int
"""
raise NotImplementedError()
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:

View file

@ -27,13 +27,19 @@ from typing import (
)
from synapse.push import PusherConfig, ThrottleParams
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushersStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
AbstractStreamIdTracker,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@ -52,9 +58,21 @@ class PusherWorkerStore(SQLBaseStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
if hs.config.worker.worker_app is None:
self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
db_conn,
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
)
else:
self._pushers_id_gen = SlavedIdTracker(
db_conn,
"pushers",
"id",
extra_tables=[("deleted_pushers", "stream_id")],
)
self.db_pool.updates.register_background_update_handler(
"remove_deactivated_pushers",
@ -96,6 +114,16 @@ class PusherWorkerStore(SQLBaseStore):
yield PusherConfig(**r)
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str
) -> Iterator[PusherConfig]:
@ -545,8 +573,9 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
# Because we have write access, this will be a StreamIdGenerator
# (see PusherWorkerStore.__init__)
_pushers_id_gen: AbstractStreamIdGenerator
async def add_pusher(
self,

View file

@ -415,6 +415,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
def get_room_max_stream_ordering(self) -> int:
"""Get the stream_ordering of regular events that we have committed up to