Split fetching device keys and signatures into two transactions (#8233)

I think this is simpler (and moves stuff out of the db threads)
This commit is contained in:
Richard van der Hoff 2020-09-03 18:27:26 +01:00 committed by GitHub
parent 208e1d3eb3
commit f97f9485ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 68 additions and 46 deletions

1
changelog.d/8233.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

View File

@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause from synapse.storage.database import make_in_list_sql_clause
from synapse.storage.types import Cursor
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -45,8 +46,9 @@ class DeviceKeyLookupResult:
# key) and "signatures" (a signature of the structure by the ed25519 key) # key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str]) key_json = attr.ib(type=Optional[str])
# cross-signing sigs # cross-signing sigs on this device.
signatures = attr.ib(type=Optional[Dict], default=None) # dict from (signing user_id)->(signing device_id)->sig
signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)
class EndToEndKeyWorkerStore(SQLBaseStore): class EndToEndKeyWorkerStore(SQLBaseStore):
@ -133,7 +135,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
include_all_devices: bool = False, include_all_devices: bool = False,
include_deleted_devices: bool = False, include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys, together with their cross-signatures. """Fetch a list of device keys
Any cross-signatures made on the keys by the owner of the device are also
included.
Args: Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None query_list: List of pairs of user_ids and device_ids. Device id can be None
@ -154,22 +159,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result = await self.db_pool.runInteraction( result = await self.db_pool.runInteraction(
"get_e2e_device_keys", "get_e2e_device_keys",
self._get_e2e_device_keys_and_signatures_txn, self._get_e2e_device_keys_txn,
query_list, query_list,
include_all_devices, include_all_devices,
include_deleted_devices, include_deleted_devices,
) )
# get the (user_id, device_id) tuples to look up cross-signatures for
signature_query = (
(user_id, device_id)
for user_id, dev in result.items()
for device_id, d in dev.items()
if d is not None
)
for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
)
# add each cross-signing signature to the correct device in the result dict.
for (user_id, key_id, device_id, signature) in cross_sigs_result:
target_device_result = result[user_id][device_id]
target_device_signatures = target_device_result.signatures
signing_user_signatures = target_device_signatures.setdefault(
user_id, {}
)
signing_user_signatures[key_id] = signature
log_kv(result) log_kv(result)
return result return result
def _get_e2e_device_keys_and_signatures_txn( def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Get information on devices from the database
The results include the device's keys and self-signatures, but *not* any
cross-signing signatures which have been added subsequently (for which, see
get_e2e_device_keys_and_signatures)
"""
query_clauses = [] query_clauses = []
query_params = [] query_params = []
signature_query_clauses = []
signature_query_params = []
if include_all_devices is False: if include_all_devices is False:
include_deleted_devices = False include_deleted_devices = False
@ -180,20 +214,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for (user_id, device_id) in query_list: for (user_id, device_id) in query_list:
query_clause = "user_id = ?" query_clause = "user_id = ?"
query_params.append(user_id) query_params.append(user_id)
signature_query_clause = "target_user_id = ?"
signature_query_params.append(user_id)
if device_id is not None: if device_id is not None:
query_clause += " AND device_id = ?" query_clause += " AND device_id = ?"
query_params.append(device_id) query_params.append(device_id)
signature_query_clause += " AND target_device_id = ?"
signature_query_params.append(device_id)
signature_query_clause += " AND user_id = ?"
signature_query_params.append(user_id)
query_clauses.append(query_clause) query_clauses.append(query_clause)
signature_query_clauses.append(signature_query_clause)
sql = ( sql = (
"SELECT user_id, device_id, " "SELECT user_id, device_id, "
@ -221,41 +247,36 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_id in deleted_devices: for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None result.setdefault(user_id, {})[device_id] = None
# get signatures on the device return result
signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
Returns signatures made by the owners of the devices.
Returns: a list of results; each entry in the list is a tuple of
(user_id, key_id, target_device_id, signature).
"""
signature_query_clauses = []
signature_query_params = []
for (user_id, device_id) in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
)
signature_query_params.extend([user_id, device_id, user_id])
signature_sql = """
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s
""" % (
" OR ".join("(" + q + ")" for q in signature_query_clauses) " OR ".join("(" + q + ")" for q in signature_query_clauses)
) )
txn.execute(signature_sql, signature_query_params) txn.execute(signature_sql, signature_query_params)
rows = self.db_pool.cursor_to_dict(txn) return txn.fetchall()
# add each cross-signing signature to the correct device in the result dict.
for row in rows:
signing_user_id = row["user_id"]
signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
signature = row["signature"]
target_user_result = result.get(target_user_id)
if not target_user_result:
continue
target_device_result = target_user_result.get(target_device_id)
if not target_device_result:
# note that target_device_result will be None for deleted devices.
continue
target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
signing_user_signatures[signing_key_id] = signature
return result
async def get_e2e_one_time_keys( async def get_e2e_one_time_keys(
self, user_id: str, device_id: str, key_ids: List[str] self, user_id: str, device_id: str, key_ids: List[str]