mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-03 05:56:02 -04:00
Add support for MSC2732: olm fallback keys (#8312)
This commit is contained in:
parent
a024461130
commit
3cd78bbe9e
8 changed files with 215 additions and 1 deletions
|
@ -367,6 +367,57 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
async def set_e2e_fallback_keys(
|
||||
self, user_id: str, device_id: str, fallback_keys: JsonDict
|
||||
) -> None:
|
||||
"""Set the user's e2e fallback keys.
|
||||
|
||||
Args:
|
||||
user_id: the user whose keys are being set
|
||||
device_id: the device whose keys are being set
|
||||
fallback_keys: the keys to set. This is a map from key ID (which is
|
||||
of the form "algorithm:id") to key data.
|
||||
"""
|
||||
# fallback_keys will usually only have one item in it, so using a for
|
||||
# loop (as opposed to calling simple_upsert_many_txn) won't be too bad
|
||||
# FIXME: make sure that only one key per algorithm is uploaded
|
||||
for key_id, fallback_key in fallback_keys.items():
|
||||
algorithm, key_id = key_id.split(":", 1)
|
||||
await self.db_pool.simple_upsert(
|
||||
"e2e_fallback_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
},
|
||||
values={
|
||||
"key_id": key_id,
|
||||
"key_json": json_encoder.encode(fallback_key),
|
||||
"used": False,
|
||||
},
|
||||
desc="set_e2e_fallback_key",
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
async def get_e2e_unused_fallback_key_types(
|
||||
self, user_id: str, device_id: str
|
||||
) -> List[str]:
|
||||
"""Returns the fallback key types that have an unused key.
|
||||
|
||||
Args:
|
||||
user_id: the user whose keys are being queried
|
||||
device_id: the device whose keys are being queried
|
||||
|
||||
Returns:
|
||||
a list of key types
|
||||
"""
|
||||
return await self.db_pool.simple_select_onecol(
|
||||
"e2e_fallback_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "used": False},
|
||||
retcol="algorithm",
|
||||
desc="get_e2e_unused_fallback_key_types",
|
||||
)
|
||||
|
||||
async def get_e2e_cross_signing_key(
|
||||
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
||||
) -> Optional[dict]:
|
||||
|
@ -701,15 +752,37 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
" LIMIT 1"
|
||||
)
|
||||
fallback_sql = (
|
||||
"SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
|
||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
" LIMIT 1"
|
||||
)
|
||||
result = {}
|
||||
delete = []
|
||||
used_fallbacks = []
|
||||
for user_id, device_id, algorithm in query_list:
|
||||
user_result = result.setdefault(user_id, {})
|
||||
device_result = user_result.setdefault(device_id, {})
|
||||
txn.execute(sql, (user_id, device_id, algorithm))
|
||||
for key_id, key_json in txn:
|
||||
otk_row = txn.fetchone()
|
||||
if otk_row is not None:
|
||||
key_id, key_json = otk_row
|
||||
device_result[algorithm + ":" + key_id] = key_json
|
||||
delete.append((user_id, device_id, algorithm, key_id))
|
||||
else:
|
||||
# no one-time key available, so see if there's a fallback
|
||||
# key
|
||||
txn.execute(fallback_sql, (user_id, device_id, algorithm))
|
||||
fallback_row = txn.fetchone()
|
||||
if fallback_row is not None:
|
||||
key_id, key_json, used = fallback_row
|
||||
device_result[algorithm + ":" + key_id] = key_json
|
||||
if not used:
|
||||
used_fallbacks.append(
|
||||
(user_id, device_id, algorithm, key_id)
|
||||
)
|
||||
|
||||
# drop any one-time keys that were claimed
|
||||
sql = (
|
||||
"DELETE FROM e2e_one_time_keys_json"
|
||||
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||
|
@ -726,6 +799,23 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
# mark fallback keys as used
|
||||
for user_id, device_id, algorithm, key_id in used_fallbacks:
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
"e2e_fallback_keys_json",
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
"key_id": key_id,
|
||||
},
|
||||
{"used": True},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
@ -754,6 +844,14 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue