synapse-product/synapse/storage/databases/main/e2e_room_keys.py
Patrick Cloke d8cc86eff4
Remove redundant types from comments. (#14412)
Remove type hints from comments which have been added
as Python type hints. This helps avoid drift between comments
and reality, as well as removing redundant information.

Also adds some missing type hints which were simple to fill in.
2022-11-16 15:25:24 +00:00

553 lines
20 KiB
Python

# Copyright 2017 New Vector Ltd
# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, JsonSerializable, StreamKeyType
from synapse.util import json_encoder
class RoomKey(TypedDict):
"""`KeyBackupData` in the Matrix spec.
https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
"""
first_message_index: int
forwarded_count: int
is_verified: bool
session_data: JsonSerializable
class EndToEndRoomKeyStore(SQLBaseStore):
"""The store for end to end room key backups.
See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups
As per the spec, backups are identified by an opaque version string. Internally,
version identifiers are assigned using incrementing integers. Non-numeric version
strings are treated as if they do not exist, since we would have never issued them.
"""
async def update_e2e_room_key(
self,
user_id: str,
version: str,
room_id: str,
session_id: str,
room_key: RoomKey,
) -> None:
"""Replaces the encrypted E2E room key for a given session in a given backup
Args:
user_id: the user whose backup we're setting
version: the version ID of the backup we're updating
room_id: the ID of the room whose keys we're setting
session_id: the session whose room_key we're setting
room_key: the room_key being set
Raises:
StoreError
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No backup with that version exists")
await self.db_pool.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
"version": version_int,
"room_id": room_id,
"session_id": session_id,
},
updatevalues={
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
"session_data": json_encoder.encode(room_key["session_data"]),
},
desc="update_e2e_room_key",
)
async def add_e2e_room_keys(
self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]]
) -> None:
"""Bulk add room keys to a given backup.
Args:
user_id: the user whose backup we're adding to
version: the version ID of the backup for the set of keys we're adding to
room_keys: the keys to add, in the form (roomID, sessionID, keyData)
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No backup with that version exists")
values = []
for (room_id, session_id, room_key) in room_keys:
values.append(
(
user_id,
version_int,
room_id,
session_id,
room_key["first_message_index"],
room_key["forwarded_count"],
room_key["is_verified"],
json_encoder.encode(room_key["session_data"]),
)
)
log_kv(
{
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
StreamKeyType.ROOM: room_key,
}
)
await self.db_pool.simple_insert_many(
table="e2e_room_keys",
keys=(
"user_id",
"version",
"room_id",
"session_id",
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
),
values=values,
desc="add_e2e_room_keys",
)
@trace
async def get_e2e_room_keys(
self,
user_id: str,
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
Args:
user_id: the user whose backup we're querying
version: the version ID of the backup for the set of keys we're querying
room_id: Optional. the ID of the room whose keys we're querying, if any.
If not specified, we return the keys for all the rooms in the backup.
session_id: Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we return all the keys in this version of
the backup (or for the specified room)
Returns:
A dict giving the session_data and message metadata for these room keys.
`{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
"""
try:
version_int = int(version)
except ValueError:
return {"rooms": {}}
keyvalues = {"user_id": user_id, "version": version_int}
if room_id:
keyvalues["room_id"] = room_id
if session_id:
keyvalues["session_id"] = session_id
rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
"user_id",
"room_id",
"session_id",
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
),
desc="get_e2e_room_keys",
)
sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
for row in rows:
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
"session_data": db_to_json(row["session_data"]),
}
return sessions
async def get_e2e_room_keys_multi(
self,
user_id: str,
version: str,
room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
) -> Dict[str, Dict[str, RoomKey]]:
"""Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for
getting all the keys in a backup version, all the keys for a room, or a
specific key.
Args:
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
room_keys: a map from room ID -> {"sessions": [session ids]}
indicating the session IDs that we want to query
Returns:
A map of room IDs to session IDs to room key
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return {}
return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
version_int,
room_keys,
)
@staticmethod
def _get_e2e_room_keys_multi_txn(
txn: LoggingTransaction,
user_id: str,
version: int,
room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
) -> Dict[str, Dict[str, RoomKey]]:
if not room_keys:
return {}
where_clauses = []
params = [user_id, version]
for room_id, room in room_keys.items():
sessions = list(room["sessions"])
if not sessions:
continue
params.append(room_id)
params.extend(sessions)
where_clauses.append(
"(room_id = ? AND session_id IN (%s))"
% (",".join(["?" for _ in sessions]),)
)
# check if we're actually querying something
if not where_clauses:
return {}
sql = """
SELECT room_id, session_id, first_message_index, forwarded_count,
is_verified, session_data
FROM e2e_room_keys
WHERE user_id = ? AND version = ? AND (%s)
""" % (
" OR ".join(where_clauses)
)
txn.execute(sql, params)
ret: Dict[str, Dict[str, RoomKey]] = {}
for row in txn:
room_id = row[0]
session_id = row[1]
ret.setdefault(room_id, {})
ret[room_id][session_id] = {
"first_message_index": row[2],
"forwarded_count": row[3],
"is_verified": row[4],
"session_data": db_to_json(row[5]),
}
return ret
async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version.
Args:
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return 0
return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version_int},
retcol="COUNT(*)",
desc="count_e2e_room_keys",
)
@trace
async def delete_e2e_room_keys(
self,
user_id: str,
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> None:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
Args:
user_id: the user whose backup we're deleting from
version: the version ID of the backup for the set of keys we're deleting
room_id: Optional. the ID of the room whose keys we're deleting, if any.
If not specified, we delete the keys for all the rooms in the backup.
session_id: Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we delete all the keys in this version of
the backup (or for the specified room)
"""
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return
keyvalues = {"user_id": user_id, "version": version_int}
if room_id:
keyvalues["room_id"] = room_id
if session_id:
keyvalues["session_id"] = session_id
await self.db_pool.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@staticmethod
def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions "
"WHERE user_id=? AND deleted=0",
(user_id,),
)
# `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
# be `NULL` when there are no available versions.
row = cast(Tuple[Optional[int]], txn.fetchone())
if row[0] is None:
raise StoreError(404, "No current backup version")
return row[0]
async def get_e2e_room_keys_version_info(
self, user_id: str, version: Optional[str] = None
) -> JsonDict:
"""Get info metadata about a version of our room_keys backup.
Args:
user_id: the user whose backup we're querying
version: Optional. the version ID of the backup we're querying about
If missing, we return the information about the current version.
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present
Returns:
A dict giving the info metadata for this backup version, with
fields including:
version (str)
algorithm (str)
auth_data (object): opaque dict supplied by the client
etag (int): tag of the keys in the backup
"""
def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
if version is None:
this_version = self._get_current_version(txn, user_id)
else:
try:
this_version = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it isn't there.
raise StoreError(404, "No backup with that version exists")
result = self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
)
assert result is not None # see comment on `simple_select_one_txn`
result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
return result
return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
@trace
async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str:
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
Args:
user_id: the user whose backup we're creating a version
info: the info about the backup version to be created
Returns:
The newly created version ID
"""
def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
(user_id,),
)
current_version = cast(Tuple[Optional[int]], txn.fetchone())[0]
if current_version is None:
current_version = 0
new_version = current_version + 1
self.db_pool.simple_insert_txn(
txn,
table="e2e_room_keys_versions",
values={
"user_id": user_id,
"version": new_version,
"algorithm": info["algorithm"],
"auth_data": json_encoder.encode(info["auth_data"]),
},
)
return str(new_version)
return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
@trace
async def update_e2e_room_keys_version(
self,
user_id: str,
version: str,
info: Optional[JsonDict] = None,
version_etag: Optional[int] = None,
) -> None:
"""Update a given backup version
Args:
user_id: the user whose backup version we're updating
version: the version ID of the backup version we're updating
info: the new backup version info to store. If None, then the backup
version info is not updated.
version_etag: etag of the keys in the backup. If None, then the etag
is not updated.
"""
updatevalues: Dict[str, object] = {}
if info is not None and "auth_data" in info:
updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
if version_etag is not None:
updatevalues["etag"] = version_etag
if updatevalues:
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No backup with that version exists")
await self.db_pool.simple_update_one(
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version_int},
updatevalues=updatevalues,
desc="update_e2e_room_keys_version",
)
@trace
async def delete_e2e_room_keys_version(
self, user_id: str, version: Optional[str] = None
) -> None:
"""Delete a given backup version of the user's room keys.
Doesn't delete their actual key data.
Args:
user_id: the user whose backup version we're deleting
version: Optional. the version ID of the backup version we're deleting
If missing, we delete the current backup version info.
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present,
or if the version requested doesn't exist.
"""
def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
if version is None:
this_version = self._get_current_version(txn, user_id)
else:
try:
this_version = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it isn't there.
raise StoreError(404, "No backup with that version exists")
self.db_pool.simple_delete_txn(
txn,
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": this_version},
)
self.db_pool.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
await self.db_pool.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)