Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates. (#11730)

Co-authored-by: David Robertson <davidr@element.io>
This commit is contained in:
reivilibre 2022-01-13 18:12:18 +00:00 committed by GitHub
parent 22abfca8d9
commit b602ba194b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 190 additions and 17 deletions

1
changelog.d/11730.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug introduced in Synapse v1.50.0rc1 whereby outbound federation could fail because too many EDUs were produced for device updates.

View File

@ -191,7 +191,7 @@ class DeviceWorkerStore(SQLBaseStore):
@trace @trace
async def get_device_updates_by_remote( async def get_device_updates_by_remote(
self, destination: str, from_stream_id: int, limit: int self, destination: str, from_stream_id: int, limit: int
) -> Tuple[int, List[Tuple[str, dict]]]: ) -> Tuple[int, List[Tuple[str, JsonDict]]]:
"""Get a stream of device updates to send to the given remote server. """Get a stream of device updates to send to the given remote server.
Args: Args:
@ -200,9 +200,10 @@ class DeviceWorkerStore(SQLBaseStore):
limit: Maximum number of device updates to return limit: Maximum number of device updates to return
Returns: Returns:
A mapping from the current stream id (ie, the stream id of the last - The current stream id (i.e. the stream id of the last update included
update included in the response), and the list of updates, where in the response); and
each update is a pair of EDU type and EDU contents. - The list of updates, where each update is a pair of EDU type and
EDU contents.
""" """
now_stream_id = self.get_device_stream_token() now_stream_id = self.get_device_stream_token()
@ -221,6 +222,9 @@ class DeviceWorkerStore(SQLBaseStore):
limit, limit,
) )
# We need to ensure `updates` doesn't grow too big.
# Currently: `len(updates) <= limit`.
# Return an empty list if there are no updates # Return an empty list if there are no updates
if not updates: if not updates:
return now_stream_id, [] return now_stream_id, []
@ -277,16 +281,43 @@ class DeviceWorkerStore(SQLBaseStore):
query_map = {} query_map = {}
cross_signing_keys_by_user = {} cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates: for user_id, device_id, update_stream_id, update_context in updates:
if ( # Calculate the remaining length budget.
# Note that, for now, each entry in `cross_signing_keys_by_user`
# gives rise to two device updates in the result, so those cost twice
# as much (and are the whole reason we need to separately calculate
# the budget; we know len(updates) <= limit otherwise!)
# N.B. len() on dicts is cheap since they store their size.
remaining_length_budget = limit - (
len(query_map) + 2 * len(cross_signing_keys_by_user)
)
assert remaining_length_budget >= 0
is_master_key_update = (
user_id in master_key_by_user user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"] and device_id == master_key_by_user[user_id]["device_id"]
): )
result = cross_signing_keys_by_user.setdefault(user_id, {}) is_self_signing_key_update = (
result["master_key"] = master_key_by_user[user_id]["key_info"]
elif (
user_id in self_signing_key_by_user user_id in self_signing_key_by_user
and device_id == self_signing_key_by_user[user_id]["device_id"] and device_id == self_signing_key_by_user[user_id]["device_id"]
)
is_cross_signing_key_update = (
is_master_key_update or is_self_signing_key_update
)
if (
is_cross_signing_key_update
and user_id not in cross_signing_keys_by_user
): ):
# This will give rise to 2 device updates.
# If we don't have the budget, stop here!
if remaining_length_budget < 2:
break
if is_master_key_update:
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = master_key_by_user[user_id]["key_info"]
elif is_self_signing_key_update:
result = cross_signing_keys_by_user.setdefault(user_id, {}) result = cross_signing_keys_by_user.setdefault(user_id, {})
result["self_signing_key"] = self_signing_key_by_user[user_id][ result["self_signing_key"] = self_signing_key_by_user[user_id][
"key_info" "key_info"
@ -294,23 +325,44 @@ class DeviceWorkerStore(SQLBaseStore):
else: else:
key = (user_id, device_id) key = (user_id, device_id)
if key not in query_map and remaining_length_budget < 1:
# We don't have space for a new entry
break
previous_update_stream_id, _ = query_map.get(key, (0, None)) previous_update_stream_id, _ = query_map.get(key, (0, None))
if update_stream_id > previous_update_stream_id: if update_stream_id > previous_update_stream_id:
# FIXME If this overwrites an older update, this discards the
# previous OpenTracing context.
# It might make it harder to track down issues using OpenTracing.
# If there's a good reason why it doesn't matter, a comment here
# about that would not hurt.
query_map[key] = (update_stream_id, update_context) query_map[key] = (update_stream_id, update_context)
# As this update has been added to the response, advance the stream
# position.
last_processed_stream_id = update_stream_id last_processed_stream_id = update_stream_id
# In the worst case scenario, each update is for a distinct user and is
# added either to the query_map or to cross_signing_keys_by_user,
# but not both:
# len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here,
# so len(query_map) + len(cross_signing_keys_by_user) <= limit.
results = await self._get_device_update_edus_by_remote( results = await self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map destination, from_stream_id, query_map
) )
# add the updated cross-signing keys to the results list # len(results) <= len(query_map) here,
# so len(results) + len(cross_signing_keys_by_user) <= limit.
# Add the updated cross-signing keys to the results list
for user_id, result in cross_signing_keys_by_user.items(): for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id result["user_id"] = user_id
results.append(("m.signing_key_update", result)) results.append(("m.signing_key_update", result))
# also send the unstable version # also send the unstable version
# FIXME: remove this when enough servers have upgraded # FIXME: remove this when enough servers have upgraded
# and remove the length budgeting above.
results.append(("org.matrix.signing_key_update", result)) results.append(("org.matrix.signing_key_update", result))
return last_processed_stream_id, results return last_processed_stream_id, results
@ -322,7 +374,7 @@ class DeviceWorkerStore(SQLBaseStore):
from_stream_id: int, from_stream_id: int,
now_stream_id: int, now_stream_id: int,
limit: int, limit: int,
): ) -> List[Tuple[str, str, int, Optional[str]]]:
"""Return device update information for a given remote destination """Return device update information for a given remote destination
Args: Args:
@ -333,7 +385,11 @@ class DeviceWorkerStore(SQLBaseStore):
limit: Maximum number of device updates to return limit: Maximum number of device updates to return
Returns: Returns:
List: List of device updates List: List of device update tuples:
- user_id
- device_id
- stream_id
- opentracing_context
""" """
# get the list of device updates that need to be sent # get the list of device updates that need to be sent
sql = """ sql = """
@ -357,15 +413,21 @@ class DeviceWorkerStore(SQLBaseStore):
Args: Args:
destination: The host the device updates are intended for destination: The host the device updates are intended for
from_stream_id: The minimum stream_id to filter updates by, exclusive from_stream_id: The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping query_map: Dictionary mapping (user_id, device_id) to
user_id/device_id to update stream_id and the relevant json-encoded (update stream_id, the relevant json-encoded opentracing context)
opentracing context
Returns: Returns:
List of objects representing an device update EDU List of objects representing a device update EDU.
Postconditions:
The returned list has a length not exceeding that of the query_map:
len(result) <= len(query_map)
""" """
devices = ( devices = (
await self.get_e2e_device_keys_and_signatures( await self.get_e2e_device_keys_and_signatures(
# Because these are (user_id, device_id) tuples with all
# device_ids not being None, the returned list's length will not
# exceed that of query_map.
query_map.keys(), query_map.keys(),
include_all_devices=True, include_all_devices=True,
include_deleted_devices=True, include_deleted_devices=True,

View File

@ -125,7 +125,7 @@ class DeviceStoreTestCase(HomeserverTestCase):
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"]) self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
) )
# Get all device updates ever meant for this remote # Get device updates meant for this remote
next_stream_id, device_updates = self.get_success( next_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=3) self.store.get_device_updates_by_remote("somehost", -1, limit=3)
) )
@ -155,6 +155,116 @@ class DeviceStoreTestCase(HomeserverTestCase):
# Check the newly-added device_ids are contained within these updates # Check the newly-added device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates) self._check_devices_in_updates(device_ids, device_updates)
# Check there are no more device updates left.
_, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
)
self.assertEqual(device_updates, [])
def test_get_device_updates_by_remote_cross_signing_key_updates(
self,
) -> None:
"""
Tests that `get_device_updates_by_remote` limits the length of the return value
properly when cross-signing key updates are present.
Current behaviour is that the cross-signing key updates will always come in pairs,
even if that means leaving an earlier batch one EDU short of the limit.
"""
assert self.hs.is_mine_id(
"@user_id:test"
), "Test not valid: this MXID should be considered local"
self.get_success(
self.store.set_e2e_cross_signing_key(
"@user_id:test",
"master",
{
"keys": {
"ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA="
},
"signatures": {
"@user_id:test": {
"ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA="
}
},
},
)
)
self.get_success(
self.store.set_e2e_cross_signing_key(
"@user_id:test",
"self_signing",
{
"keys": {
"ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA="
},
"signatures": {
"@user_id:test": {
"ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA="
}
},
},
)
)
# Add some device updates with sequential `stream_id`s
# Note that the public cross-signing keys occupy the same space as device IDs,
# so also notify that those have updated.
device_ids = [
"device_id1",
"device_id2",
"fakeMaster",
"fakeSelfSigning",
]
self.get_success(
self.store.add_device_change_to_streams(
"@user_id:test", device_ids, ["somehost"]
)
)
# Get device updates meant for this remote
next_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", -1, limit=3)
)
# Here we expect the device updates for `device_id1` and `device_id2`.
# That means we only receive 2 updates this time around.
# If we had a higher limit, we would expect to see the pair of
# (unstable-prefixed & unprefixed) signing key updates for the device
# represented by `fakeMaster` and `fakeSelfSigning`.
# Our implementation only sends these two variants together, so we get
# a short batch.
self.assertEqual(len(device_updates), 2, device_updates)
# Check the first two devices (device_id1, device_id2) came out.
self._check_devices_in_updates(device_ids[:2], device_updates)
# Get more device updates meant for this remote
next_stream_id, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
)
# The next 2 updates should be a cross-signing key update
# (the master key update and the self-signing key update are combined into
# one 'signing key update', but the cross-signing key update is emitted
# twice, once with an unprefixed type and once again with an unstable-prefixed type)
# (This is a temporary arrangement for backwards compatibility!)
self.assertEqual(len(device_updates), 2, device_updates)
self.assertEqual(
device_updates[0][0], "m.signing_key_update", device_updates[0]
)
self.assertEqual(
device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
)
# Check there are no more device updates left.
_, device_updates = self.get_success(
self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
)
self.assertEqual(device_updates, [])
def _check_devices_in_updates(self, expected_device_ids, device_updates): def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs""" """Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids)) self.assertEqual(len(device_updates), len(expected_device_ids))