Optimize query for fetching to-device messages in /sync (#16805)

The current query supports passing in a list of users, which generates a
query using `user_id = ANY(..)`. This is generates a less efficient
query plan that is notably slower than a simple `user_id = ?` condition.

Note: The new function is mostly a copy and paste and then a
simplification of the existing function.
This commit is contained in:
Erik Johnston 2024-01-11 13:37:57 +00:00 committed by GitHub
parent b11f7b5122
commit c43f751013
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 77 deletions

1
changelog.d/16805.misc Normal file
View File

@ -0,0 +1 @@
Optimize query for fetching to-device messages in `/sync`.

View File

@ -245,33 +245,74 @@ class DeviceInboxWorkerStore(SQLBaseStore):
* The last-processed stream ID. Subsequent calls of this function with the * The last-processed stream ID. Subsequent calls of this function with the
same device should pass this value as 'from_stream_id'. same device should pass this value as 'from_stream_id'.
""" """
( if not self._device_inbox_stream_cache.has_entity_changed(
user_id_device_id_to_messages, user_id, from_stream_id
last_processed_stream_id, ):
) = await self._get_device_messages(
user_ids=[user_id],
device_id=device_id,
from_stream_id=from_stream_id,
to_stream_id=to_stream_id,
limit=limit,
)
if not user_id_device_id_to_messages:
# There were no messages! # There were no messages!
return [], to_stream_id return [], to_stream_id
# Extract the messages, no need to return the user and device ID again def get_device_messages_txn(
to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), []) txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
sql = """
SELECT stream_id, message_json FROM device_inbox
WHERE user_id = ? AND device_id = ?
AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (user_id, device_id, from_stream_id, to_stream_id, limit))
return to_device_messages, last_processed_stream_id # Create and fill a dictionary of (user ID, device ID) -> list of messages
# intended for each device.
last_processed_stream_pos = to_stream_id
to_device_messages: List[JsonDict] = []
rowcount = 0
for row in txn:
rowcount += 1
last_processed_stream_pos = row[0]
message_dict = db_to_json(row[1])
# Store the device details
to_device_messages.append(message_dict)
# start a new span for each message, so that we can tag each separately
with start_active_span("get_to_device_message"):
set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"])
set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"])
set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
set_tag(
SynapseTags.TO_DEVICE_MSGID,
message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
)
if rowcount == limit:
# We ended up bumping up against the message limit. There may be more messages
# to retrieve. Return what we have, as well as the last stream position that
# was processed.
#
# The caller is expected to set this as the lower (exclusive) bound
# for the next query of this device.
return to_device_messages, last_processed_stream_pos
# The limit was not reached, thus we know that recipient_device_to_messages
# contains all to-device messages for the given device and stream id range.
#
# We return to_stream_id, which the caller should then provide as the lower
# (exclusive) bound on the next query of this device.
return to_device_messages, to_stream_id
return await self.db_pool.runInteraction(
"get_messages_for_device", get_device_messages_txn
)
async def _get_device_messages( async def _get_device_messages(
self, self,
user_ids: Collection[str], user_ids: Collection[str],
from_stream_id: int, from_stream_id: int,
to_stream_id: int, to_stream_id: int,
device_id: Optional[str] = None,
limit: Optional[int] = None,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
""" """
Retrieve pending to-device messages for a collection of user devices. Retrieve pending to-device messages for a collection of user devices.
@ -291,11 +332,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_ids: The user IDs to filter device messages by. user_ids: The user IDs to filter device messages by.
from_stream_id: The lower boundary of stream id to filter with (exclusive). from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive). to_stream_id: The upper boundary of stream id to filter with (inclusive).
device_id: A device ID to query to-device messages for. If not provided, to-device
messages from all device IDs for the given user IDs will be queried. May not be
provided if `user_ids` contains more than one entry.
limit: The maximum number of to-device messages to return. Can only be used when
passing a single user ID / device ID tuple.
Returns: Returns:
A tuple containing: A tuple containing:
@ -308,30 +345,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
logger.warning("No users provided upon querying for device IDs") logger.warning("No users provided upon querying for device IDs")
return {}, to_stream_id return {}, to_stream_id
# Prevent a query for one user's device also retrieving another user's device with
# the same device ID (device IDs are not unique across users).
if len(user_ids) > 1 and device_id is not None:
raise AssertionError(
"Programming error: 'device_id' cannot be supplied to "
"_get_device_messages when >1 user_id has been provided"
)
# A limit can only be applied when querying for a single user ID / device ID tuple.
# See the docstring of this function for more details.
if limit is not None and device_id is None:
raise AssertionError(
"Programming error: _get_device_messages was passed 'limit' "
"without a specific user_id/device_id"
)
user_ids_to_query: Set[str] = set() user_ids_to_query: Set[str] = set()
device_ids_to_query: Set[str] = set()
# Note that a device ID could be an empty str
if device_id is not None:
# If a device ID was passed, use it to filter results.
# Otherwise, device IDs will be derived from the given collection of user IDs.
device_ids_to_query.add(device_id)
# Determine which users have devices with pending messages # Determine which users have devices with pending messages
for user_id in user_ids: for user_id in user_ids:
@ -355,20 +369,20 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# hidden devices should not receive to-device messages. # hidden devices should not receive to-device messages.
# Note that this is more efficient than just dropping `device_id` from the query, # Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)` # since device_inbox has an index on `(user_id, device_id, stream_id)`
if not device_ids_to_query:
user_device_dicts = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"hidden": False},
retcols=("device_id",),
),
)
device_ids_to_query.update({row[0] for row in user_device_dicts}) user_device_dicts = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"hidden": False},
retcols=("device_id",),
),
)
device_ids_to_query = {row[0] for row in user_device_dicts}
if not device_ids_to_query: if not device_ids_to_query:
# We've ended up with no devices to query. # We've ended up with no devices to query.
@ -400,22 +414,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
to_stream_id, to_stream_id,
) )
# If a limit was provided, limit the data retrieved from the database
if limit is not None:
sql += "LIMIT ?"
sql_args += (limit,)
txn.execute(sql, sql_args) txn.execute(sql, sql_args)
# Create and fill a dictionary of (user ID, device ID) -> list of messages # Create and fill a dictionary of (user ID, device ID) -> list of messages
# intended for each device. # intended for each device.
last_processed_stream_pos = to_stream_id
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
rowcount = 0 rowcount = 0
for row in txn: for row in txn:
rowcount += 1 rowcount += 1
last_processed_stream_pos = row[0]
recipient_user_id = row[1] recipient_user_id = row[1]
recipient_device_id = row[2] recipient_device_id = row[2]
message_dict = db_to_json(row[3]) message_dict = db_to_json(row[3])
@ -436,18 +443,6 @@ class DeviceInboxWorkerStore(SQLBaseStore):
message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
) )
if limit is not None and rowcount == limit:
# We ended up bumping up against the message limit. There may be more messages
# to retrieve. Return what we have, as well as the last stream position that
# was processed.
#
# The caller is expected to set this as the lower (exclusive) bound
# for the next query of this device.
return recipient_device_to_messages, last_processed_stream_pos
# The limit was not reached, thus we know that recipient_device_to_messages
# contains all to-device messages for the given device and stream id range.
#
# We return to_stream_id, which the caller should then provide as the lower # We return to_stream_id, which the caller should then provide as the lower
# (exclusive) bound on the next query of this device. # (exclusive) bound on the next query of this device.
return recipient_device_to_messages, to_stream_id return recipient_device_to_messages, to_stream_id