Prevent multiple device list updates from breaking a batch send (#5156)

fixes #5153
This commit is contained in:
Andrew Morgan 2019-06-06 23:54:00 +01:00 committed by Richard van der Hoff
parent a11865016e
commit 2d1d7b7e6f
4 changed files with 198 additions and 33 deletions

1
changelog.d/5156.bugfix Normal file
View File

@ -0,0 +1 @@
Prevent federation device list updates breaking when processing multiple updates at once.

View File

@ -349,9 +349,10 @@ class PerDestinationQueue(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_new_device_messages(self, limit): def _get_new_device_messages(self, limit):
last_device_list = self._last_device_list_stream_id last_device_list = self._last_device_list_stream_id
# Will return at most 20 entries
# Retrieve list of new device updates to send to the destination
now_stream_id, results = yield self._store.get_devices_by_remote( now_stream_id, results = yield self._store.get_devices_by_remote(
self._destination, last_device_list self._destination, last_device_list, limit=limit,
) )
edus = [ edus = [
Edu( Edu(

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from six import iteritems, itervalues from six import iteritems
from canonicaljson import json from canonicaljson import json
@ -72,11 +72,14 @@ class DeviceWorkerStore(SQLBaseStore):
defer.returnValue({d["device_id"]: d for d in devices}) defer.returnValue({d["device_id"]: d for d in devices})
def get_devices_by_remote(self, destination, from_stream_id): @defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit):
"""Get stream of updates to send to remote servers """Get stream of updates to send to remote servers
Returns: Returns:
(int, list[dict]): current stream id and list of updates Deferred[tuple[int, list[dict]]]:
current stream id (ie, the stream id of the last update included in the
response), and the list of updates
""" """
now_stream_id = self._device_list_id_gen.get_current_token() now_stream_id = self._device_list_id_gen.get_current_token()
@ -84,55 +87,131 @@ class DeviceWorkerStore(SQLBaseStore):
destination, int(from_stream_id) destination, int(from_stream_id)
) )
if not has_changed: if not has_changed:
return (now_stream_id, []) defer.returnValue((now_stream_id, []))
return self.runInteraction( # We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
# device has the same stream_id as the second-to-last device. If so,
# then we ignore all devices with that stream_id and only send the
# devices with a lower stream_id.
#
# If when culling the list we end up with no devices afterwards, we
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
"get_devices_by_remote", "get_devices_by_remote",
self._get_devices_by_remote_txn, self._get_devices_by_remote_txn,
destination, destination,
from_stream_id, from_stream_id,
now_stream_id, now_stream_id,
limit + 1,
) )
def _get_devices_by_remote_txn( # Return an empty list if there are no updates
self, txn, destination, from_stream_id, now_stream_id if not updates:
): defer.returnValue((now_stream_id, []))
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
"""
txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
stream_id_cutoff = updates[-1][2]
now_stream_id = stream_id_cutoff - 1
else:
stream_id_cutoff = None
# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
# (user_id, device_id) entries into a map, with the value being
# the max stream_id across each set of duplicate entries
#
# maps (user_id, device_id) -> stream_id # maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn} # as long as their stream_id does not match that of the last row
query_map = {}
for update in updates:
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
# Stop processing updates
break
key = (update[0], update[1])
query_map[key] = max(query_map.get(key, 0), update[2])
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
# steam_id.
# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map: if not query_map:
return (now_stream_id, []) defer.returnValue((stream_id_cutoff, []))
if len(query_map) >= 20: results = yield self._get_device_update_edus_by_remote(
now_stream_id = max(stream_id for stream_id in itervalues(query_map)) destination,
from_stream_id,
query_map,
)
devices = self._get_e2e_device_keys_txn( defer.returnValue((now_stream_id, results))
txn,
def _get_devices_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
Args:
txn (LoggingTransaction): The transaction to execute
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
now_stream_id (int): The maximum stream_id to filter updates by, inclusive
limit (int): Maximum number of device updates to return
Returns:
List: List of device updates
"""
sql = """
SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
ORDER BY stream_id
LIMIT ?
"""
txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))
return list(txn)
@defer.inlineCallbacks
def _get_device_update_edus_by_remote(
self, destination, from_stream_id, query_map,
):
"""Returns a list of device update EDUs as well as E2EE keys
Args:
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): int]): Dictionary mapping
user_id/device_id to update stream_id
Returns:
List[Dict]: List of objects representing an device update EDU
"""
devices = yield self.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
query_map.keys(), query_map.keys(),
include_all_devices=True, include_all_devices=True,
include_deleted_devices=True, include_deleted_devices=True,
) )
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = [] results = []
for user_id, user_devices in iteritems(devices): for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before # The prev_id for the first row is always the last row before
# `from_stream_id` # `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) prev_id = yield self._get_last_device_update_for_remote_user(
rows = txn.fetchall() destination, user_id, from_stream_id,
prev_id = rows[0][0] )
for device_id, device in iteritems(user_devices): for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)] stream_id = query_map[(user_id, device_id)]
result = { result = {
@ -156,7 +235,22 @@ class DeviceWorkerStore(SQLBaseStore):
results.append(result) results.append(result)
return (now_stream_id, results) defer.returnValue(results)
def _get_last_device_update_for_remote_user(
self, destination, user_id, from_stream_id,
):
def f(txn):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
return rows[0][0]
return self.runInteraction("get_last_device_update_for_remote_user", f)
def mark_as_sent_devices_by_remote(self, destination, stream_id): def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination. """Mark that updates have successfully been sent to the destination.

View File

@ -71,6 +71,75 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
res["device2"], res["device2"],
) )
@defer.inlineCallbacks
def test_get_devices_by_remote(self):
device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id
yield self.store.add_device_change_to_streams(
"user_id", device_ids, ["somehost"],
)
# Get all device updates ever meant for this remote
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"somehost", -1, limit=100,
)
# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)
@defer.inlineCallbacks
def test_get_devices_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments
# first add one device
device_ids1 = ["device_id0"]
yield self.store.add_device_change_to_streams(
"user_id", device_ids1, ["someotherhost"],
)
# then add 101
device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
yield self.store.add_device_change_to_streams(
"user_id", device_ids2, ["someotherhost"],
)
# then one more
device_ids3 = ["newdevice"]
yield self.store.add_device_change_to_streams(
"user_id", device_ids3, ["someotherhost"],
)
#
# now read them back.
#
# first we should get a single update
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"someotherhost", -1, limit=100,
)
self._check_devices_in_updates(device_ids1, device_updates)
# Then we should get an empty list back as the 101 devices broke the limit
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"someotherhost", now_stream_id, limit=100,
)
self.assertEqual(len(device_updates), 0)
# The 101 devices should've been cleared, so we should now just get one device
# update
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
"someotherhost", now_stream_id, limit=100,
)
self._check_devices_in_updates(device_ids3, device_updates)
def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids))
received_device_ids = {update["device_id"] for update in device_updates}
self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_update_device(self): def test_update_device(self):
yield self.store.store_device("user_id", "device_id", "display_name 1") yield self.store.store_device("user_id", "device_id", "display_name 1")