mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
b26cbe3d45
* Fix type error introduced accidentally by #13045 * Newsfile Signed-off-by: Olivier Wilkinson (reivilibre) <oliverw@matrix.org>
1958 lines
71 KiB
Python
1958 lines
71 KiB
Python
# Copyright 2016 OpenMarket Ltd
|
|
# Copyright 2019 New Vector Ltd
|
|
# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import abc
|
|
import logging
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Collection,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
cast,
|
|
)
|
|
|
|
from typing_extensions import Literal
|
|
|
|
from synapse.api.constants import EduTypes
|
|
from synapse.api.errors import Codes, StoreError
|
|
from synapse.logging.opentracing import (
|
|
get_active_span_text_map,
|
|
set_tag,
|
|
trace,
|
|
whitelisted_homeserver,
|
|
)
|
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
|
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
|
from synapse.storage.database import (
|
|
DatabasePool,
|
|
LoggingDatabaseConnection,
|
|
LoggingTransaction,
|
|
make_tuple_comparison_clause,
|
|
)
|
|
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
|
from synapse.storage.types import Cursor
|
|
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
|
|
from synapse.util import json_decoder, json_encoder
|
|
from synapse.util.caches.descriptors import cached, cachedList
|
|
from synapse.util.caches.lrucache import LruCache
|
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|
from synapse.util.iterutils import batch_iter
|
|
from synapse.util.stringutils import shortstr
|
|
|
|
if TYPE_CHECKING:
|
|
from synapse.server import HomeServer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
issue_8631_logger = logging.getLogger("synapse.8631_debug")
|
|
|
|
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
|
|
"drop_device_list_streams_non_unique_indexes"
|
|
)
|
|
|
|
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
|
|
|
|
|
class DeviceWorkerStore(EndToEndKeyWorkerStore):
|
|
def __init__(
|
|
self,
|
|
database: DatabasePool,
|
|
db_conn: LoggingDatabaseConnection,
|
|
hs: "HomeServer",
|
|
):
|
|
super().__init__(database, db_conn, hs)
|
|
|
|
# Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
|
|
# StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
|
|
device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
|
|
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
|
|
db_conn,
|
|
"device_lists_stream",
|
|
entity_column="user_id",
|
|
stream_column="stream_id",
|
|
max_value=device_list_max,
|
|
limit=10000,
|
|
)
|
|
self._device_list_stream_cache = StreamChangeCache(
|
|
"DeviceListStreamChangeCache",
|
|
min_device_list_id,
|
|
prefilled_cache=device_list_prefill,
|
|
)
|
|
|
|
(
|
|
user_signature_stream_prefill,
|
|
user_signature_stream_list_id,
|
|
) = self.db_pool.get_cache_dict(
|
|
db_conn,
|
|
"user_signature_stream",
|
|
entity_column="from_user_id",
|
|
stream_column="stream_id",
|
|
max_value=device_list_max,
|
|
limit=1000,
|
|
)
|
|
self._user_signature_stream_cache = StreamChangeCache(
|
|
"UserSignatureStreamChangeCache",
|
|
user_signature_stream_list_id,
|
|
prefilled_cache=user_signature_stream_prefill,
|
|
)
|
|
|
|
(
|
|
device_list_federation_prefill,
|
|
device_list_federation_list_id,
|
|
) = self.db_pool.get_cache_dict(
|
|
db_conn,
|
|
"device_lists_outbound_pokes",
|
|
entity_column="destination",
|
|
stream_column="stream_id",
|
|
max_value=device_list_max,
|
|
limit=10000,
|
|
)
|
|
self._device_list_federation_stream_cache = StreamChangeCache(
|
|
"DeviceListFederationStreamChangeCache",
|
|
device_list_federation_list_id,
|
|
prefilled_cache=device_list_federation_prefill,
|
|
)
|
|
|
|
if hs.config.worker.run_background_tasks:
|
|
self._clock.looping_call(
|
|
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
|
|
)
|
|
|
|
async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
|
|
"""Retrieve number of all devices of given users.
|
|
Only returns number of devices that are not marked as hidden.
|
|
|
|
Args:
|
|
user_ids: The IDs of the users which owns devices
|
|
Returns:
|
|
Number of devices of this users.
|
|
"""
|
|
|
|
def count_devices_by_users_txn(
|
|
txn: LoggingTransaction, user_ids: List[str]
|
|
) -> int:
|
|
sql = """
|
|
SELECT count(*)
|
|
FROM devices
|
|
WHERE
|
|
hidden = '0' AND
|
|
"""
|
|
|
|
clause, args = make_in_list_sql_clause(
|
|
txn.database_engine, "user_id", user_ids
|
|
)
|
|
|
|
txn.execute(sql + clause, args)
|
|
return cast(Tuple[int], txn.fetchone())[0]
|
|
|
|
if not user_ids:
|
|
return 0
|
|
|
|
return await self.db_pool.runInteraction(
|
|
"count_devices_by_users", count_devices_by_users_txn, user_ids
|
|
)
|
|
|
|
async def get_device(
|
|
self, user_id: str, device_id: str
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve a device. Only returns devices that are not marked as
|
|
hidden.
|
|
|
|
Args:
|
|
user_id: The ID of the user which owns the device
|
|
device_id: The ID of the device to retrieve
|
|
Returns:
|
|
A dict containing the device information, or `None` if the device does not
|
|
exist.
|
|
"""
|
|
return await self.db_pool.simple_select_one(
|
|
table="devices",
|
|
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
|
retcols=("user_id", "device_id", "display_name"),
|
|
desc="get_device",
|
|
allow_none=True,
|
|
)
|
|
|
|
async def get_device_opt(
|
|
self, user_id: str, device_id: str
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Retrieve a device. Only returns devices that are not marked as
|
|
hidden.
|
|
|
|
Args:
|
|
user_id: The ID of the user which owns the device
|
|
device_id: The ID of the device to retrieve
|
|
Returns:
|
|
A dict containing the device information, or None if the device does not exist.
|
|
"""
|
|
return await self.db_pool.simple_select_one(
|
|
table="devices",
|
|
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
|
retcols=("user_id", "device_id", "display_name"),
|
|
desc="get_device",
|
|
allow_none=True,
|
|
)
|
|
|
|
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
|
|
"""Retrieve all of a user's registered devices. Only returns devices
|
|
that are not marked as hidden.
|
|
|
|
Args:
|
|
user_id:
|
|
Returns:
|
|
A mapping from device_id to a dict containing "device_id", "user_id"
|
|
and "display_name" for each device.
|
|
"""
|
|
devices = await self.db_pool.simple_select_list(
|
|
table="devices",
|
|
keyvalues={"user_id": user_id, "hidden": False},
|
|
retcols=("user_id", "device_id", "display_name"),
|
|
desc="get_devices_by_user",
|
|
)
|
|
|
|
return {d["device_id"]: d for d in devices}
|
|
|
|
async def get_devices_by_auth_provider_session_id(
|
|
self, auth_provider_id: str, auth_provider_session_id: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""Retrieve the list of devices associated with a SSO IdP session ID.
|
|
|
|
Args:
|
|
auth_provider_id: The SSO IdP ID as defined in the server config
|
|
auth_provider_session_id: The session ID within the IdP
|
|
Returns:
|
|
A list of dicts containing the device_id and the user_id of each device
|
|
"""
|
|
return await self.db_pool.simple_select_list(
|
|
table="device_auth_providers",
|
|
keyvalues={
|
|
"auth_provider_id": auth_provider_id,
|
|
"auth_provider_session_id": auth_provider_session_id,
|
|
},
|
|
retcols=("user_id", "device_id"),
|
|
desc="get_devices_by_auth_provider_session_id",
|
|
)
|
|
|
|
@trace
|
|
async def get_device_updates_by_remote(
|
|
self, destination: str, from_stream_id: int, limit: int
|
|
) -> Tuple[int, List[Tuple[str, JsonDict]]]:
|
|
"""Get a stream of device updates to send to the given remote server.
|
|
|
|
Args:
|
|
destination: The host the device updates are intended for
|
|
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
|
limit: Maximum number of device updates to return
|
|
|
|
Returns:
|
|
- The current stream id (i.e. the stream id of the last update included
|
|
in the response); and
|
|
- The list of updates, where each update is a pair of EDU type and
|
|
EDU contents.
|
|
"""
|
|
now_stream_id = self.get_device_stream_token()
|
|
|
|
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
|
|
destination, int(from_stream_id)
|
|
)
|
|
if not has_changed:
|
|
return now_stream_id, []
|
|
|
|
updates = await self.db_pool.runInteraction(
|
|
"get_device_updates_by_remote",
|
|
self._get_device_updates_by_remote_txn,
|
|
destination,
|
|
from_stream_id,
|
|
now_stream_id,
|
|
limit,
|
|
)
|
|
|
|
# We need to ensure `updates` doesn't grow too big.
|
|
# Currently: `len(updates) <= limit`.
|
|
|
|
# Return an empty list if there are no updates
|
|
if not updates:
|
|
return now_stream_id, []
|
|
|
|
if issue_8631_logger.isEnabledFor(logging.DEBUG):
|
|
data = {(user, device): stream_id for user, device, stream_id, _ in updates}
|
|
issue_8631_logger.debug(
|
|
"device updates need to be sent to %s: %s", destination, data
|
|
)
|
|
|
|
# get the cross-signing keys of the users in the list, so that we can
|
|
# determine which of the device changes were cross-signing keys
|
|
users = {r[0] for r in updates}
|
|
master_key_by_user = {}
|
|
self_signing_key_by_user = {}
|
|
for user in users:
|
|
cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
|
|
if cross_signing_key:
|
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
|
cross_signing_key
|
|
)
|
|
# verify_key is a VerifyKey from signedjson, which uses
|
|
# .version to denote the portion of the key ID after the
|
|
# algorithm and colon, which is the device ID
|
|
master_key_by_user[user] = {
|
|
"key_info": cross_signing_key,
|
|
"device_id": verify_key.version,
|
|
}
|
|
|
|
cross_signing_key = await self.get_e2e_cross_signing_key(
|
|
user, "self_signing"
|
|
)
|
|
if cross_signing_key:
|
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
|
cross_signing_key
|
|
)
|
|
self_signing_key_by_user[user] = {
|
|
"key_info": cross_signing_key,
|
|
"device_id": verify_key.version,
|
|
}
|
|
|
|
# Perform the equivalent of a GROUP BY
|
|
#
|
|
# Iterate through the updates list and copy non-duplicate
|
|
# (user_id, device_id) entries into a map, with the value being
|
|
# the max stream_id across each set of duplicate entries
|
|
#
|
|
# maps (user_id, device_id) -> (stream_id, opentracing_context)
|
|
#
|
|
# opentracing_context contains the opentracing metadata for the request
|
|
# that created the poke
|
|
#
|
|
# The most recent request's opentracing_context is used as the
|
|
# context which created the Edu.
|
|
|
|
# This is the stream ID that we will return for the consumer to resume
|
|
# following this stream later.
|
|
last_processed_stream_id = from_stream_id
|
|
|
|
# A map of (user ID, device ID) to (stream ID, context).
|
|
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
|
|
cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
|
|
for user_id, device_id, update_stream_id, update_context in updates:
|
|
# Calculate the remaining length budget.
|
|
# Note that, for now, each entry in `cross_signing_keys_by_user`
|
|
# gives rise to two device updates in the result, so those cost twice
|
|
# as much (and are the whole reason we need to separately calculate
|
|
# the budget; we know len(updates) <= limit otherwise!)
|
|
# N.B. len() on dicts is cheap since they store their size.
|
|
remaining_length_budget = limit - (
|
|
len(query_map) + 2 * len(cross_signing_keys_by_user)
|
|
)
|
|
assert remaining_length_budget >= 0
|
|
|
|
is_master_key_update = (
|
|
user_id in master_key_by_user
|
|
and device_id == master_key_by_user[user_id]["device_id"]
|
|
)
|
|
is_self_signing_key_update = (
|
|
user_id in self_signing_key_by_user
|
|
and device_id == self_signing_key_by_user[user_id]["device_id"]
|
|
)
|
|
|
|
is_cross_signing_key_update = (
|
|
is_master_key_update or is_self_signing_key_update
|
|
)
|
|
|
|
if (
|
|
is_cross_signing_key_update
|
|
and user_id not in cross_signing_keys_by_user
|
|
):
|
|
# This will give rise to 2 device updates.
|
|
# If we don't have the budget, stop here!
|
|
if remaining_length_budget < 2:
|
|
break
|
|
|
|
if is_master_key_update:
|
|
result = cross_signing_keys_by_user.setdefault(user_id, {})
|
|
result["master_key"] = master_key_by_user[user_id]["key_info"]
|
|
elif is_self_signing_key_update:
|
|
result = cross_signing_keys_by_user.setdefault(user_id, {})
|
|
result["self_signing_key"] = self_signing_key_by_user[user_id][
|
|
"key_info"
|
|
]
|
|
else:
|
|
key = (user_id, device_id)
|
|
|
|
if key not in query_map and remaining_length_budget < 1:
|
|
# We don't have space for a new entry
|
|
break
|
|
|
|
previous_update_stream_id, _ = query_map.get(key, (0, None))
|
|
|
|
if update_stream_id > previous_update_stream_id:
|
|
# FIXME If this overwrites an older update, this discards the
|
|
# previous OpenTracing context.
|
|
# It might make it harder to track down issues using OpenTracing.
|
|
# If there's a good reason why it doesn't matter, a comment here
|
|
# about that would not hurt.
|
|
query_map[key] = (update_stream_id, update_context)
|
|
|
|
# As this update has been added to the response, advance the stream
|
|
# position.
|
|
last_processed_stream_id = update_stream_id
|
|
|
|
# In the worst case scenario, each update is for a distinct user and is
|
|
# added either to the query_map or to cross_signing_keys_by_user,
|
|
# but not both:
|
|
# len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here,
|
|
# so len(query_map) + len(cross_signing_keys_by_user) <= limit.
|
|
|
|
results = await self._get_device_update_edus_by_remote(
|
|
destination, from_stream_id, query_map
|
|
)
|
|
|
|
# len(results) <= len(query_map) here,
|
|
# so len(results) + len(cross_signing_keys_by_user) <= limit.
|
|
|
|
# Add the updated cross-signing keys to the results list
|
|
for user_id, result in cross_signing_keys_by_user.items():
|
|
result["user_id"] = user_id
|
|
results.append((EduTypes.SIGNING_KEY_UPDATE, result))
|
|
# also send the unstable version
|
|
# FIXME: remove this when enough servers have upgraded
|
|
# and remove the length budgeting above.
|
|
results.append(("org.matrix.signing_key_update", result))
|
|
|
|
if issue_8631_logger.isEnabledFor(logging.DEBUG):
|
|
for (user_id, edu) in results:
|
|
issue_8631_logger.debug(
|
|
"device update to %s for %s from %s to %s: %s",
|
|
destination,
|
|
user_id,
|
|
from_stream_id,
|
|
last_processed_stream_id,
|
|
edu,
|
|
)
|
|
|
|
return last_processed_stream_id, results
|
|
|
|
def _get_device_updates_by_remote_txn(
|
|
self,
|
|
txn: LoggingTransaction,
|
|
destination: str,
|
|
from_stream_id: int,
|
|
now_stream_id: int,
|
|
limit: int,
|
|
) -> List[Tuple[str, str, int, Optional[str]]]:
|
|
"""Return device update information for a given remote destination
|
|
|
|
Args:
|
|
txn: The transaction to execute
|
|
destination: The host the device updates are intended for
|
|
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
|
now_stream_id: The maximum stream_id to filter updates by, inclusive
|
|
limit: Maximum number of device updates to return
|
|
|
|
Returns:
|
|
List: List of device update tuples:
|
|
- user_id
|
|
- device_id
|
|
- stream_id
|
|
- opentracing_context
|
|
"""
|
|
# get the list of device updates that need to be sent
|
|
sql = """
|
|
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
|
|
WHERE destination = ? AND ? < stream_id AND stream_id <= ?
|
|
ORDER BY stream_id
|
|
LIMIT ?
|
|
"""
|
|
txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
|
|
|
|
return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall())
|
|
|
|
async def _get_device_update_edus_by_remote(
|
|
self,
|
|
destination: str,
|
|
from_stream_id: int,
|
|
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
|
|
) -> List[Tuple[str, dict]]:
|
|
"""Returns a list of device update EDUs as well as E2EE keys
|
|
|
|
Args:
|
|
destination: The host the device updates are intended for
|
|
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
|
query_map: Dictionary mapping (user_id, device_id) to
|
|
(update stream_id, the relevant json-encoded opentracing context)
|
|
|
|
Returns:
|
|
List of objects representing a device update EDU.
|
|
|
|
Postconditions:
|
|
The returned list has a length not exceeding that of the query_map:
|
|
len(result) <= len(query_map)
|
|
"""
|
|
devices = (
|
|
await self.get_e2e_device_keys_and_signatures(
|
|
# Because these are (user_id, device_id) tuples with all
|
|
# device_ids not being None, the returned list's length will not
|
|
# exceed that of query_map.
|
|
query_map.keys(),
|
|
include_all_devices=True,
|
|
include_deleted_devices=True,
|
|
)
|
|
if query_map
|
|
else {}
|
|
)
|
|
|
|
results = []
|
|
for user_id, user_devices in devices.items():
|
|
# The prev_id for the first row is always the last row before
|
|
# `from_stream_id`
|
|
prev_id = await self._get_last_device_update_for_remote_user(
|
|
destination, user_id, from_stream_id
|
|
)
|
|
|
|
# make sure we go through the devices in stream order
|
|
device_ids = sorted(
|
|
user_devices.keys(),
|
|
key=lambda i: query_map[(user_id, i)][0],
|
|
)
|
|
|
|
for device_id in device_ids:
|
|
device = user_devices[device_id]
|
|
stream_id, opentracing_context = query_map[(user_id, device_id)]
|
|
result = {
|
|
"user_id": user_id,
|
|
"device_id": device_id,
|
|
"prev_id": [prev_id] if prev_id else [],
|
|
"stream_id": stream_id,
|
|
"org.matrix.opentracing_context": opentracing_context,
|
|
}
|
|
|
|
prev_id = stream_id
|
|
|
|
if device is not None:
|
|
keys = device.keys
|
|
if keys:
|
|
result["keys"] = keys
|
|
|
|
device_display_name = device.display_name
|
|
if device_display_name:
|
|
result["device_display_name"] = device_display_name
|
|
else:
|
|
result["deleted"] = True
|
|
|
|
results.append((EduTypes.DEVICE_LIST_UPDATE, result))
|
|
|
|
return results
|
|
|
|
async def _get_last_device_update_for_remote_user(
|
|
self, destination: str, user_id: str, from_stream_id: int
|
|
) -> int:
|
|
def f(txn: LoggingTransaction) -> int:
|
|
prev_sent_id_sql = """
|
|
SELECT coalesce(max(stream_id), 0) as stream_id
|
|
FROM device_lists_outbound_last_success
|
|
WHERE destination = ? AND user_id = ? AND stream_id <= ?
|
|
"""
|
|
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
|
|
rows = txn.fetchall()
|
|
return rows[0][0]
|
|
|
|
return await self.db_pool.runInteraction(
|
|
"get_last_device_update_for_remote_user", f
|
|
)
|
|
|
|
async def mark_as_sent_devices_by_remote(
|
|
self, destination: str, stream_id: int
|
|
) -> None:
|
|
"""Mark that updates have successfully been sent to the destination."""
|
|
await self.db_pool.runInteraction(
|
|
"mark_as_sent_devices_by_remote",
|
|
self._mark_as_sent_devices_by_remote_txn,
|
|
destination,
|
|
stream_id,
|
|
)
|
|
|
|
def _mark_as_sent_devices_by_remote_txn(
|
|
self, txn: LoggingTransaction, destination: str, stream_id: int
|
|
) -> None:
|
|
# We update the device_lists_outbound_last_success with the successfully
|
|
# poked users.
|
|
sql = """
|
|
SELECT user_id, coalesce(max(o.stream_id), 0)
|
|
FROM device_lists_outbound_pokes as o
|
|
WHERE destination = ? AND o.stream_id <= ?
|
|
GROUP BY user_id
|
|
"""
|
|
txn.execute(sql, (destination, stream_id))
|
|
rows = txn.fetchall()
|
|
|
|
self.db_pool.simple_upsert_many_txn(
|
|
txn=txn,
|
|
table="device_lists_outbound_last_success",
|
|
key_names=("destination", "user_id"),
|
|
key_values=[(destination, user_id) for user_id, _ in rows],
|
|
value_names=("stream_id",),
|
|
value_values=((stream_id,) for _, stream_id in rows),
|
|
)
|
|
|
|
# Delete all sent outbound pokes
|
|
sql = """
|
|
DELETE FROM device_lists_outbound_pokes
|
|
WHERE destination = ? AND stream_id <= ?
|
|
"""
|
|
txn.execute(sql, (destination, stream_id))
|
|
|
|
async def add_user_signature_change_to_streams(
|
|
self, from_user_id: str, user_ids: List[str]
|
|
) -> int:
|
|
"""Persist that a user has made new signatures
|
|
|
|
Args:
|
|
from_user_id: the user who made the signatures
|
|
user_ids: the users who were signed
|
|
|
|
Returns:
|
|
The new stream ID.
|
|
"""
|
|
|
|
# TODO: this looks like it's _writing_. Should this be on DeviceStore rather
|
|
# than DeviceWorkerStore?
|
|
async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
|
|
await self.db_pool.runInteraction(
|
|
"add_user_sig_change_to_streams",
|
|
self._add_user_signature_change_txn,
|
|
from_user_id,
|
|
user_ids,
|
|
stream_id,
|
|
)
|
|
return stream_id
|
|
|
|
def _add_user_signature_change_txn(
|
|
self,
|
|
txn: LoggingTransaction,
|
|
from_user_id: str,
|
|
user_ids: List[str],
|
|
stream_id: int,
|
|
) -> None:
|
|
txn.call_after(
|
|
self._user_signature_stream_cache.entity_has_changed,
|
|
from_user_id,
|
|
stream_id,
|
|
)
|
|
self.db_pool.simple_insert_txn(
|
|
txn,
|
|
"user_signature_stream",
|
|
values={
|
|
"stream_id": stream_id,
|
|
"from_user_id": from_user_id,
|
|
"user_ids": json_encoder.encode(user_ids),
|
|
},
|
|
)
|
|
|
|
@abc.abstractmethod
|
|
def get_device_stream_token(self) -> int:
|
|
"""Get the current stream id from the _device_list_id_gen"""
|
|
...
|
|
|
|
@trace
|
|
async def get_user_devices_from_cache(
|
|
self, query_list: List[Tuple[str, str]]
|
|
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
|
|
"""Get the devices (and keys if any) for remote users from the cache.
|
|
|
|
Args:
|
|
query_list: List of (user_id, device_ids), if device_ids is
|
|
falsey then return all device ids for that user.
|
|
|
|
Returns:
|
|
A tuple of (user_ids_not_in_cache, results_map), where
|
|
user_ids_not_in_cache is a set of user_ids and results_map is a
|
|
mapping of user_id -> device_id -> device_info.
|
|
"""
|
|
user_ids = {user_id for user_id, _ in query_list}
|
|
user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
|
|
|
# We go and check if any of the users need to have their device lists
|
|
# resynced. If they do then we remove them from the cached list.
|
|
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
|
|
user_ids
|
|
)
|
|
user_ids_in_cache = {
|
|
user_id for user_id, stream_id in user_map.items() if stream_id
|
|
} - users_needing_resync
|
|
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
|
|
|
results: Dict[str, Dict[str, JsonDict]] = {}
|
|
for user_id, device_id in query_list:
|
|
if user_id not in user_ids_in_cache:
|
|
continue
|
|
|
|
if device_id:
|
|
device = await self._get_cached_user_device(user_id, device_id)
|
|
results.setdefault(user_id, {})[device_id] = device
|
|
else:
|
|
results[user_id] = await self.get_cached_devices_for_user(user_id)
|
|
|
|
set_tag("in_cache", results)
|
|
set_tag("not_in_cache", user_ids_not_in_cache)
|
|
|
|
return user_ids_not_in_cache, results
|
|
|
|
@cached(num_args=2, tree=True)
|
|
async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
|
|
content = await self.db_pool.simple_select_one_onecol(
|
|
table="device_lists_remote_cache",
|
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
|
retcol="content",
|
|
desc="_get_cached_user_device",
|
|
)
|
|
return db_to_json(content)
|
|
|
|
@cached()
|
|
async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
|
|
devices = await self.db_pool.simple_select_list(
|
|
table="device_lists_remote_cache",
|
|
keyvalues={"user_id": user_id},
|
|
retcols=("device_id", "content"),
|
|
desc="get_cached_devices_for_user",
|
|
)
|
|
return {
|
|
device["device_id"]: db_to_json(device["content"]) for device in devices
|
|
}
|
|
|
|
def get_cached_device_list_changes(
|
|
self,
|
|
from_key: int,
|
|
) -> Optional[List[str]]:
|
|
"""Get set of users whose devices have changed since `from_key`, or None
|
|
if that information is not in our cache.
|
|
"""
|
|
|
|
return self._device_list_stream_cache.get_all_entities_changed(from_key)
|
|
|
|
async def get_users_whose_devices_changed(
|
|
self,
|
|
from_key: int,
|
|
user_ids: Optional[Collection[str]] = None,
|
|
to_key: Optional[int] = None,
|
|
) -> Set[str]:
|
|
"""Get set of users whose devices have changed since `from_key` that
|
|
are in the given list of user_ids.
|
|
|
|
Args:
|
|
from_key: The minimum device lists stream token to query device list changes for,
|
|
exclusive.
|
|
user_ids: If provided, only check if these users have changed their device lists.
|
|
Otherwise changes from all users are returned.
|
|
to_key: The maximum device lists stream token to query device list changes for,
|
|
inclusive.
|
|
|
|
Returns:
|
|
The set of user_ids whose devices have changed since `from_key` (exclusive)
|
|
until `to_key` (inclusive).
|
|
"""
|
|
# Get set of users who *may* have changed. Users not in the returned
|
|
# list have definitely not changed.
|
|
user_ids_to_check: Optional[Collection[str]]
|
|
if user_ids is None:
|
|
# Get set of all users that have had device list changes since 'from_key'
|
|
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
|
|
from_key
|
|
)
|
|
else:
|
|
# The same as above, but filter results to only those users in 'user_ids'
|
|
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
|
|
user_ids, from_key
|
|
)
|
|
|
|
if not user_ids_to_check:
|
|
return set()
|
|
|
|
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
|
|
changes: Set[str] = set()
|
|
|
|
stream_id_where_clause = "stream_id > ?"
|
|
sql_args = [from_key]
|
|
|
|
if to_key:
|
|
stream_id_where_clause += " AND stream_id <= ?"
|
|
sql_args.append(to_key)
|
|
|
|
sql = f"""
|
|
SELECT DISTINCT user_id FROM device_lists_stream
|
|
WHERE {stream_id_where_clause}
|
|
AND
|
|
"""
|
|
|
|
# Query device changes with a batch of users at a time
|
|
# Assertion for mypy's benefit; see also
|
|
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
|
|
assert user_ids_to_check is not None
|
|
for chunk in batch_iter(user_ids_to_check, 100):
|
|
clause, args = make_in_list_sql_clause(
|
|
txn.database_engine, "user_id", chunk
|
|
)
|
|
txn.execute(sql + clause, sql_args + args)
|
|
changes.update(user_id for user_id, in txn)
|
|
|
|
return changes
|
|
|
|
return await self.db_pool.runInteraction(
|
|
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
|
|
)
|
|
|
|
async def get_users_whose_signatures_changed(
|
|
self, user_id: str, from_key: int
|
|
) -> Set[str]:
|
|
"""Get the users who have new cross-signing signatures made by `user_id` since
|
|
`from_key`.
|
|
|
|
Args:
|
|
user_id: the user who made the signatures
|
|
from_key: The device lists stream token
|
|
|
|
Returns:
|
|
A set of user IDs with updated signatures.
|
|
"""
|
|
|
|
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
|
|
sql = """
|
|
SELECT DISTINCT user_ids FROM user_signature_stream
|
|
WHERE from_user_id = ? AND stream_id > ?
|
|
"""
|
|
rows = await self.db_pool.execute(
|
|
"get_users_whose_signatures_changed", None, sql, user_id, from_key
|
|
)
|
|
return {user for row in rows for user in db_to_json(row[0])}
|
|
else:
|
|
return set()
|
|
|
|
async def get_all_device_list_changes_for_remotes(
|
|
self, instance_name: str, last_id: int, current_id: int, limit: int
|
|
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
|
"""Get updates for device lists replication stream.
|
|
|
|
Args:
|
|
instance_name: The writer we want to fetch updates from. Unused
|
|
here since there is only ever one writer.
|
|
last_id: The token to fetch updates from. Exclusive.
|
|
current_id: The token to fetch updates up to. Inclusive.
|
|
limit: The requested limit for the number of rows to return. The
|
|
function may return more or fewer rows.
|
|
|
|
Returns:
|
|
A tuple consisting of: the updates, a token to use to fetch
|
|
subsequent updates, and whether we returned fewer rows than exists
|
|
between the requested tokens due to the limit.
|
|
|
|
The token returned can be used in a subsequent call to this
|
|
function to get further updates.
|
|
|
|
The updates are a list of 2-tuples of stream ID and the row data
|
|
"""
|
|
|
|
if last_id == current_id:
|
|
return [], current_id, False
|
|
|
|
def _get_all_device_list_changes_for_remotes(
|
|
txn: Cursor,
|
|
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
|
# This query Does The Right Thing where it'll correctly apply the
|
|
# bounds to the inner queries.
|
|
sql = """
|
|
SELECT stream_id, entity FROM (
|
|
SELECT stream_id, user_id AS entity FROM device_lists_stream
|
|
UNION ALL
|
|
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
|
|
) AS e
|
|
WHERE ? < stream_id AND stream_id <= ?
|
|
ORDER BY stream_id ASC
|
|
LIMIT ?
|
|
"""
|
|
|
|
txn.execute(sql, (last_id, current_id, limit))
|
|
updates = [(row[0], row[1:]) for row in txn]
|
|
limited = False
|
|
upto_token = current_id
|
|
if len(updates) >= limit:
|
|
upto_token = updates[-1][0]
|
|
limited = True
|
|
|
|
return updates, upto_token, limited
|
|
|
|
return await self.db_pool.runInteraction(
|
|
"get_all_device_list_changes_for_remotes",
|
|
_get_all_device_list_changes_for_remotes,
|
|
)
|
|
|
|
@cached(max_entries=10000)
|
|
async def get_device_list_last_stream_id_for_remote(
|
|
self, user_id: str
|
|
) -> Optional[str]:
|
|
"""Get the last stream_id we got for a user. May be None if we haven't
|
|
got any information for them.
|
|
"""
|
|
return await self.db_pool.simple_select_one_onecol(
|
|
table="device_lists_remote_extremeties",
|
|
keyvalues={"user_id": user_id},
|
|
retcol="stream_id",
|
|
desc="get_device_list_last_stream_id_for_remote",
|
|
allow_none=True,
|
|
)
|
|
|
|
@cachedList(
|
|
cached_method_name="get_device_list_last_stream_id_for_remote",
|
|
list_name="user_ids",
|
|
)
|
|
async def get_device_list_last_stream_id_for_remotes(
|
|
self, user_ids: Iterable[str]
|
|
) -> Dict[str, Optional[str]]:
|
|
rows = await self.db_pool.simple_select_many_batch(
|
|
table="device_lists_remote_extremeties",
|
|
column="user_id",
|
|
iterable=user_ids,
|
|
retcols=("user_id", "stream_id"),
|
|
desc="get_device_list_last_stream_id_for_remotes",
|
|
)
|
|
|
|
results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
|
|
results.update({row["user_id"]: row["stream_id"] for row in rows})
|
|
|
|
return results
|
|
|
|
async def get_user_ids_requiring_device_list_resync(
|
|
self,
|
|
user_ids: Optional[Collection[str]] = None,
|
|
) -> Set[str]:
|
|
"""Given a list of remote users return the list of users that we
|
|
should resync the device lists for. If None is given instead of a list,
|
|
return every user that we should resync the device lists for.
|
|
|
|
Returns:
|
|
The IDs of users whose device lists need resync.
|
|
"""
|
|
if user_ids:
|
|
rows = await self.db_pool.simple_select_many_batch(
|
|
table="device_lists_remote_resync",
|
|
column="user_id",
|
|
iterable=user_ids,
|
|
retcols=("user_id",),
|
|
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
|
)
|
|
else:
|
|
rows = await self.db_pool.simple_select_list(
|
|
table="device_lists_remote_resync",
|
|
keyvalues=None,
|
|
retcols=("user_id",),
|
|
desc="get_user_ids_requiring_device_list_resync",
|
|
)
|
|
|
|
return {row["user_id"] for row in rows}
|
|
|
|
async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
|
|
"""Records that the server has reason to believe the cache of the devices
|
|
for the remote users is out of date.
|
|
"""
|
|
await self.db_pool.simple_upsert(
|
|
table="device_lists_remote_resync",
|
|
keyvalues={"user_id": user_id},
|
|
values={},
|
|
insertion_values={"added_ts": self._clock.time_msec()},
|
|
desc="mark_remote_user_device_cache_as_stale",
|
|
)
|
|
|
|
async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None:
|
|
# Remove the database entry that says we need to resync devices, after a resync
|
|
await self.db_pool.simple_delete(
|
|
table="device_lists_remote_resync",
|
|
keyvalues={"user_id": user_id},
|
|
desc="mark_remote_user_device_cache_as_valid",
|
|
)
|
|
|
|
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
|
|
"""Mark that we no longer track device lists for remote user."""
|
|
|
|
def _mark_remote_user_device_list_as_unsubscribed_txn(
|
|
txn: LoggingTransaction,
|
|
) -> None:
|
|
self.db_pool.simple_delete_txn(
|
|
txn,
|
|
table="device_lists_remote_extremeties",
|
|
keyvalues={"user_id": user_id},
|
|
)
|
|
self._invalidate_cache_and_stream(
|
|
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
|
|
)
|
|
|
|
await self.db_pool.runInteraction(
|
|
"mark_remote_user_device_list_as_unsubscribed",
|
|
_mark_remote_user_device_list_as_unsubscribed_txn,
|
|
)
|
|
|
|
async def get_dehydrated_device(
|
|
self, user_id: str
|
|
) -> Optional[Tuple[str, JsonDict]]:
|
|
"""Retrieve the information for a dehydrated device.
|
|
|
|
Args:
|
|
user_id: the user whose dehydrated device we are looking for
|
|
Returns:
|
|
a tuple whose first item is the device ID, and the second item is
|
|
the dehydrated device information
|
|
"""
|
|
# FIXME: make sure device ID still exists in devices table
|
|
row = await self.db_pool.simple_select_one(
|
|
table="dehydrated_devices",
|
|
keyvalues={"user_id": user_id},
|
|
retcols=["device_id", "device_data"],
|
|
allow_none=True,
|
|
)
|
|
return (
|
|
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
|
|
)
|
|
|
|
def _store_dehydrated_device_txn(
|
|
self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
|
|
) -> Optional[str]:
|
|
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
|
txn,
|
|
table="dehydrated_devices",
|
|
keyvalues={"user_id": user_id},
|
|
retcol="device_id",
|
|
allow_none=True,
|
|
)
|
|
self.db_pool.simple_upsert_txn(
|
|
txn,
|
|
table="dehydrated_devices",
|
|
keyvalues={"user_id": user_id},
|
|
values={"device_id": device_id, "device_data": device_data},
|
|
)
|
|
return old_device_id
|
|
|
|
async def store_dehydrated_device(
|
|
self, user_id: str, device_id: str, device_data: JsonDict
|
|
) -> Optional[str]:
|
|
"""Store a dehydrated device for a user.
|
|
|
|
Args:
|
|
user_id: the user that we are storing the device for
|
|
device_id: the ID of the dehydrated device
|
|
device_data: the dehydrated device information
|
|
Returns:
|
|
device id of the user's previous dehydrated device, if any
|
|
"""
|
|
return await self.db_pool.runInteraction(
|
|
"store_dehydrated_device_txn",
|
|
self._store_dehydrated_device_txn,
|
|
user_id,
|
|
device_id,
|
|
json_encoder.encode(device_data),
|
|
)
|
|
|
|
async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
|
|
"""Remove a dehydrated device.
|
|
|
|
Args:
|
|
user_id: the user that the dehydrated device belongs to
|
|
device_id: the ID of the dehydrated device
|
|
"""
|
|
count = await self.db_pool.simple_delete(
|
|
"dehydrated_devices",
|
|
{"user_id": user_id, "device_id": device_id},
|
|
desc="remove_dehydrated_device",
|
|
)
|
|
return count >= 1
|
|
|
|
@wrap_as_background_process("prune_old_outbound_device_pokes")
|
|
async def _prune_old_outbound_device_pokes(
|
|
self, prune_age: int = 24 * 60 * 60 * 1000
|
|
) -> None:
|
|
"""Delete old entries out of the device_lists_outbound_pokes to ensure
|
|
that we don't fill up due to dead servers.
|
|
|
|
Normally, we try to send device updates as a delta since a previous known point:
|
|
this is done by setting the prev_id in the m.device_list_update EDU. However,
|
|
for that to work, we have to have a complete record of each change to
|
|
each device, which can add up to quite a lot of data.
|
|
|
|
An alternative mechanism is that, if the remote server sees that it has missed
|
|
an entry in the stream_id sequence for a given user, it will request a full
|
|
list of that user's devices. Hence, we can reduce the amount of data we have to
|
|
store (and transmit in some future transaction), by clearing almost everything
|
|
for a given destination out of the database, and having the remote server
|
|
resync.
|
|
|
|
All we need to do is make sure we keep at least one row for each
|
|
(user, destination) pair, to remind us to send a m.device_list_update EDU for
|
|
that user when the destination comes back. It doesn't matter which device
|
|
we keep.
|
|
"""
|
|
yesterday = self._clock.time_msec() - prune_age
|
|
|
|
def _prune_txn(txn: LoggingTransaction) -> None:
|
|
# look for (user, destination) pairs which have an update older than
|
|
# the cutoff.
|
|
#
|
|
# For each pair, we also need to know the most recent stream_id, and
|
|
# an arbitrary device_id at that stream_id.
|
|
select_sql = """
|
|
SELECT
|
|
dlop1.destination,
|
|
dlop1.user_id,
|
|
MAX(dlop1.stream_id) AS stream_id,
|
|
(SELECT MIN(dlop2.device_id) AS device_id FROM
|
|
device_lists_outbound_pokes dlop2
|
|
WHERE dlop2.destination = dlop1.destination AND
|
|
dlop2.user_id=dlop1.user_id AND
|
|
dlop2.stream_id=MAX(dlop1.stream_id)
|
|
)
|
|
FROM device_lists_outbound_pokes dlop1
|
|
GROUP BY destination, user_id
|
|
HAVING min(ts) < ? AND count(*) > 1
|
|
"""
|
|
|
|
txn.execute(select_sql, (yesterday,))
|
|
rows = txn.fetchall()
|
|
|
|
if not rows:
|
|
return
|
|
|
|
logger.info(
|
|
"Pruning old outbound device list updates for %i users/destinations: %s",
|
|
len(rows),
|
|
shortstr((row[0], row[1]) for row in rows),
|
|
)
|
|
|
|
# we want to keep the update with the highest stream_id for each user.
|
|
#
|
|
# there might be more than one update (with different device_ids) with the
|
|
# same stream_id, so we also delete all but one rows with the max stream id.
|
|
delete_sql = """
|
|
DELETE FROM device_lists_outbound_pokes
|
|
WHERE destination = ? AND user_id = ? AND (
|
|
stream_id < ? OR
|
|
(stream_id = ? AND device_id != ?)
|
|
)
|
|
"""
|
|
count = 0
|
|
for (destination, user_id, stream_id, device_id) in rows:
|
|
txn.execute(
|
|
delete_sql, (destination, user_id, stream_id, stream_id, device_id)
|
|
)
|
|
count += txn.rowcount
|
|
|
|
# Since we've deleted unsent deltas, we need to remove the entry
|
|
# of last successful sent so that the prev_ids are correctly set.
|
|
sql = """
|
|
DELETE FROM device_lists_outbound_last_success
|
|
WHERE destination = ? AND user_id = ?
|
|
"""
|
|
txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
|
|
|
|
logger.info("Pruned %d device list outbound pokes", count)
|
|
|
|
await self.db_pool.runInteraction(
|
|
"_prune_old_outbound_device_pokes",
|
|
_prune_txn,
|
|
)
|
|
|
|
async def get_local_devices_not_accessed_since(
|
|
self, since_ms: int
|
|
) -> Dict[str, List[str]]:
|
|
"""Retrieves local devices that haven't been accessed since a given date.
|
|
|
|
Args:
|
|
since_ms: the timestamp to select on, every device with a last access date
|
|
from before that time is returned.
|
|
|
|
Returns:
|
|
A dictionary with an entry for each user with at least one device matching
|
|
the request, which value is a list of the device ID(s) for the corresponding
|
|
device(s).
|
|
"""
|
|
|
|
def get_devices_not_accessed_since_txn(
|
|
txn: LoggingTransaction,
|
|
) -> List[Dict[str, str]]:
|
|
sql = """
|
|
SELECT user_id, device_id
|
|
FROM devices WHERE last_seen < ? AND hidden = FALSE
|
|
"""
|
|
txn.execute(sql, (since_ms,))
|
|
return self.db_pool.cursor_to_dict(txn)
|
|
|
|
rows = await self.db_pool.runInteraction(
|
|
"get_devices_not_accessed_since",
|
|
get_devices_not_accessed_since_txn,
|
|
)
|
|
|
|
devices: Dict[str, List[str]] = {}
|
|
for row in rows:
|
|
# Remote devices are never stale from our point of view.
|
|
if self.hs.is_mine_id(row["user_id"]):
|
|
user_devices = devices.setdefault(row["user_id"], [])
|
|
user_devices.append(row["device_id"])
|
|
|
|
return devices
|
|
|
|
@cached()
|
|
async def _get_min_device_lists_changes_in_room(self) -> int:
|
|
"""Returns the minimum stream ID that we have entries for
|
|
`device_lists_changes_in_room`
|
|
"""
|
|
|
|
return await self.db_pool.simple_select_one_onecol(
|
|
table="device_lists_changes_in_room",
|
|
keyvalues={},
|
|
retcol="COALESCE(MIN(stream_id), 0)",
|
|
desc="get_min_device_lists_changes_in_room",
|
|
)
|
|
|
|
async def get_device_list_changes_in_rooms(
|
|
self, room_ids: Collection[str], from_id: int
|
|
) -> Optional[Set[str]]:
|
|
"""Return the set of users whose devices have changed in the given rooms
|
|
since the given stream ID.
|
|
|
|
Returns None if the given stream ID is too old.
|
|
"""
|
|
|
|
if not room_ids:
|
|
return set()
|
|
|
|
min_stream_id = await self._get_min_device_lists_changes_in_room()
|
|
|
|
if min_stream_id > from_id:
|
|
return None
|
|
|
|
sql = """
|
|
SELECT DISTINCT user_id FROM device_lists_changes_in_room
|
|
WHERE {clause} AND stream_id >= ?
|
|
"""
|
|
|
|
def _get_device_list_changes_in_rooms_txn(
|
|
txn: LoggingTransaction,
|
|
clause: str,
|
|
args: List[Any],
|
|
) -> Set[str]:
|
|
txn.execute(sql.format(clause=clause), args)
|
|
return {user_id for user_id, in txn}
|
|
|
|
changes = set()
|
|
for chunk in batch_iter(room_ids, 1000):
|
|
clause, args = make_in_list_sql_clause(
|
|
self.database_engine, "room_id", chunk
|
|
)
|
|
args.append(from_id)
|
|
|
|
changes |= await self.db_pool.runInteraction(
|
|
"get_device_list_changes_in_rooms",
|
|
_get_device_list_changes_in_rooms_txn,
|
|
clause,
|
|
args,
|
|
)
|
|
|
|
return changes
|
|
|
|
|
|
class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|
def __init__(
|
|
self,
|
|
database: DatabasePool,
|
|
db_conn: LoggingDatabaseConnection,
|
|
hs: "HomeServer",
|
|
):
|
|
super().__init__(database, db_conn, hs)
|
|
|
|
self.db_pool.updates.register_background_index_update(
|
|
"device_lists_stream_idx",
|
|
index_name="device_lists_stream_user_id",
|
|
table="device_lists_stream",
|
|
columns=["user_id", "device_id"],
|
|
)
|
|
|
|
# create a unique index on device_lists_remote_cache
|
|
self.db_pool.updates.register_background_index_update(
|
|
"device_lists_remote_cache_unique_idx",
|
|
index_name="device_lists_remote_cache_unique_id",
|
|
table="device_lists_remote_cache",
|
|
columns=["user_id", "device_id"],
|
|
unique=True,
|
|
)
|
|
|
|
# And one on device_lists_remote_extremeties
|
|
self.db_pool.updates.register_background_index_update(
|
|
"device_lists_remote_extremeties_unique_idx",
|
|
index_name="device_lists_remote_extremeties_unique_idx",
|
|
table="device_lists_remote_extremeties",
|
|
columns=["user_id"],
|
|
unique=True,
|
|
)
|
|
|
|
# once they complete, we can remove the old non-unique indexes.
|
|
self.db_pool.updates.register_background_update_handler(
|
|
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
|
|
self._drop_device_list_streams_non_unique_indexes,
|
|
)
|
|
|
|
# clear out duplicate device list outbound pokes
|
|
self.db_pool.updates.register_background_update_handler(
|
|
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
|
|
self._remove_duplicate_outbound_pokes,
|
|
)
|
|
|
|
async def _drop_device_list_streams_non_unique_indexes(
|
|
self, progress: JsonDict, batch_size: int
|
|
) -> int:
|
|
def f(conn: LoggingDatabaseConnection) -> None:
|
|
txn = conn.cursor()
|
|
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
|
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
|
txn.close()
|
|
|
|
await self.db_pool.runWithConnection(f)
|
|
await self.db_pool.updates._end_background_update(
|
|
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
|
|
)
|
|
return 1
|
|
|
|
async def _remove_duplicate_outbound_pokes(
|
|
self, progress: JsonDict, batch_size: int
|
|
) -> int:
|
|
# for some reason, we have accumulated duplicate entries in
|
|
# device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
|
|
# efficient.
|
|
#
|
|
# For each duplicate, we delete all the existing rows and put one back.
|
|
|
|
KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
|
|
last_row = progress.get(
|
|
"last_row",
|
|
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
|
|
)
|
|
|
|
def _txn(txn: LoggingTransaction) -> int:
|
|
clause, args = make_tuple_comparison_clause(
|
|
[(x, last_row[x]) for x in KEY_COLS]
|
|
)
|
|
sql = """
|
|
SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
|
|
FROM device_lists_outbound_pokes
|
|
WHERE %s
|
|
GROUP BY %s
|
|
HAVING count(*) > 1
|
|
ORDER BY %s
|
|
LIMIT ?
|
|
""" % (
|
|
clause, # WHERE
|
|
",".join(KEY_COLS), # GROUP BY
|
|
",".join(KEY_COLS), # ORDER BY
|
|
)
|
|
txn.execute(sql, args + [batch_size])
|
|
rows = self.db_pool.cursor_to_dict(txn)
|
|
|
|
row = None
|
|
for row in rows:
|
|
self.db_pool.simple_delete_txn(
|
|
txn,
|
|
"device_lists_outbound_pokes",
|
|
{x: row[x] for x in KEY_COLS},
|
|
)
|
|
|
|
row["sent"] = False
|
|
self.db_pool.simple_insert_txn(
|
|
txn,
|
|
"device_lists_outbound_pokes",
|
|
row,
|
|
)
|
|
|
|
if row:
|
|
self.db_pool.updates._background_update_progress_txn(
|
|
txn,
|
|
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
|
|
{"last_row": row},
|
|
)
|
|
|
|
return len(rows)
|
|
|
|
rows = await self.db_pool.runInteraction(
|
|
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn
|
|
)
|
|
|
|
if not rows:
|
|
await self.db_pool.updates._end_background_update(
|
|
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
|
|
)
|
|
|
|
return rows
|
|
|
|
|
|
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
def __init__(
|
|
self,
|
|
database: DatabasePool,
|
|
db_conn: LoggingDatabaseConnection,
|
|
hs: "HomeServer",
|
|
):
|
|
super().__init__(database, db_conn, hs)
|
|
|
|
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
|
# the device exists.
|
|
self.device_id_exists_cache: LruCache[
|
|
Tuple[str, str], Literal[True]
|
|
] = LruCache(cache_name="device_id_exists", max_size=10000)
|
|
|
|
async def store_device(
|
|
self,
|
|
user_id: str,
|
|
device_id: str,
|
|
initial_device_display_name: Optional[str],
|
|
auth_provider_id: Optional[str] = None,
|
|
auth_provider_session_id: Optional[str] = None,
|
|
) -> bool:
|
|
"""Ensure the given device is known; add it to the store if not
|
|
|
|
Args:
|
|
user_id: id of user associated with the device
|
|
device_id: id of device
|
|
initial_device_display_name: initial displayname of the device.
|
|
Ignored if device exists.
|
|
auth_provider_id: The SSO IdP the user used, if any.
|
|
auth_provider_session_id: The session ID (sid) got from a OIDC login.
|
|
|
|
Returns:
|
|
Whether the device was inserted or an existing device existed with that ID.
|
|
|
|
Raises:
|
|
StoreError: if the device is already in use
|
|
"""
|
|
key = (user_id, device_id)
|
|
if self.device_id_exists_cache.get(key, None):
|
|
return False
|
|
|
|
try:
|
|
inserted = await self.db_pool.simple_upsert(
|
|
"devices",
|
|
keyvalues={
|
|
"user_id": user_id,
|
|
"device_id": device_id,
|
|
},
|
|
values={},
|
|
insertion_values={
|
|
"display_name": initial_device_display_name,
|
|
"hidden": False,
|
|
},
|
|
desc="store_device",
|
|
)
|
|
if not inserted:
|
|
# if the device already exists, check if it's a real device, or
|
|
# if the device ID is reserved by something else
|
|
hidden = await self.db_pool.simple_select_one_onecol(
|
|
"devices",
|
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
|
retcol="hidden",
|
|
)
|
|
if hidden:
|
|
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
|
|
|
|
if auth_provider_id and auth_provider_session_id:
|
|
await self.db_pool.simple_insert(
|
|
"device_auth_providers",
|
|
values={
|
|
"user_id": user_id,
|
|
"device_id": device_id,
|
|
"auth_provider_id": auth_provider_id,
|
|
"auth_provider_session_id": auth_provider_session_id,
|
|
},
|
|
desc="store_device_auth_provider",
|
|
)
|
|
|
|
self.device_id_exists_cache.set(key, True)
|
|
return inserted
|
|
except StoreError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(
|
|
"store_device with device_id=%s(%r) user_id=%s(%r)"
|
|
" display_name=%s(%r) failed: %s",
|
|
type(device_id).__name__,
|
|
device_id,
|
|
type(user_id).__name__,
|
|
user_id,
|
|
type(initial_device_display_name).__name__,
|
|
initial_device_display_name,
|
|
e,
|
|
)
|
|
raise StoreError(500, "Problem storing device.")
|
|
|
|
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
|
|
"""Deletes several devices.
|
|
|
|
Args:
|
|
user_id: The ID of the user which owns the devices
|
|
device_ids: The IDs of the devices to delete
|
|
"""
|
|
|
|
def _delete_devices_txn(txn: LoggingTransaction) -> None:
|
|
self.db_pool.simple_delete_many_txn(
|
|
txn,
|
|
table="devices",
|
|
column="device_id",
|
|
values=device_ids,
|
|
keyvalues={"user_id": user_id, "hidden": False},
|
|
)
|
|
|
|
self.db_pool.simple_delete_many_txn(
|
|
txn,
|
|
table="device_inbox",
|
|
column="device_id",
|
|
values=device_ids,
|
|
keyvalues={"user_id": user_id},
|
|
)
|
|
|
|
self.db_pool.simple_delete_many_txn(
|
|
txn,
|
|
table="device_auth_providers",
|
|
column="device_id",
|
|
values=device_ids,
|
|
keyvalues={"user_id": user_id},
|
|
)
|
|
|
|
await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
|
|
for device_id in device_ids:
|
|
self.device_id_exists_cache.invalidate((user_id, device_id))
|
|
|
|
async def update_device(
|
|
self, user_id: str, device_id: str, new_display_name: Optional[str] = None
|
|
) -> None:
|
|
"""Update a device. Only updates the device if it is not marked as
|
|
hidden.
|
|
|
|
Args:
|
|
user_id: The ID of the user which owns the device
|
|
device_id: The ID of the device to update
|
|
new_display_name: new displayname for device; None to leave unchanged
|
|
Raises:
|
|
StoreError: if the device is not found
|
|
"""
|
|
updates = {}
|
|
if new_display_name is not None:
|
|
updates["display_name"] = new_display_name
|
|
if not updates:
|
|
return None
|
|
await self.db_pool.simple_update_one(
|
|
table="devices",
|
|
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
|
updatevalues=updates,
|
|
desc="update_device",
|
|
)
|
|
|
|
async def update_remote_device_list_cache_entry(
|
|
self, user_id: str, device_id: str, content: JsonDict, stream_id: str
|
|
) -> None:
|
|
"""Updates a single device in the cache of a remote user's devicelist.
|
|
|
|
Note: assumes that we are the only thread that can be updating this user's
|
|
device list.
|
|
|
|
Args:
|
|
user_id: User to update device list for
|
|
device_id: ID of decivice being updated
|
|
content: new data on this device
|
|
stream_id: the version of the device list
|
|
"""
|
|
await self.db_pool.runInteraction(
|
|
"update_remote_device_list_cache_entry",
|
|
self._update_remote_device_list_cache_entry_txn,
|
|
user_id,
|
|
device_id,
|
|
content,
|
|
stream_id,
|
|
)
|
|
|
|
def _update_remote_device_list_cache_entry_txn(
|
|
self,
|
|
txn: LoggingTransaction,
|
|
user_id: str,
|
|
device_id: str,
|
|
content: JsonDict,
|
|
stream_id: str,
|
|
) -> None:
|
|
"""Delete, update or insert a cache entry for this (user, device) pair."""
|
|
if content.get("deleted"):
|
|
self.db_pool.simple_delete_txn(
|
|
txn,
|
|
table="device_lists_remote_cache",
|
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
|
)
|
|
|
|
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
|
|
else:
|
|
self.db_pool.simple_upsert_txn(
|
|
txn,
|
|
table="device_lists_remote_cache",
|
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
|
values={"content": json_encoder.encode(content)},
|
|
# we don't need to lock, because we assume we are the only thread
|
|
# updating this user's devices.
|
|
lock=False,
|
|
)
|
|
|
|
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
|
|
txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
|
|
txn.call_after(
|
|
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
|
|
)
|
|
|
|
self.db_pool.simple_upsert_txn(
|
|
txn,
|
|
table="device_lists_remote_extremeties",
|
|
keyvalues={"user_id": user_id},
|
|
values={"stream_id": stream_id},
|
|
# again, we can assume we are the only thread updating this user's
|
|
# extremity.
|
|
lock=False,
|
|
)
|
|
|
|
async def update_remote_device_list_cache(
|
|
self, user_id: str, devices: List[dict], stream_id: int
|
|
) -> None:
|
|
"""Replace the entire cache of the remote user's devices.
|
|
|
|
Note: assumes that we are the only thread that can be updating this user's
|
|
device list.
|
|
|
|
Args:
|
|
user_id: User to update device list for
|
|
devices: list of device objects supplied over federation
|
|
stream_id: the version of the device list
|
|
"""
|
|
await self.db_pool.runInteraction(
|
|
"update_remote_device_list_cache",
|
|
self._update_remote_device_list_cache_txn,
|
|
user_id,
|
|
devices,
|
|
stream_id,
|
|
)
|
|
|
|
def _update_remote_device_list_cache_txn(
|
|
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
|
|
) -> None:
|
|
"""Replace the list of cached devices for this user with the given list."""
|
|
self.db_pool.simple_delete_txn(
|
|
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
|
)
|
|
|
|
self.db_pool.simple_insert_many_txn(
|
|
txn,
|
|
table="device_lists_remote_cache",
|
|
keys=("user_id", "device_id", "content"),
|
|
values=[
|
|
(user_id, content["device_id"], json_encoder.encode(content))
|
|
for content in devices
|
|
],
|
|
)
|
|
|
|
txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
|
|
txn.call_after(self._get_cached_user_device.invalidate, (user_id,))
|
|
txn.call_after(
|
|
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
|
|
)
|
|
|
|
self.db_pool.simple_upsert_txn(
|
|
txn,
|
|
table="device_lists_remote_extremeties",
|
|
keyvalues={"user_id": user_id},
|
|
values={"stream_id": stream_id},
|
|
# we don't need to lock, because we can assume we are the only thread
|
|
# updating this user's extremity.
|
|
lock=False,
|
|
)
|
|
|
|
async def add_device_change_to_streams(
|
|
self,
|
|
user_id: str,
|
|
device_ids: Collection[str],
|
|
room_ids: Collection[str],
|
|
) -> Optional[int]:
|
|
"""Persist that a user's devices have been updated, and which hosts
|
|
(if any) should be poked.
|
|
|
|
Args:
|
|
user_id: The ID of the user whose device changed.
|
|
device_ids: The IDs of any changed devices. If empty, this function will
|
|
return None.
|
|
room_ids: The rooms that the user is in
|
|
|
|
Returns:
|
|
The maximum stream ID of device list updates that were added to the database, or
|
|
None if no updates were added.
|
|
"""
|
|
if not device_ids:
|
|
return None
|
|
|
|
context = get_active_span_text_map()
|
|
|
|
def add_device_changes_txn(
|
|
txn: LoggingTransaction, stream_ids: List[int]
|
|
) -> None:
|
|
self._add_device_change_to_stream_txn(
|
|
txn,
|
|
user_id,
|
|
device_ids,
|
|
stream_ids,
|
|
)
|
|
|
|
self._add_device_outbound_room_poke_txn(
|
|
txn,
|
|
user_id,
|
|
device_ids,
|
|
room_ids,
|
|
stream_ids,
|
|
context,
|
|
)
|
|
|
|
async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
|
|
len(device_ids)
|
|
) as stream_ids:
|
|
await self.db_pool.runInteraction(
|
|
"add_device_change_to_stream",
|
|
add_device_changes_txn,
|
|
stream_ids,
|
|
)
|
|
|
|
return stream_ids[-1]
|
|
|
|
def _add_device_change_to_stream_txn(
|
|
self,
|
|
txn: LoggingTransaction,
|
|
user_id: str,
|
|
device_ids: Collection[str],
|
|
stream_ids: List[int],
|
|
) -> None:
|
|
txn.call_after(
|
|
self._device_list_stream_cache.entity_has_changed,
|
|
user_id,
|
|
stream_ids[-1],
|
|
)
|
|
|
|
min_stream_id = stream_ids[0]
|
|
|
|
# Delete older entries in the table, as we really only care about
|
|
# when the latest change happened.
|
|
txn.execute_batch(
|
|
"""
|
|
DELETE FROM device_lists_stream
|
|
WHERE user_id = ? AND device_id = ? AND stream_id < ?
|
|
""",
|
|
[(user_id, device_id, min_stream_id) for device_id in device_ids],
|
|
)
|
|
|
|
self.db_pool.simple_insert_many_txn(
|
|
txn,
|
|
table="device_lists_stream",
|
|
keys=("stream_id", "user_id", "device_id"),
|
|
values=[
|
|
(stream_id, user_id, device_id)
|
|
for stream_id, device_id in zip(stream_ids, device_ids)
|
|
],
|
|
)
|
|
|
|
def _add_device_outbound_poke_to_stream_txn(
|
|
self,
|
|
txn: LoggingTransaction,
|
|
user_id: str,
|
|
device_ids: Iterable[str],
|
|
hosts: Collection[str],
|
|
stream_ids: List[int],
|
|
context: Optional[Dict[str, str]],
|
|
) -> None:
|
|
for host in hosts:
|
|
txn.call_after(
|
|
self._device_list_federation_stream_cache.entity_has_changed,
|
|
host,
|
|
stream_ids[-1],
|
|
)
|
|
|
|
now = self._clock.time_msec()
|
|
stream_id_iterator = iter(stream_ids)
|
|
|
|
encoded_context = json_encoder.encode(context)
|
|
self.db_pool.simple_insert_many_txn(
|
|
txn,
|
|
table="device_lists_outbound_pokes",
|
|
keys=(
|
|
"destination",
|
|
"stream_id",
|
|
"user_id",
|
|
"device_id",
|
|
"sent",
|
|
"ts",
|
|
"opentracing_context",
|
|
),
|
|
values=[
|
|
(
|
|
destination,
|
|
next(stream_id_iterator),
|
|
user_id,
|
|
device_id,
|
|
not self.hs.is_mine_id(
|
|
user_id
|
|
), # We only need to send out update for *our* users
|
|
now,
|
|
encoded_context if whitelisted_homeserver(destination) else "{}",
|
|
)
|
|
for destination in hosts
|
|
for device_id in device_ids
|
|
],
|
|
)
|
|
|
|
def _add_device_outbound_room_poke_txn(
|
|
self,
|
|
txn: LoggingTransaction,
|
|
user_id: str,
|
|
device_ids: Iterable[str],
|
|
room_ids: Collection[str],
|
|
stream_ids: List[int],
|
|
context: Dict[str, str],
|
|
) -> None:
|
|
"""Record the user in the room has updated their device."""
|
|
|
|
encoded_context = json_encoder.encode(context)
|
|
|
|
# The `device_lists_changes_in_room.stream_id` column matches the
|
|
# corresponding `stream_id` of the update in the `device_lists_stream`
|
|
# table, i.e. all rows persisted for the same device update will have
|
|
# the same `stream_id` (but different room IDs).
|
|
self.db_pool.simple_insert_many_txn(
|
|
txn,
|
|
table="device_lists_changes_in_room",
|
|
keys=(
|
|
"user_id",
|
|
"device_id",
|
|
"room_id",
|
|
"stream_id",
|
|
"converted_to_destinations",
|
|
"opentracing_context",
|
|
),
|
|
values=[
|
|
(
|
|
user_id,
|
|
device_id,
|
|
room_id,
|
|
stream_id,
|
|
# We only need to calculate outbound pokes for local users
|
|
not self.hs.is_mine_id(user_id),
|
|
encoded_context,
|
|
)
|
|
for room_id in room_ids
|
|
for device_id, stream_id in zip(device_ids, stream_ids)
|
|
],
|
|
)
|
|
|
|
async def get_uncoverted_outbound_room_pokes(
|
|
self, limit: int = 10
|
|
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
|
|
"""Get device list changes by room that have not yet been handled and
|
|
written to `device_lists_outbound_pokes`.
|
|
|
|
Returns:
|
|
A list of user ID, device ID, room ID, stream ID and optional opentracing context.
|
|
"""
|
|
|
|
sql = """
|
|
SELECT user_id, device_id, room_id, stream_id, opentracing_context
|
|
FROM device_lists_changes_in_room
|
|
WHERE NOT converted_to_destinations
|
|
ORDER BY stream_id
|
|
LIMIT ?
|
|
"""
|
|
|
|
def get_uncoverted_outbound_room_pokes_txn(
|
|
txn: LoggingTransaction,
|
|
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
|
|
txn.execute(sql, (limit,))
|
|
|
|
return [
|
|
(
|
|
user_id,
|
|
device_id,
|
|
room_id,
|
|
stream_id,
|
|
db_to_json(opentracing_context),
|
|
)
|
|
for user_id, device_id, room_id, stream_id, opentracing_context in txn
|
|
]
|
|
|
|
return await self.db_pool.runInteraction(
|
|
"get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
|
|
)
|
|
|
|
async def add_device_list_outbound_pokes(
|
|
self,
|
|
user_id: str,
|
|
device_id: str,
|
|
room_id: str,
|
|
stream_id: int,
|
|
hosts: Collection[str],
|
|
context: Optional[Dict[str, str]],
|
|
) -> None:
|
|
"""Queue the device update to be sent to the given set of hosts,
|
|
calculated from the room ID.
|
|
|
|
Marks the associated row in `device_lists_changes_in_room` as handled.
|
|
"""
|
|
|
|
def add_device_list_outbound_pokes_txn(
|
|
txn: LoggingTransaction, stream_ids: List[int]
|
|
) -> None:
|
|
if hosts:
|
|
self._add_device_outbound_poke_to_stream_txn(
|
|
txn,
|
|
user_id=user_id,
|
|
device_ids=[device_id],
|
|
hosts=hosts,
|
|
stream_ids=stream_ids,
|
|
context=context,
|
|
)
|
|
|
|
self.db_pool.simple_update_txn(
|
|
txn,
|
|
table="device_lists_changes_in_room",
|
|
keyvalues={
|
|
"user_id": user_id,
|
|
"device_id": device_id,
|
|
"stream_id": stream_id,
|
|
"room_id": room_id,
|
|
},
|
|
updatevalues={"converted_to_destinations": True},
|
|
)
|
|
|
|
if not hosts:
|
|
# If there are no hosts then we don't try and generate stream IDs.
|
|
return await self.db_pool.runInteraction(
|
|
"add_device_list_outbound_pokes",
|
|
add_device_list_outbound_pokes_txn,
|
|
[],
|
|
)
|
|
|
|
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
|
|
return await self.db_pool.runInteraction(
|
|
"add_device_list_outbound_pokes",
|
|
add_device_list_outbound_pokes_txn,
|
|
stream_ids,
|
|
)
|