forked-synapse/synapse/replication/slave/storage/devices.py
Erik Johnston 5c9e39e619
Track device list updates per room. (#12321)
This is a first step in dealing with #7721.

The idea is basically that rather than calculating the full set of users a device list update needs to be sent to up front, we instead simply record the rooms the user was in at the time of the change. This will allow a few things:

1. we can defer calculating the set of remote servers that need to be poked about the change; and
2. during `/sync` and `/keys/changes` we can avoid also avoid calculating users who share rooms with other users, and instead just look at the rooms that have changed.

However, care needs to be taken to correctly handle server downgrades. As such this PR writes to both `device_lists_changes_in_room` and the `device_lists_outbound_pokes` table synchronously. In a future release we can then bump the database schema compat version to `69` and then we can assume that the new `device_lists_changes_in_room` exists and is handled.

There is a temporary option to disable writing to `device_lists_outbound_pokes` synchronously, allowing us to test the new code path does work (and by implication upgrading to a future release and downgrading to this one will work correctly).

Note: Ideally we'd do the calculation of room to servers on a worker (e.g. the background worker), but currently only master can write to the `device_list_outbound_pokes` table.
2022-04-04 15:25:20 +01:00

93 lines
3.9 KiB
Python

# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Iterable
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
],
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max
)
self._user_signature_stream_cache = StreamChangeCache(
"UserSignatureStreamChangeCache", device_list_max
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max
)
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
else:
self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token
)