Refactor _get_e2e_device_keys_txn to split large queries (#13956)

Instead of running a single large query, run a single query for
user-only lookups and additional queries for batches of user device
lookups.

Resolves #13580.

Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
Sean Quah 2022-10-03 13:46:36 +01:00 committed by GitHub
parent 061739d10f
commit d65862c41f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 114 additions and 28 deletions

1
changelog.d/13956.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where `POST /_matrix/client/v3/keys/query` requests could result in excessively large SQL queries.

View File

@ -2461,6 +2461,66 @@ def make_in_list_sql_clause(
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable) return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
# These overloads ensure that `columns` and `iterable` values have the same length.
# Suppress "Single overload definition, multiple required" complaint.
@overload # type: ignore[misc]
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, str],
iterable: Collection[Tuple[Any, Any]],
) -> Tuple[str, list]:
...
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, ...],
iterable: Collection[Tuple[Any, ...]],
) -> Tuple[str, list]:
"""Returns an SQL clause that checks the given tuple of columns is in the iterable.
Args:
database_engine
columns: Names of the columns in the tuple.
iterable: The tuples to check the columns against.
Returns:
A tuple of SQL query and the args
"""
if len(columns) == 0:
# Should be unreachable due to mypy, as long as the overloads are set up right.
if () in iterable:
return "TRUE", []
else:
return "FALSE", []
if len(columns) == 1:
# Use `= ANY(?)` on postgres.
return make_in_list_sql_clause(
database_engine, next(iter(columns)), [values[0] for values in iterable]
)
# There are multiple columns. Avoid using an `= ANY(?)` clause on postgres, as
# indices are not used when there are multiple columns. Instead, use an `IN`
# expression.
#
# `IN ((?, ...), ...)` with tuples is supported by postgres only, whereas
# `IN (VALUES (?, ...), ...)` is supported by both sqlite and postgres.
# Thus, the latter is chosen.
if len(iterable) == 0:
# A 0-length `VALUES` list is not allowed in sqlite or postgres.
# Also note that a 0-length `IN (...)` clause (not using `VALUES`) is not
# allowed in postgres.
return "FALSE", []
tuple_sql = "(%s)" % (",".join("?" for _ in columns),)
return "(%s) IN (VALUES %s)" % (
",".join(column for column in columns),
",".join(tuple_sql for _ in iterable),
), [value for values in iterable for value in values]
KV = TypeVar("KV") KV = TypeVar("KV")

View File

@ -43,6 +43,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection, LoggingDatabaseConnection,
LoggingTransaction, LoggingTransaction,
make_in_list_sql_clause, make_in_list_sql_clause,
make_tuple_in_list_sql_clause,
) )
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
@ -278,7 +279,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
def _get_e2e_device_keys_txn( def _get_e2e_device_keys_txn(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
query_list: Collection[Tuple[str, str]], query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False, include_all_devices: bool = False,
include_deleted_devices: bool = False, include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
@ -288,8 +289,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
cross-signing signatures which have been added subsequently (for which, see cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures) get_e2e_device_keys_and_signatures)
""" """
query_clauses = [] query_clauses: List[str] = []
query_params = [] query_params_list: List[List[object]] = []
if include_all_devices is False: if include_all_devices is False:
include_deleted_devices = False include_deleted_devices = False
@ -297,16 +298,38 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if include_deleted_devices: if include_deleted_devices:
deleted_devices = set(query_list) deleted_devices = set(query_list)
# Split the query list into queries for users and queries for particular
# devices.
user_list = []
user_device_list = []
for (user_id, device_id) in query_list: for (user_id, device_id) in query_list:
query_clause = "user_id = ?" if device_id is None:
query_params.append(user_id) user_list.append(user_id)
else:
user_device_list.append((user_id, device_id))
if device_id is not None: if user_list:
query_clause += " AND device_id = ?" user_id_in_list_clause, user_args = make_in_list_sql_clause(
query_params.append(device_id) txn.database_engine, "user_id", user_list
)
query_clauses.append(user_id_in_list_clause)
query_params_list.append(user_args)
query_clauses.append(query_clause) if user_device_list:
# Divide the device queries into batches, to avoid excessively large
# queries.
for user_device_batch in batch_iter(user_device_list, 1024):
(
user_device_id_in_list_clause,
user_device_args,
) = make_tuple_in_list_sql_clause(
txn.database_engine, ("user_id", "device_id"), user_device_batch
)
query_clauses.append(user_device_id_in_list_clause)
query_params_list.append(user_device_args)
result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
for query_clause, query_params in zip(query_clauses, query_params_list):
sql = ( sql = (
"SELECT user_id, device_id, " "SELECT user_id, device_id, "
" d.display_name, " " d.display_name, "
@ -316,13 +339,13 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
" WHERE %s AND NOT d.hidden" " WHERE %s AND NOT d.hidden"
) % ( ) % (
"LEFT" if include_all_devices else "INNER", "LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses), query_clause,
) )
txn.execute(sql, query_params) txn.execute(sql, query_params)
result: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]] = {}
for (user_id, device_id, display_name, key_json) in txn: for (user_id, device_id, display_name, key_json) in txn:
assert device_id is not None
if include_deleted_devices: if include_deleted_devices:
deleted_devices.remove((user_id, device_id)) deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult( result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
@ -331,6 +354,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if include_deleted_devices: if include_deleted_devices:
for user_id, device_id in deleted_devices: for user_id, device_id in deleted_devices:
if device_id is None:
continue
result.setdefault(user_id, {})[device_id] = None result.setdefault(user_id, {})[device_id] = None
return result return result