Track device list updates per room. (#12321)

This is a first step in dealing with #7721.

The idea is basically that rather than calculating the full set of users a device list update needs to be sent to up front, we instead simply record the rooms the user was in at the time of the change. This will allow a few things:

1. we can defer calculating the set of remote servers that need to be poked about the change; and
2. during `/sync` and `/keys/changes` we can avoid also avoid calculating users who share rooms with other users, and instead just look at the rooms that have changed.

However, care needs to be taken to correctly handle server downgrades. As such this PR writes to both `device_lists_changes_in_room` and the `device_lists_outbound_pokes` table synchronously. In a future release we can then bump the database schema compat version to `69` and then we can assume that the new `device_lists_changes_in_room` exists and is handled.

There is a temporary option to disable writing to `device_lists_outbound_pokes` synchronously, allowing us to test the new code path does work (and by implication upgrading to a future release and downgrading to this one will work correctly).

Note: Ideally we'd do the calculation of room to servers on a worker (e.g. the background worker), but currently only master can write to the `device_list_outbound_pokes` table.
This commit is contained in:
Erik Johnston 2022-04-04 15:25:20 +01:00 committed by GitHub
parent 80839a44f1
commit 5c9e39e619
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 390 additions and 47 deletions

View file

@ -810,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
@ -1528,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def add_device_change_to_streams(
self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
self,
user_id: str,
device_ids: Collection[str],
hosts: Optional[Collection[str]],
room_ids: Collection[str],
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@ -1537,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user whose device changed.
device_ids: The IDs of any changed devices. If empty, this function will
return None.
hosts: The remote destinations that should be notified of the change.
hosts: The remote destinations that should be notified of the change. If
None then the set of hosts have *not* been calculated, and will be
calculated later by a background task.
room_ids: The rooms that the user is in
Returns:
The maximum stream ID of device list updates that were added to the database, or
@ -1546,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return None
async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
context = get_active_span_text_map()
def add_device_changes_txn(
txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
):
self._add_device_change_to_stream_txn(
txn,
user_id,
device_ids,
stream_ids,
stream_ids_for_device_change,
)
if not hosts:
return stream_ids[-1]
self._add_device_outbound_room_poke_txn(
txn,
user_id,
device_ids,
room_ids,
stream_ids_for_device_change,
context,
hosts_have_been_calculated=hosts is not None,
)
context = get_active_span_text_map()
async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
"add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn,
# If the set of hosts to send to has not been calculated yet (and so
# `hosts` is None) or there are no `hosts` to send to, then skip
# trying to persist them to the DB.
if not hosts:
return
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id,
device_ids,
hosts,
stream_ids,
stream_ids_for_outbound_pokes,
context,
)
# `device_lists_stream` wants a stream ID per device update.
num_stream_ids = len(device_ids)
if hosts:
# `device_lists_outbound_pokes` wants a different stream ID for
# each row, which is a row per host per device update.
num_stream_ids += len(hosts) * len(device_ids)
async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
stream_ids_for_device_change = stream_ids[: len(device_ids)]
stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
await self.db_pool.runInteraction(
"add_device_change_to_stream",
add_device_changes_txn,
stream_ids_for_device_change,
stream_ids_for_outbound_pokes,
)
return stream_ids[-1]
def _add_device_change_to_stream_txn(
@ -1617,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
hosts: Collection[str],
stream_ids: List[str],
stream_ids: List[int],
context: Dict[str, str],
) -> None:
for host in hosts:
@ -1628,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
now = self._clock.time_msec()
next_stream_id = iter(stream_ids)
stream_id_iterator = iter(stream_ids)
encoded_context = json_encoder.encode(context)
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@ -1645,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
(
destination,
next(next_stream_id),
next(stream_id_iterator),
user_id,
device_id,
False,
now,
json_encoder.encode(context)
if whitelisted_homeserver(destination)
else "{}",
encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
for device_id in device_ids
],
)
def _add_device_outbound_room_poke_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
stream_ids: List[str],
context: Dict[str, str],
hosts_have_been_calculated: bool,
) -> None:
"""Record the user in the room has updated their device.
Args:
hosts_have_been_calculated: True if `device_lists_outbound_pokes`
has been updated already with the updates.
"""
# We only need to convert to outbound pokes if they are our user.
converted_to_destinations = (
hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
)
encoded_context = json_encoder.encode(context)
# The `device_lists_changes_in_room.stream_id` column matches the
# corresponding `stream_id` of the update in the `device_lists_stream`
# table, i.e. all rows persisted for the same device update will have
# the same `stream_id` (but different room IDs).
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_changes_in_room",
keys=(
"user_id",
"device_id",
"room_id",
"stream_id",
"converted_to_destinations",
"opentracing_context",
),
values=[
(
user_id,
device_id,
room_id,
stream_id,
converted_to_destinations,
encoded_context,
)
for room_id in room_ids
for device_id, stream_id in zip(device_ids, stream_ids)
],
)
async def get_uncoverted_outbound_room_pokes(
self, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
"""Get device list changes by room that have not yet been handled and
written to `device_lists_outbound_pokes`.
Returns:
A list of user ID, device ID, room ID, stream ID and optional opentracing context.
"""
sql = """
SELECT user_id, device_id, room_id, stream_id, opentracing_context
FROM device_lists_changes_in_room
WHERE NOT converted_to_destinations
ORDER BY stream_id
LIMIT ?
"""
def get_uncoverted_outbound_room_pokes_txn(txn):
txn.execute(sql, (limit,))
return txn.fetchall()
return await self.db_pool.runInteraction(
"get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
)
async def add_device_list_outbound_pokes(
self,
user_id: str,
device_id: str,
room_id: str,
stream_id: int,
hosts: Collection[str],
context: Optional[Dict[str, str]],
) -> None:
"""Queue the device update to be sent to the given set of hosts,
calculated from the room ID.
Marks the associated row in `device_lists_changes_in_room` as handled.
"""
def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
if hosts:
self._add_device_outbound_poke_to_stream_txn(
txn,
user_id=user_id,
device_ids=[device_id],
hosts=hosts,
stream_ids=stream_ids,
context=context,
)
self.db_pool.simple_update_txn(
txn,
table="device_lists_changes_in_room",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"stream_id": stream_id,
"room_id": room_id,
},
updatevalues={"converted_to_destinations": True},
)
if not hosts:
# If there are no hosts then we don't try and generate stream IDs.
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
[],
)
async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
return await self.db_pool.runInteraction(
"add_device_list_outbound_pokes",
add_device_list_outbound_pokes_txn,
stream_ids,
)