Change device list streams to have one row per ID (#7010)

* Add 'device_lists_outbound_pokes' as extra table.

This makes sure we check all the relevant tables to get the current max
stream ID.

Currently not doing so isn't problematic as the max stream ID in
`device_lists_outbound_pokes` is the same as in `device_lists_stream`,
however that will change.

* Change device lists stream to have one row per id.

This will make it possible to process the streams more incrementally,
avoiding having to process large chunks at once.

* Change device list replication to match new semantics.

Instead of sending down batches of user ID/host tuples, send down a row
per entity (user ID or host).

* Newsfile

* Remove handling of multiple rows per ID

* Fix worker handling

* Comments from review
This commit is contained in:
Erik Johnston 2020-03-19 11:36:53 +00:00 committed by GitHub
commit a319cb1dd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 112 additions and 130 deletions

1
changelog.d/7010.misc Normal file
View File

@ -0,0 +1 @@
Change device list streams to have one row per ID.

View File

@ -676,7 +676,8 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
elif stream_name == "device_lists": elif stream_name == "device_lists":
all_room_ids = set() all_room_ids = set()
for row in rows: for row in rows:
room_ids = await self.store.get_rooms_for_user(row.user_id) if row.entity.startswith("@"):
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids) all_room_ids.update(room_ids)
self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids) self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
elif stream_name == "presence": elif stream_name == "presence":
@ -774,7 +775,10 @@ class FederationSenderHandler(object):
# ... as well as device updates and messages # ... as well as device updates and messages
elif stream_name == DeviceListsStream.NAME: elif stream_name == DeviceListsStream.NAME:
hosts = {row.destination for row in rows} # The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
hosts = {row.entity for row in rows if not row.entity.startswith("@")}
for host in hosts: for host in hosts:
self.federation_sender.send_device_messages(host) self.federation_sender.send_device_messages(host)

View File

@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self.hs = hs self.hs = hs
self._device_list_id_gen = SlavedIdTracker( self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id" db_conn,
"device_lists_stream",
"stream_id",
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
) )
device_list_max = self._device_list_id_gen.get_current_token() device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache( self._device_list_stream_cache = StreamChangeCache(
@ -55,23 +61,27 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token) self._device_list_id_gen.advance(token)
for row in rows: self._invalidate_caches_for_devices(token, rows)
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
elif stream_name == UserSignatureStream.NAME: elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows: for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token) self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows( return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows stream_name, token, rows
) )
def _invalidate_caches_for_devices(self, token, user_id, destination): def _invalidate_caches_for_devices(self, token, rows):
self._device_list_stream_cache.entity_has_changed(user_id, token) for row in rows:
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
if row.entity.startswith("@"):
self._device_list_stream_cache.entity_has_changed(row.entity, token)
self.get_cached_devices_for_user.invalidate((row.entity,))
self._get_cached_user_device.invalidate_many((row.entity,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
if destination: else:
self._device_list_federation_stream_cache.entity_has_changed( self._device_list_federation_stream_cache.entity_has_changed(
destination, token row.entity, token
) )
self.get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))

View File

@ -94,9 +94,13 @@ PublicRoomsStreamRow = namedtuple(
"network_id", # str, optional "network_id", # str, optional
), ),
) )
DeviceListsStreamRow = namedtuple(
"DeviceListsStreamRow", ("user_id", "destination") # str # str
) @attr.s
class DeviceListsStreamRow:
entity = attr.ib(type=str)
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
TagAccountDataStreamRow = namedtuple( TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
@ -363,7 +367,8 @@ class PublicRoomsStream(Stream):
class DeviceListsStream(Stream): class DeviceListsStream(Stream):
"""Someone added/changed/removed a device """Either a user has updated their devices or a remote server needs to be
told about a device update.
""" """
NAME = "device_lists" NAME = "device_lists"

View File

@ -144,7 +144,10 @@ class DataStore(
db_conn, db_conn,
"device_lists_stream", "device_lists_stream",
"stream_id", "stream_id",
extra_tables=[("user_signature_stream", "stream_id")], extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
],
) )
self._cross_signing_id_gen = StreamIdGenerator( self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id" db_conn, "e2e_cross_signing_keys", "stream_id"

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Tuple
from six import iteritems from six import iteritems
@ -31,7 +32,7 @@ from synapse.logging.opentracing import (
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database from synapse.storage.database import Database, LoggingTransaction
from synapse.types import Collection, get_verify_key_from_cross_signing_key from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import ( from synapse.util.caches.descriptors import (
Cache, Cache,
@ -112,23 +113,13 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed: if not has_changed:
return now_stream_id, [] return now_stream_id, []
# We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
# device has the same stream_id as the second-to-last device. If so,
# then we ignore all devices with that stream_id and only send the
# devices with a lower stream_id.
#
# If when culling the list we end up with no devices afterwards, we
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.db.runInteraction( updates = yield self.db.runInteraction(
"get_device_updates_by_remote", "get_device_updates_by_remote",
self._get_device_updates_by_remote_txn, self._get_device_updates_by_remote_txn,
destination, destination,
from_stream_id, from_stream_id,
now_stream_id, now_stream_id,
limit + 1, limit,
) )
# Return an empty list if there are no updates # Return an empty list if there are no updates
@ -166,14 +157,6 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version, "device_id": verify_key.version,
} }
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
if len(updates) > limit:
stream_id_cutoff = updates[-1][2]
now_stream_id = stream_id_cutoff - 1
else:
stream_id_cutoff = None
# Perform the equivalent of a GROUP BY # Perform the equivalent of a GROUP BY
# #
# Iterate through the updates list and copy non-duplicate # Iterate through the updates list and copy non-duplicate
@ -192,10 +175,6 @@ class DeviceWorkerStore(SQLBaseStore):
query_map = {} query_map = {}
cross_signing_keys_by_user = {} cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates: for user_id, device_id, update_stream_id, update_context in updates:
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
if ( if (
user_id in master_key_by_user user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"] and device_id == master_key_by_user[user_id]["device_id"]
@ -218,17 +197,6 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id: if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context) query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
# steam_id.
# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote( results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map destination, from_stream_id, query_map
) )
@ -607,21 +575,26 @@ class DeviceWorkerStore(SQLBaseStore):
else: else:
return set() return set()
def get_all_device_list_changes_for_remotes(self, from_key, to_key): async def get_all_device_list_changes_for_remotes(
"""Return a list of `(stream_id, user_id, destination)` which is the self, from_key: int, to_key: int
combined list of changes to devices, and which destinations need to be ) -> List[Tuple[int, str]]:
poked. `destination` may be None if no destinations need to be poked. """Return a list of `(stream_id, entity)` which is the combined list of
changes to devices and which destinations need to be poked. Entity is
either a user ID (starting with '@') or a remote destination.
""" """
# We do a group by here as there can be a large number of duplicate
# entries, since we throw away device IDs. # This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """ sql = """
SELECT MAX(stream_id) AS stream_id, user_id, destination SELECT stream_id, entity FROM (
FROM device_lists_stream SELECT stream_id, user_id AS entity FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
""" """
return self.db.execute(
return await self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
) )
@ -1017,30 +990,50 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""Persist that a user's devices have been updated, and which hosts """Persist that a user's devices have been updated, and which hosts
(if any) should be poked. (if any) should be poked.
""" """
with self._device_list_id_gen.get_next() as stream_id: if not device_ids:
return
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
yield self.db.runInteraction( yield self.db.runInteraction(
"add_device_change_to_streams", "add_device_change_to_stream",
self._add_device_change_txn, self._add_device_change_to_stream_txn,
user_id,
device_ids,
stream_ids,
)
if not hosts:
return stream_ids[-1]
context = get_active_span_text_map()
with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
yield self.db.runInteraction(
"add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn,
user_id, user_id,
device_ids, device_ids,
hosts, hosts,
stream_id, stream_ids,
context,
) )
return stream_id
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): return stream_ids[-1]
now = self._clock.time_msec()
def _add_device_change_to_stream_txn(
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Collection[str],
stream_ids: List[str],
):
txn.call_after( txn.call_after(
self._device_list_stream_cache.entity_has_changed, user_id, stream_id self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_id,
) )
min_stream_id = stream_ids[0]
# Delete older entries in the table, as we really only care about # Delete older entries in the table, as we really only care about
# when the latest change happened. # when the latest change happened.
txn.executemany( txn.executemany(
@ -1048,7 +1041,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
DELETE FROM device_lists_stream DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ? WHERE user_id = ? AND device_id = ? AND stream_id < ?
""", """,
[(user_id, device_id, stream_id) for device_id in device_ids], [(user_id, device_id, min_stream_id) for device_id in device_ids],
) )
self.db.simple_insert_many_txn( self.db.simple_insert_many_txn(
@ -1056,11 +1049,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_stream", table="device_lists_stream",
values=[ values=[
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id} {"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
for device_id in device_ids for stream_id, device_id in zip(stream_ids, device_ids)
], ],
) )
context = get_active_span_text_map() def _add_device_outbound_poke_to_stream_txn(
self, txn, user_id, device_ids, hosts, stream_ids, context,
):
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)
now = self._clock.time_msec()
next_stream_id = iter(stream_ids)
self.db.simple_insert_many_txn( self.db.simple_insert_many_txn(
txn, txn,
@ -1068,7 +1072,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[ values=[
{ {
"destination": destination, "destination": destination,
"stream_id": stream_id, "stream_id": next(next_stream_id),
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
"sent": False, "sent": False,

View File

@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
# Check original device_ids are contained within these updates # Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates) self._check_devices_in_updates(device_ids, device_updates)
@defer.inlineCallbacks
def test_get_device_updates_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments
# first add one device
device_ids1 = ["device_id0"]
yield self.store.add_device_change_to_streams(
"user_id", device_ids1, ["someotherhost"]
)
# then add 101
device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
yield self.store.add_device_change_to_streams(
"user_id", device_ids2, ["someotherhost"]
)
# then one more
device_ids3 = ["newdevice"]
yield self.store.add_device_change_to_streams(
"user_id", device_ids3, ["someotherhost"]
)
#
# now read them back.
#
# first we should get a single update
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", -1, limit=100
)
self._check_devices_in_updates(device_ids1, device_updates)
# Then we should get an empty list back as the 101 devices broke the limit
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self.assertEqual(len(device_updates), 0)
# The 101 devices should've been cleared, so we should now just get one device
# update
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self._check_devices_in_updates(device_ids3, device_updates)
def _check_devices_in_updates(self, expected_device_ids, device_updates): def _check_devices_in_updates(self, expected_device_ids, device_updates):
"""Check that an specific device ids exist in a list of device update EDUs""" """Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids)) self.assertEqual(len(device_updates), len(expected_device_ids))