Add support for MSC2732: olm fallback keys (#8312)

This commit is contained in:
Hubert Chathi 2020-10-06 13:26:29 -04:00 committed by GitHub
parent a024461130
commit 3cd78bbe9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 215 additions and 1 deletions

1
changelog.d/8312.feature Normal file
View File

@ -0,0 +1 @@
Add support for olm fallback keys ([MSC2732](https://github.com/matrix-org/matrix-doc/pull/2732)).

View File

@ -90,6 +90,7 @@ BOOLEAN_COLUMNS = {
"room_stats_state": ["is_federatable"], "room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"], "local_media_repository": ["safe_from_quarantine"],
"users": ["shadow_banned"], "users": ["shadow_banned"],
"e2e_fallback_keys_json": ["used"],
} }

View File

@ -496,6 +496,22 @@ class E2eKeysHandler:
log_kv( log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"} {"message": "Did not update one_time_keys", "reason": "no keys given"}
) )
fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
if fallback_keys and isinstance(fallback_keys, dict):
log_kv(
{
"message": "Updating fallback_keys for device.",
"user_id": user_id,
"device_id": device_id,
}
)
await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
elif fallback_keys:
log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
else:
log_kv(
{"message": "Did not update fallback_keys", "reason": "no keys given"}
)
# the device should have been registered already, but it may have been # the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an # deleted due to a race with a DELETE request. Or we may be using an

View File

@ -201,6 +201,8 @@ class SyncResult:
device_lists: List of user_ids whose devices have changed device_lists: List of user_ids whose devices have changed
device_one_time_keys_count: Dict of algorithm to count for one time keys device_one_time_keys_count: Dict of algorithm to count for one time keys
for this device for this device
device_unused_fallback_key_types: List of key types that have an unused fallback
key
groups: Group updates, if any groups: Group updates, if any
""" """
@ -213,6 +215,7 @@ class SyncResult:
to_device = attr.ib(type=List[JsonDict]) to_device = attr.ib(type=List[JsonDict])
device_lists = attr.ib(type=DeviceLists) device_lists = attr.ib(type=DeviceLists)
device_one_time_keys_count = attr.ib(type=JsonDict) device_one_time_keys_count = attr.ib(type=JsonDict)
device_unused_fallback_key_types = attr.ib(type=List[str])
groups = attr.ib(type=Optional[GroupsSyncResult]) groups = attr.ib(type=Optional[GroupsSyncResult])
def __bool__(self) -> bool: def __bool__(self) -> bool:
@ -1014,10 +1017,14 @@ class SyncHandler:
logger.debug("Fetching OTK data") logger.debug("Fetching OTK data")
device_id = sync_config.device_id device_id = sync_config.device_id
one_time_key_counts = {} # type: JsonDict one_time_key_counts = {} # type: JsonDict
unused_fallback_key_types = [] # type: List[str]
if device_id: if device_id:
one_time_key_counts = await self.store.count_e2e_one_time_keys( one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id user_id, device_id
) )
unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
user_id, device_id
)
logger.debug("Fetching group data") logger.debug("Fetching group data")
await self._generate_sync_entry_for_groups(sync_result_builder) await self._generate_sync_entry_for_groups(sync_result_builder)
@ -1041,6 +1048,7 @@ class SyncHandler:
device_lists=device_lists, device_lists=device_lists,
groups=sync_result_builder.groups, groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts, device_one_time_keys_count=one_time_key_counts,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
) )

View File

@ -236,6 +236,7 @@ class SyncRestServlet(RestServlet):
"leave": sync_result.groups.leave, "leave": sync_result.groups.leave,
}, },
"device_one_time_keys_count": sync_result.device_one_time_keys_count, "device_one_time_keys_count": sync_result.device_one_time_keys_count,
"org.matrix.msc2732.device_unused_fallback_key_types": sync_result.device_unused_fallback_key_types,
"next_batch": await sync_result.next_batch.to_string(self.store), "next_batch": await sync_result.next_batch.to_string(self.store),
} }

View File

@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys "count_e2e_one_time_keys", _count_e2e_one_time_keys
) )
async def set_e2e_fallback_keys(
self, user_id: str, device_id: str, fallback_keys: JsonDict
) -> None:
"""Set the user's e2e fallback keys.
Args:
user_id: the user whose keys are being set
device_id: the device whose keys are being set
fallback_keys: the keys to set. This is a map from key ID (which is
of the form "algorithm:id") to key data.
"""
# fallback_keys will usually only have one item in it, so using a for
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
# FIXME: make sure that only one key per algorithm is uploaded
for key_id, fallback_key in fallback_keys.items():
algorithm, key_id = key_id.split(":", 1)
await self.db_pool.simple_upsert(
"e2e_fallback_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
},
values={
"key_id": key_id,
"key_json": json_encoder.encode(fallback_key),
"used": False,
},
desc="set_e2e_fallback_key",
)
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
self, user_id: str, device_id: str
) -> List[str]:
"""Returns the fallback key types that have an unused key.
Args:
user_id: the user whose keys are being queried
device_id: the device whose keys are being queried
Returns:
a list of key types
"""
return await self.db_pool.simple_select_onecol(
"e2e_fallback_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
retcol="algorithm",
desc="get_e2e_unused_fallback_key_types",
)
async def get_e2e_cross_signing_key( async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None self, user_id: str, key_type: str, from_user_id: Optional[str] = None
) -> Optional[dict]: ) -> Optional[dict]:
@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
" WHERE user_id = ? AND device_id = ? AND algorithm = ?" " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1" " LIMIT 1"
) )
fallback_sql = (
"SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
result = {} result = {}
delete = [] delete = []
used_fallbacks = []
for user_id, device_id, algorithm in query_list: for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {}) user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {}) device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm)) txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn: otk_row = txn.fetchone()
if otk_row is not None:
key_id, key_json = otk_row
device_result[algorithm + ":" + key_id] = key_json device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id)) delete.append((user_id, device_id, algorithm, key_id))
else:
# no one-time key available, so see if there's a fallback
# key
txn.execute(fallback_sql, (user_id, device_id, algorithm))
fallback_row = txn.fetchone()
if fallback_row is not None:
key_id, key_json, used = fallback_row
device_result[algorithm + ":" + key_id] = key_json
if not used:
used_fallbacks.append(
(user_id, device_id, algorithm, key_id)
)
# drop any one-time keys that were claimed
sql = ( sql = (
"DELETE FROM e2e_one_time_keys_json" "DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?" " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
# mark fallback keys as used
for user_id, device_id, algorithm, key_id in used_fallbacks:
self.db_pool.simple_update_txn(
txn,
"e2e_fallback_keys_json",
{
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
},
{"used": True},
)
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)
return result return result
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
@ -754,6 +844,14 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._invalidate_cache_and_stream( self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id) txn, self.count_e2e_one_time_keys, (user_id, device_id)
) )
self.db_pool.simple_delete_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn

View File

@ -0,0 +1,24 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS e2e_fallback_keys_json (
user_id TEXT NOT NULL, -- The user this fallback key is for.
device_id TEXT NOT NULL, -- The device this fallback key is for.
algorithm TEXT NOT NULL, -- Which algorithm this fallback key is for.
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
key_json TEXT NOT NULL, -- The key as a JSON blob.
used BOOLEAN NOT NULL DEFAULT FALSE, -- Whether the key has been used or not.
CONSTRAINT e2e_fallback_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm)
);

View File

@ -171,6 +171,71 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
}, },
) )
@defer.inlineCallbacks
def test_fallback_key(self):
local_user = "@boris:" + self.hs.hostname
device_id = "xyz"
fallback_key = {"alg1:k1": "key1"}
otk = {"alg1:k2": "key2"}
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"org.matrix.msc2732.fallback_keys": fallback_key},
)
)
# claiming an OTK when no OTKs are available should return the fallback
# key
res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# claiming an OTK again should return the same fallback key
res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": otk}
)
)
res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)
res = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_replace_master_key(self): def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable""" """uploading a new signing key should make the old signing key unavailable"""