Make _get_e2e_device_keys_and_signatures_txn return an attrs (#8224)

this makes it a bit clearer what's going on.
This commit is contained in:
Richard van der Hoff 2020-09-02 11:47:26 +01:00 committed by GitHub
parent b939251c37
commit abeab964d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 20 deletions

1
changelog.d/8224.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

View File

@ -293,17 +293,17 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = stream_id prev_id = stream_id
if device is not None: if device is not None:
key_json = device.get("key_json", None) key_json = device.key_json
if key_json: if key_json:
result["keys"] = db_to_json(key_json) result["keys"] = db_to_json(key_json)
if "signatures" in device: if device.signatures:
for sig_user_id, sigs in device["signatures"].items(): for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault( result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {} sig_user_id, {}
).update(sigs) ).update(sigs)
device_display_name = device.get("device_display_name", None) device_display_name = device.display_name
if device_display_name: if device_display_name:
result["device_display_name"] = device_display_name result["device_display_name"] = device_display_name
else: else:

View File

@ -17,6 +17,7 @@
import abc import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection from twisted.enterprise.adbapi import Connection
@ -33,6 +34,21 @@ if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem from synapse.handlers.e2e_keys import SignatureListItem
@attr.s
class DeviceKeyLookupResult:
"""The type returned by _get_e2e_device_keys_and_signatures_txn"""
display_name = attr.ib(type=Optional[str])
# the key data from e2e_device_keys_json. Typically includes fields like
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])
# cross-signing sigs
signatures = attr.ib(type=Optional[Dict], default=None)
class EndToEndKeyWorkerStore(SQLBaseStore): class EndToEndKeyWorkerStore(SQLBaseStore):
async def get_e2e_device_keys_for_federation_query( async def get_e2e_device_keys_for_federation_query(
self, user_id: str self, user_id: str
@ -61,17 +77,17 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for device_id, device in user_devices.items(): for device_id, device in user_devices.items():
result = {"device_id": device_id} result = {"device_id": device_id}
key_json = device.get("key_json", None) key_json = device.key_json
if key_json: if key_json:
result["keys"] = db_to_json(key_json) result["keys"] = db_to_json(key_json)
if "signatures" in device: if device.signatures:
for sig_user_id, sigs in device["signatures"].items(): for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault( result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {} sig_user_id, {}
).update(sigs) ).update(sigs)
device_display_name = device.get("device_display_name", None) device_display_name = device.display_name
if device_display_name: if device_display_name:
result["device_display_name"] = device_display_name result["device_display_name"] = device_display_name
@ -109,13 +125,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
rv[user_id] = {} rv[user_id] = {}
for device_id, device_info in device_keys.items(): for device_id, device_info in device_keys.items():
r = db_to_json(device_info.pop("key_json")) r = db_to_json(device_info.key_json)
r["unsigned"] = {} r["unsigned"] = {}
display_name = device_info["device_display_name"] display_name = device_info.display_name
if display_name is not None: if display_name is not None:
r["unsigned"]["device_display_name"] = display_name r["unsigned"]["device_display_name"] = display_name
if "signatures" in device_info: if device_info.signatures:
for sig_user_id, sigs in device_info["signatures"].items(): for sig_user_id, sigs in device_info.signatures.items():
r.setdefault("signatures", {}).setdefault( r.setdefault("signatures", {}).setdefault(
sig_user_id, {} sig_user_id, {}
).update(sigs) ).update(sigs)
@ -126,7 +142,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
@trace @trace
def _get_e2e_device_keys_and_signatures_txn( def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[Dict]]]: ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
set_tag("include_all_devices", include_all_devices) set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices) set_tag("include_deleted_devices", include_deleted_devices)
@ -161,7 +177,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
sql = ( sql = (
"SELECT user_id, device_id, " "SELECT user_id, device_id, "
" d.display_name AS device_display_name, " " d.display_name, "
" k.key_json" " k.key_json"
" FROM devices d" " FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@ -172,13 +188,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
) )
txn.execute(sql, query_params) txn.execute(sql, query_params)
rows = self.db_pool.cursor_to_dict(txn)
result = {} result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
for row in rows: for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices: if include_deleted_devices:
deleted_devices.remove((row["user_id"], row["device_id"])) deleted_devices.remove((user_id, device_id))
result.setdefault(row["user_id"], {})[row["device_id"]] = row result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
display_name, key_json
)
if include_deleted_devices: if include_deleted_devices:
for user_id, device_id in deleted_devices: for user_id, device_id in deleted_devices:
@ -209,7 +226,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# note that target_device_result will be None for deleted devices. # note that target_device_result will be None for deleted devices.
continue continue
target_device_signatures = target_device_result.setdefault("signatures", {}) target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}
signing_user_signatures = target_device_signatures.setdefault( signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {} signing_user_id, {}
) )