mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-01-24 08:11:04 -05:00
Add type hints to synapse/storage/databases/main/end_to_end_keys.py
(#11551)
This commit is contained in:
parent
6da8591f2e
commit
1abfb15f07
1
changelog.d/11551.misc
Normal file
1
changelog.d/11551.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add missing type hints to storage classes.
|
4
mypy.ini
4
mypy.ini
@ -28,7 +28,6 @@ exclude = (?x)
|
|||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/databases/main/devices.py
|
|synapse/storage/databases/main/devices.py
|
||||||
|synapse/storage/databases/main/e2e_room_keys.py
|
|synapse/storage/databases/main/e2e_room_keys.py
|
||||||
|synapse/storage/databases/main/end_to_end_keys.py
|
|
||||||
|synapse/storage/databases/main/event_federation.py
|
|synapse/storage/databases/main/event_federation.py
|
||||||
|synapse/storage/databases/main/event_push_actions.py
|
|synapse/storage/databases/main/event_push_actions.py
|
||||||
|synapse/storage/databases/main/events_bg_updates.py
|
|synapse/storage/databases/main/events_bg_updates.py
|
||||||
@ -189,6 +188,9 @@ disallow_untyped_defs = True
|
|||||||
[mypy-synapse.storage.databases.main.directory]
|
[mypy-synapse.storage.databases.main.directory]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.databases.main.end_to_end_keys]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.storage.databases.main.events_worker]
|
[mypy-synapse.storage.databases.main.events_worker]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -143,9 +143,6 @@ class DataStore(
|
|||||||
("device_lists_outbound_pokes", "stream_id"),
|
("device_lists_outbound_pokes", "stream_id"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
self._cross_signing_id_gen = StreamIdGenerator(
|
|
||||||
db_conn, "e2e_cross_signing_keys", "stream_id"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||||
|
@ -14,19 +14,32 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from twisted.enterprise.adbapi import Connection
|
|
||||||
|
|
||||||
from synapse.api.constants import DeviceKeyAlgorithms
|
from synapse.api.constants import DeviceKeyAlgorithms
|
||||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
make_in_list_sql_clause,
|
||||||
|
)
|
||||||
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
@ -50,7 +63,12 @@ class DeviceKeyLookupResult:
|
|||||||
|
|
||||||
|
|
||||||
class EndToEndKeyBackgroundStore(SQLBaseStore):
|
class EndToEndKeyBackgroundStore(SQLBaseStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.db_pool.updates.register_background_index_update(
|
self.db_pool.updates.register_background_index_update(
|
||||||
@ -62,8 +80,13 @@ class EndToEndKeyBackgroundStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._allow_device_name_lookup_over_federation = (
|
self._allow_device_name_lookup_over_federation = (
|
||||||
@ -124,7 +147,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
|
|
||||||
# Build the result structure, un-jsonify the results, and add the
|
# Build the result structure, un-jsonify the results, and add the
|
||||||
# "unsigned" section
|
# "unsigned" section
|
||||||
rv = {}
|
rv: Dict[str, Dict[str, JsonDict]] = {}
|
||||||
for user_id, device_keys in results.items():
|
for user_id, device_keys in results.items():
|
||||||
rv[user_id] = {}
|
rv[user_id] = {}
|
||||||
for device_id, device_info in device_keys.items():
|
for device_id, device_info in device_keys.items():
|
||||||
@ -195,6 +218,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
# add each cross-signing signature to the correct device in the result dict.
|
# add each cross-signing signature to the correct device in the result dict.
|
||||||
for (user_id, key_id, device_id, signature) in cross_sigs_result:
|
for (user_id, key_id, device_id, signature) in cross_sigs_result:
|
||||||
target_device_result = result[user_id][device_id]
|
target_device_result = result[user_id][device_id]
|
||||||
|
# We've only looked up cross-signatures for non-deleted devices with key
|
||||||
|
# data.
|
||||||
|
assert target_device_result is not None
|
||||||
|
assert target_device_result.keys is not None
|
||||||
target_device_signatures = target_device_result.keys.setdefault(
|
target_device_signatures = target_device_result.keys.setdefault(
|
||||||
"signatures", {}
|
"signatures", {}
|
||||||
)
|
)
|
||||||
@ -207,7 +234,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_e2e_device_keys_txn(
|
def _get_e2e_device_keys_txn(
|
||||||
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
query_list: Collection[Tuple[str, str]],
|
||||||
|
include_all_devices: bool = False,
|
||||||
|
include_deleted_devices: bool = False,
|
||||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||||
"""Get information on devices from the database
|
"""Get information on devices from the database
|
||||||
|
|
||||||
@ -263,7 +294,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_e2e_cross_signing_signatures_for_devices_txn(
|
def _get_e2e_cross_signing_signatures_for_devices_txn(
|
||||||
self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
|
self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
|
||||||
) -> List[Tuple[str, str, str, str]]:
|
) -> List[Tuple[str, str, str, str]]:
|
||||||
"""Get cross-signing signatures for a given list of devices
|
"""Get cross-signing signatures for a given list of devices
|
||||||
|
|
||||||
@ -289,7 +320,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
txn.execute(signature_sql, signature_query_params)
|
txn.execute(signature_sql, signature_query_params)
|
||||||
return txn.fetchall()
|
return cast(
|
||||||
|
List[
|
||||||
|
Tuple[
|
||||||
|
str,
|
||||||
|
str,
|
||||||
|
str,
|
||||||
|
str,
|
||||||
|
]
|
||||||
|
],
|
||||||
|
txn.fetchall(),
|
||||||
|
)
|
||||||
|
|
||||||
async def get_e2e_one_time_keys(
|
async def get_e2e_one_time_keys(
|
||||||
self, user_id: str, device_id: str, key_ids: List[str]
|
self, user_id: str, device_id: str, key_ids: List[str]
|
||||||
@ -335,7 +376,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
|
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _add_e2e_one_time_keys(txn):
|
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
|
||||||
set_tag("user_id", user_id)
|
set_tag("user_id", user_id)
|
||||||
set_tag("device_id", device_id)
|
set_tag("device_id", device_id)
|
||||||
set_tag("new_keys", new_keys)
|
set_tag("new_keys", new_keys)
|
||||||
@ -375,7 +416,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
A mapping from algorithm to number of keys for that algorithm.
|
A mapping from algorithm to number of keys for that algorithm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _count_e2e_one_time_keys(txn):
|
def _count_e2e_one_time_keys(txn: LoggingTransaction) -> Dict[str, int]:
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
|
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
|
||||||
" WHERE user_id = ? AND device_id = ?"
|
" WHERE user_id = ? AND device_id = ?"
|
||||||
@ -421,7 +462,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _set_e2e_fallback_keys_txn(
|
def _set_e2e_fallback_keys_txn(
|
||||||
self, txn: Connection, user_id: str, device_id: str, fallback_keys: JsonDict
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
fallback_keys: JsonDict,
|
||||||
) -> None:
|
) -> None:
|
||||||
# fallback_keys will usually only have one item in it, so using a for
|
# fallback_keys will usually only have one item in it, so using a for
|
||||||
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
|
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
|
||||||
@ -483,7 +528,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
|
|
||||||
async def get_e2e_cross_signing_key(
|
async def get_e2e_cross_signing_key(
|
||||||
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
||||||
) -> Optional[dict]:
|
) -> Optional[JsonDict]:
|
||||||
"""Returns a user's cross-signing key.
|
"""Returns a user's cross-signing key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -504,7 +549,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
return user_keys.get(key_type)
|
return user_keys.get(key_type)
|
||||||
|
|
||||||
@cached(num_args=1)
|
@cached(num_args=1)
|
||||||
def _get_bare_e2e_cross_signing_keys(self, user_id):
|
def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
|
||||||
"""Dummy function. Only used to make a cache for
|
"""Dummy function. Only used to make a cache for
|
||||||
_get_bare_e2e_cross_signing_keys_bulk.
|
_get_bare_e2e_cross_signing_keys_bulk.
|
||||||
"""
|
"""
|
||||||
@ -517,7 +562,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
)
|
)
|
||||||
async def _get_bare_e2e_cross_signing_keys_bulk(
|
async def _get_bare_e2e_cross_signing_keys_bulk(
|
||||||
self, user_ids: Iterable[str]
|
self, user_ids: Iterable[str]
|
||||||
) -> Dict[str, Dict[str, dict]]:
|
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
|
||||||
"""Returns the cross-signing keys for a set of users. The output of this
|
"""Returns the cross-signing keys for a set of users. The output of this
|
||||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||||
the signatures for the calling user need to be fetched.
|
the signatures for the calling user need to be fetched.
|
||||||
@ -531,32 +576,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
their user ID will map to None.
|
their user ID will map to None.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return await self.db_pool.runInteraction(
|
result = await self.db_pool.runInteraction(
|
||||||
"get_bare_e2e_cross_signing_keys_bulk",
|
"get_bare_e2e_cross_signing_keys_bulk",
|
||||||
self._get_bare_e2e_cross_signing_keys_bulk_txn,
|
self._get_bare_e2e_cross_signing_keys_bulk_txn,
|
||||||
user_ids,
|
user_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# The `Optional` comes from the `@cachedList` decorator.
|
||||||
|
return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
|
||||||
|
|
||||||
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
||||||
self,
|
self,
|
||||||
txn: Connection,
|
txn: LoggingTransaction,
|
||||||
user_ids: Iterable[str],
|
user_ids: Iterable[str],
|
||||||
) -> Dict[str, Dict[str, dict]]:
|
) -> Dict[str, Dict[str, JsonDict]]:
|
||||||
"""Returns the cross-signing keys for a set of users. The output of this
|
"""Returns the cross-signing keys for a set of users. The output of this
|
||||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||||
the signatures for the calling user need to be fetched.
|
the signatures for the calling user need to be fetched.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
txn: db connection
|
||||||
user_ids (list[str]): the users whose keys are being requested
|
user_ids: the users whose keys are being requested
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
Mapping from user ID to key type to key data.
|
||||||
data. If a user's cross-signing keys were not found, their user
|
If a user's cross-signing keys were not found, their user ID will not be in
|
||||||
ID will not be in the dict.
|
the dict.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
result = {}
|
result: Dict[str, Dict[str, JsonDict]] = {}
|
||||||
|
|
||||||
for user_chunk in batch_iter(user_ids, 100):
|
for user_chunk in batch_iter(user_ids, 100):
|
||||||
clause, params = make_in_list_sql_clause(
|
clause, params = make_in_list_sql_clause(
|
||||||
@ -596,43 +644,48 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
user_id = row["user_id"]
|
user_id = row["user_id"]
|
||||||
key_type = row["keytype"]
|
key_type = row["keytype"]
|
||||||
key = db_to_json(row["keydata"])
|
key = db_to_json(row["keydata"])
|
||||||
user_info = result.setdefault(user_id, {})
|
user_keys = result.setdefault(user_id, {})
|
||||||
user_info[key_type] = key
|
user_keys[key_type] = key
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _get_e2e_cross_signing_signatures_txn(
|
def _get_e2e_cross_signing_signatures_txn(
|
||||||
self,
|
self,
|
||||||
txn: Connection,
|
txn: LoggingTransaction,
|
||||||
keys: Dict[str, Dict[str, dict]],
|
keys: Dict[str, Optional[Dict[str, JsonDict]]],
|
||||||
from_user_id: str,
|
from_user_id: str,
|
||||||
) -> Dict[str, Dict[str, dict]]:
|
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
|
||||||
"""Returns the cross-signing signatures made by a user on a set of keys.
|
"""Returns the cross-signing signatures made by a user on a set of keys.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
txn: db connection
|
||||||
keys (dict[str, dict[str, dict]]): a map of user ID to key type to
|
keys: a map of user ID to key type to key data.
|
||||||
key data. This dict will be modified to add signatures.
|
This dict will be modified to add signatures.
|
||||||
from_user_id (str): fetch the signatures made by this user
|
from_user_id: fetch the signatures made by this user
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
Mapping from user ID to key type to key data.
|
||||||
data. The return value will be the same as the keys argument,
|
The return value will be the same as the keys argument, with the
|
||||||
with the modifications included.
|
modifications included.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# find out what cross-signing keys (a.k.a. devices) we need to get
|
# find out what cross-signing keys (a.k.a. devices) we need to get
|
||||||
# signatures for. This is a map of (user_id, device_id) to key type
|
# signatures for. This is a map of (user_id, device_id) to key type
|
||||||
# (device_id is the key's public part).
|
# (device_id is the key's public part).
|
||||||
devices = {}
|
devices: Dict[Tuple[str, str], str] = {}
|
||||||
|
|
||||||
for user_id, user_info in keys.items():
|
for user_id, user_keys in keys.items():
|
||||||
if user_info is None:
|
if user_keys is None:
|
||||||
continue
|
continue
|
||||||
for key_type, key in user_info.items():
|
for key_type, key in user_keys.items():
|
||||||
device_id = None
|
device_id = None
|
||||||
for k in key["keys"].values():
|
for k in key["keys"].values():
|
||||||
device_id = k
|
device_id = k
|
||||||
|
# `key` ought to be a `CrossSigningKey`, whose .keys property is a
|
||||||
|
# dictionary with a single entry:
|
||||||
|
# "algorithm:base64_public_key": "base64_public_key"
|
||||||
|
# See https://spec.matrix.org/v1.1/client-server-api/#cross-signing
|
||||||
|
assert isinstance(device_id, str)
|
||||||
devices[(user_id, device_id)] = key_type
|
devices[(user_id, device_id)] = key_type
|
||||||
|
|
||||||
for batch in batch_iter(devices.keys(), size=100):
|
for batch in batch_iter(devices.keys(), size=100):
|
||||||
@ -656,15 +709,20 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
|
|
||||||
# and add the signatures to the appropriate keys
|
# and add the signatures to the appropriate keys
|
||||||
for row in rows:
|
for row in rows:
|
||||||
key_id = row["key_id"]
|
key_id: str = row["key_id"]
|
||||||
target_user_id = row["target_user_id"]
|
target_user_id: str = row["target_user_id"]
|
||||||
target_device_id = row["target_device_id"]
|
target_device_id: str = row["target_device_id"]
|
||||||
key_type = devices[(target_user_id, target_device_id)]
|
key_type = devices[(target_user_id, target_device_id)]
|
||||||
# We need to copy everything, because the result may have come
|
# We need to copy everything, because the result may have come
|
||||||
# from the cache. dict.copy only does a shallow copy, so we
|
# from the cache. dict.copy only does a shallow copy, so we
|
||||||
# need to recursively copy the dicts that will be modified.
|
# need to recursively copy the dicts that will be modified.
|
||||||
user_info = keys[target_user_id] = keys[target_user_id].copy()
|
user_keys = keys[target_user_id]
|
||||||
target_user_key = user_info[key_type] = user_info[key_type].copy()
|
# `user_keys` cannot be `None` because we only fetched signatures for
|
||||||
|
# users with keys
|
||||||
|
assert user_keys is not None
|
||||||
|
user_keys = keys[target_user_id] = user_keys.copy()
|
||||||
|
|
||||||
|
target_user_key = user_keys[key_type] = user_keys[key_type].copy()
|
||||||
if "signatures" in target_user_key:
|
if "signatures" in target_user_key:
|
||||||
signatures = target_user_key["signatures"] = target_user_key[
|
signatures = target_user_key["signatures"] = target_user_key[
|
||||||
"signatures"
|
"signatures"
|
||||||
@ -683,7 +741,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
|
|
||||||
async def get_e2e_cross_signing_keys_bulk(
|
async def get_e2e_cross_signing_keys_bulk(
|
||||||
self, user_ids: List[str], from_user_id: Optional[str] = None
|
self, user_ids: List[str], from_user_id: Optional[str] = None
|
||||||
) -> Dict[str, Optional[Dict[str, dict]]]:
|
) -> Dict[str, Optional[Dict[str, JsonDict]]]:
|
||||||
"""Returns the cross-signing keys for a set of users.
|
"""Returns the cross-signing keys for a set of users.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -741,7 +799,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
if last_id == current_id:
|
if last_id == current_id:
|
||||||
return [], current_id, False
|
return [], current_id, False
|
||||||
|
|
||||||
def _get_all_user_signature_changes_for_remotes_txn(txn):
|
def _get_all_user_signature_changes_for_remotes_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT stream_id, from_user_id AS user_id
|
SELECT stream_id, from_user_id AS user_id
|
||||||
FROM user_signature_stream
|
FROM user_signature_stream
|
||||||
@ -785,7 +845,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
|
|
||||||
@trace
|
@trace
|
||||||
def _claim_e2e_one_time_key_simple(
|
def _claim_e2e_one_time_key_simple(
|
||||||
txn, user_id: str, device_id: str, algorithm: str
|
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
|
||||||
) -> Optional[Tuple[str, str]]:
|
) -> Optional[Tuple[str, str]]:
|
||||||
"""Claim OTK for device for DBs that don't support RETURNING.
|
"""Claim OTK for device for DBs that don't support RETURNING.
|
||||||
|
|
||||||
@ -825,7 +885,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
|
|
||||||
@trace
|
@trace
|
||||||
def _claim_e2e_one_time_key_returning(
|
def _claim_e2e_one_time_key_returning(
|
||||||
txn, user_id: str, device_id: str, algorithm: str
|
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
|
||||||
) -> Optional[Tuple[str, str]]:
|
) -> Optional[Tuple[str, str]]:
|
||||||
"""Claim OTK for device for DBs that support RETURNING.
|
"""Claim OTK for device for DBs that support RETURNING.
|
||||||
|
|
||||||
@ -860,7 +920,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
key_id, key_json = otk_row
|
key_id, key_json = otk_row
|
||||||
return f"{algorithm}:{key_id}", key_json
|
return f"{algorithm}:{key_id}", key_json
|
||||||
|
|
||||||
results = {}
|
results: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||||
for user_id, device_id, algorithm in query_list:
|
for user_id, device_id, algorithm in query_list:
|
||||||
if self.database_engine.supports_returning:
|
if self.database_engine.supports_returning:
|
||||||
# If we support RETURNING clause we can use a single query that
|
# If we support RETURNING clause we can use a single query that
|
||||||
@ -930,6 +990,18 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||||||
|
|
||||||
|
|
||||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
self._cross_signing_id_gen = StreamIdGenerator(
|
||||||
|
db_conn, "e2e_cross_signing_keys", "stream_id"
|
||||||
|
)
|
||||||
|
|
||||||
async def set_e2e_device_keys(
|
async def set_e2e_device_keys(
|
||||||
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
|
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@ -937,7 +1009,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||||||
or the keys were already in the database.
|
or the keys were already in the database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _set_e2e_device_keys_txn(txn):
|
def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
|
||||||
set_tag("user_id", user_id)
|
set_tag("user_id", user_id)
|
||||||
set_tag("device_id", device_id)
|
set_tag("device_id", device_id)
|
||||||
set_tag("time_now", time_now)
|
set_tag("time_now", time_now)
|
||||||
@ -973,7 +1045,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
|
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
|
||||||
def delete_e2e_keys_by_device_txn(txn):
|
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
|
||||||
log_kv(
|
log_kv(
|
||||||
{
|
{
|
||||||
"message": "Deleting keys for device",
|
"message": "Deleting keys for device",
|
||||||
@ -1012,17 +1084,24 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
|
def _set_e2e_cross_signing_key_txn(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
user_id: str,
|
||||||
|
key_type: str,
|
||||||
|
key: JsonDict,
|
||||||
|
stream_id: int,
|
||||||
|
) -> None:
|
||||||
"""Set a user's cross-signing key.
|
"""Set a user's cross-signing key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
txn: db connection
|
||||||
user_id (str): the user to set the signing key for
|
user_id: the user to set the signing key for
|
||||||
key_type (str): the type of key that is being set: either 'master'
|
key_type: the type of key that is being set: either 'master'
|
||||||
for a master key, 'self_signing' for a self-signing key, or
|
for a master key, 'self_signing' for a self-signing key, or
|
||||||
'user_signing' for a user-signing key
|
'user_signing' for a user-signing key
|
||||||
key (dict): the key data
|
key: the key data
|
||||||
stream_id (int)
|
stream_id
|
||||||
"""
|
"""
|
||||||
# the 'key' dict will look something like:
|
# the 'key' dict will look something like:
|
||||||
# {
|
# {
|
||||||
@ -1075,13 +1154,15 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||||||
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
|
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
async def set_e2e_cross_signing_key(
|
||||||
|
self, user_id: str, key_type: str, key: JsonDict
|
||||||
|
) -> None:
|
||||||
"""Set a user's cross-signing key.
|
"""Set a user's cross-signing key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): the user to set the user-signing key for
|
user_id: the user to set the user-signing key for
|
||||||
key_type (str): the type of cross-signing key to set
|
key_type: the type of cross-signing key to set
|
||||||
key (dict): the key data
|
key: the key data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async with self._cross_signing_id_gen.get_next() as stream_id:
|
async with self._cross_signing_id_gen.get_next() as stream_id:
|
||||||
|
Loading…
Reference in New Issue
Block a user