mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add support for MSC2732: olm fallback keys (#8312)
This commit is contained in:
parent
a024461130
commit
3cd78bbe9e
1
changelog.d/8312.feature
Normal file
1
changelog.d/8312.feature
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)).
|
@ -90,6 +90,7 @@ BOOLEAN_COLUMNS = {
|
|||||||
"room_stats_state": ["is_federatable"],
|
"room_stats_state": ["is_federatable"],
|
||||||
"local_media_repository": ["safe_from_quarantine"],
|
"local_media_repository": ["safe_from_quarantine"],
|
||||||
"users": ["shadow_banned"],
|
"users": ["shadow_banned"],
|
||||||
|
"e2e_fallback_keys_json": ["used"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -496,6 +496,22 @@ class E2eKeysHandler:
|
|||||||
log_kv(
|
log_kv(
|
||||||
{"message": "Did not update one_time_keys", "reason": "no keys given"}
|
{"message": "Did not update one_time_keys", "reason": "no keys given"}
|
||||||
)
|
)
|
||||||
|
fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
|
||||||
|
if fallback_keys and isinstance(fallback_keys, dict):
|
||||||
|
log_kv(
|
||||||
|
{
|
||||||
|
"message": "Updating fallback_keys for device.",
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
|
||||||
|
elif fallback_keys:
|
||||||
|
log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
|
||||||
|
else:
|
||||||
|
log_kv(
|
||||||
|
{"message": "Did not update fallback_keys", "reason": "no keys given"}
|
||||||
|
)
|
||||||
|
|
||||||
# the device should have been registered already, but it may have been
|
# the device should have been registered already, but it may have been
|
||||||
# deleted due to a race with a DELETE request. Or we may be using an
|
# deleted due to a race with a DELETE request. Or we may be using an
|
||||||
|
@ -201,6 +201,8 @@ class SyncResult:
|
|||||||
device_lists: List of user_ids whose devices have changed
|
device_lists: List of user_ids whose devices have changed
|
||||||
device_one_time_keys_count: Dict of algorithm to count for one time keys
|
device_one_time_keys_count: Dict of algorithm to count for one time keys
|
||||||
for this device
|
for this device
|
||||||
|
device_unused_fallback_key_types: List of key types that have an unused fallback
|
||||||
|
key
|
||||||
groups: Group updates, if any
|
groups: Group updates, if any
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -213,6 +215,7 @@ class SyncResult:
|
|||||||
to_device = attr.ib(type=List[JsonDict])
|
to_device = attr.ib(type=List[JsonDict])
|
||||||
device_lists = attr.ib(type=DeviceLists)
|
device_lists = attr.ib(type=DeviceLists)
|
||||||
device_one_time_keys_count = attr.ib(type=JsonDict)
|
device_one_time_keys_count = attr.ib(type=JsonDict)
|
||||||
|
device_unused_fallback_key_types = attr.ib(type=List[str])
|
||||||
groups = attr.ib(type=Optional[GroupsSyncResult])
|
groups = attr.ib(type=Optional[GroupsSyncResult])
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
@ -1014,10 +1017,14 @@ class SyncHandler:
|
|||||||
logger.debug("Fetching OTK data")
|
logger.debug("Fetching OTK data")
|
||||||
device_id = sync_config.device_id
|
device_id = sync_config.device_id
|
||||||
one_time_key_counts = {} # type: JsonDict
|
one_time_key_counts = {} # type: JsonDict
|
||||||
|
unused_fallback_key_types = [] # type: List[str]
|
||||||
if device_id:
|
if device_id:
|
||||||
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
one_time_key_counts = await self.store.count_e2e_one_time_keys(
|
||||||
user_id, device_id
|
user_id, device_id
|
||||||
)
|
)
|
||||||
|
unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
|
||||||
|
user_id, device_id
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug("Fetching group data")
|
logger.debug("Fetching group data")
|
||||||
await self._generate_sync_entry_for_groups(sync_result_builder)
|
await self._generate_sync_entry_for_groups(sync_result_builder)
|
||||||
@ -1041,6 +1048,7 @@ class SyncHandler:
|
|||||||
device_lists=device_lists,
|
device_lists=device_lists,
|
||||||
groups=sync_result_builder.groups,
|
groups=sync_result_builder.groups,
|
||||||
device_one_time_keys_count=one_time_key_counts,
|
device_one_time_keys_count=one_time_key_counts,
|
||||||
|
device_unused_fallback_key_types=unused_fallback_key_types,
|
||||||
next_batch=sync_result_builder.now_token,
|
next_batch=sync_result_builder.now_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -236,6 +236,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
"leave": sync_result.groups.leave,
|
"leave": sync_result.groups.leave,
|
||||||
},
|
},
|
||||||
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
||||||
|
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
|
||||||
"next_batch": await sync_result.next_batch.to_string(self.store),
|
"next_batch": await sync_result.next_batch.to_string(self.store),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def set_e2e_fallback_keys(
|
||||||
|
self, user_id: str, device_id: str, fallback_keys: JsonDict
|
||||||
|
) -> None:
|
||||||
|
"""Set the user's e2e fallback keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: the user whose keys are being set
|
||||||
|
device_id: the device whose keys are being set
|
||||||
|
fallback_keys: the keys to set. This is a map from key ID (which is
|
||||||
|
of the form "algorithm:id") to key data.
|
||||||
|
"""
|
||||||
|
# fallback_keys will usually only have one item in it, so using a for
|
||||||
|
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
|
||||||
|
# FIXME: make sure that only one key per algorithm is uploaded
|
||||||
|
for key_id, fallback_key in fallback_keys.items():
|
||||||
|
algorithm, key_id = key_id.split(":", 1)
|
||||||
|
await self.db_pool.simple_upsert(
|
||||||
|
"e2e_fallback_keys_json",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
"algorithm": algorithm,
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"key_id": key_id,
|
||||||
|
"key_json": json_encoder.encode(fallback_key),
|
||||||
|
"used": False,
|
||||||
|
},
|
||||||
|
desc="set_e2e_fallback_key",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached(max_entries=10000)
|
||||||
|
async def get_e2e_unused_fallback_key_types(
|
||||||
|
self, user_id: str, device_id: str
|
||||||
|
) -> List[str]:
|
||||||
|
"""Returns the fallback key types that have an unused key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: the user whose keys are being queried
|
||||||
|
device_id: the device whose keys are being queried
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a list of key types
|
||||||
|
"""
|
||||||
|
return await self.db_pool.simple_select_onecol(
|
||||||
|
"e2e_fallback_keys_json",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
|
||||||
|
retcol="algorithm",
|
||||||
|
desc="get_e2e_unused_fallback_key_types",
|
||||||
|
)
|
||||||
|
|
||||||
async def get_e2e_cross_signing_key(
|
async def get_e2e_cross_signing_key(
|
||||||
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
||||||
) -> Optional[dict]:
|
) -> Optional[dict]:
|
||||||
@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||||
" LIMIT 1"
|
" LIMIT 1"
|
||||||
)
|
)
|
||||||
|
fallback_sql = (
|
||||||
|
"SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
|
||||||
|
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||||
|
" LIMIT 1"
|
||||||
|
)
|
||||||
result = {}
|
result = {}
|
||||||
delete = []
|
delete = []
|
||||||
|
used_fallbacks = []
|
||||||
for user_id, device_id, algorithm in query_list:
|
for user_id, device_id, algorithm in query_list:
|
||||||
user_result = result.setdefault(user_id, {})
|
user_result = result.setdefault(user_id, {})
|
||||||
device_result = user_result.setdefault(device_id, {})
|
device_result = user_result.setdefault(device_id, {})
|
||||||
txn.execute(sql, (user_id, device_id, algorithm))
|
txn.execute(sql, (user_id, device_id, algorithm))
|
||||||
for key_id, key_json in txn:
|
otk_row = txn.fetchone()
|
||||||
|
if otk_row is not None:
|
||||||
|
key_id, key_json = otk_row
|
||||||
device_result[algorithm + ":" + key_id] = key_json
|
device_result[algorithm + ":" + key_id] = key_json
|
||||||
delete.append((user_id, device_id, algorithm, key_id))
|
delete.append((user_id, device_id, algorithm, key_id))
|
||||||
|
else:
|
||||||
|
# no one-time key available, so see if there's a fallback
|
||||||
|
# key
|
||||||
|
txn.execute(fallback_sql, (user_id, device_id, algorithm))
|
||||||
|
fallback_row = txn.fetchone()
|
||||||
|
if fallback_row is not None:
|
||||||
|
key_id, key_json, used = fallback_row
|
||||||
|
device_result[algorithm + ":" + key_id] = key_json
|
||||||
|
if not used:
|
||||||
|
used_fallbacks.append(
|
||||||
|
(user_id, device_id, algorithm, key_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# drop any one-time keys that were claimed
|
||||||
sql = (
|
sql = (
|
||||||
"DELETE FROM e2e_one_time_keys_json"
|
"DELETE FROM e2e_one_time_keys_json"
|
||||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||||
@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||||
)
|
)
|
||||||
|
# mark fallback keys as used
|
||||||
|
for user_id, device_id, algorithm, key_id in used_fallbacks:
|
||||||
|
self.db_pool.simple_update_txn(
|
||||||
|
txn,
|
||||||
|
"e2e_fallback_keys_json",
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
"algorithm": algorithm,
|
||||||
|
"key_id": key_id,
|
||||||
|
},
|
||||||
|
{"used": True},
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
@ -754,6 +844,14 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||||
)
|
)
|
||||||
|
self.db_pool.simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="e2e_fallback_keys_json",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
|
||||||
|
)
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||||
|
@ -0,0 +1,24 @@
|
|||||||
|
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
|
||||||
|
user_id TEXT NOT NULL, -- The user this fallback key is for.
|
||||||
|
device_id TEXT NOT NULL, -- The device this fallback key is for.
|
||||||
|
algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
|
||||||
|
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
|
||||||
|
key_json TEXT NOT NULL, -- The key as a JSON blob.
|
||||||
|
used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
|
||||||
|
CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
|
||||||
|
);
|
@ -171,6 +171,71 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_fallback_key(self):
|
||||||
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
device_id = "xyz"
|
||||||
|
fallback_key = {"alg1:k1": "key1"}
|
||||||
|
otk = {"alg1:k2": "key2"}
|
||||||
|
|
||||||
|
yield defer.ensureDeferred(
|
||||||
|
self.handler.upload_keys_for_user(
|
||||||
|
local_user,
|
||||||
|
device_id,
|
||||||
|
{"org.matrix.msc2732.fallback_keys": fallback_key},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# claiming an OTK when no OTKs are available should return the fallback
|
||||||
|
# key
|
||||||
|
res = yield defer.ensureDeferred(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# claiming an OTK again should return the same fallback key
|
||||||
|
res = yield defer.ensureDeferred(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||||
|
)
|
||||||
|
|
||||||
|
# if the user uploads a one-time key, the next claim should fetch the
|
||||||
|
# one-time key, and then go back to the fallback
|
||||||
|
yield defer.ensureDeferred(
|
||||||
|
self.handler.upload_keys_for_user(
|
||||||
|
local_user, device_id, {"one_time_keys": otk}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
res = yield defer.ensureDeferred(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
|
||||||
|
)
|
||||||
|
|
||||||
|
res = yield defer.ensureDeferred(
|
||||||
|
self.handler.claim_one_time_keys(
|
||||||
|
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_replace_master_key(self):
|
def test_replace_master_key(self):
|
||||||
"""uploading a new signing key should make the old signing key unavailable"""
|
"""uploading a new signing key should make the old signing key unavailable"""
|
||||||
|
Loading…
Reference in New Issue
Block a user