Add type hints to synapse/storage/databases/main/e2e_room_keys.py (#11549)

This commit is contained in:
Sean Quah 2021-12-14 17:46:47 +00:00 committed by GitHub
parent 0147b3de20
commit ecfcd9bbbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 188 additions and 79 deletions

1
changelog.d/11549.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to storage classes.

View File

@ -27,7 +27,6 @@ exclude = (?x)
|synapse/storage/databases/main/__init__.py |synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/cache.py |synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py |synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/e2e_room_keys.py
|synapse/storage/databases/main/event_federation.py |synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/event_push_actions.py |synapse/storage/databases/main/event_push_actions.py
|synapse/storage/databases/main/events_bg_updates.py |synapse/storage/databases/main/events_bg_updates.py
@ -197,6 +196,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.directory] [mypy-synapse.storage.databases.main.directory]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.e2e_room_keys]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.end_to_end_keys] [mypy-synapse.storage.databases.main.end_to_end_keys]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -14,7 +14,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Dict, Optional
from typing_extensions import Literal
from synapse.api.errors import ( from synapse.api.errors import (
Codes, Codes,
@ -24,6 +26,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.logging.opentracing import log_kv, trace from synapse.logging.opentracing import log_kv, trace
from synapse.storage.databases.main.e2e_room_keys import RoomKey
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -58,7 +61,9 @@ class E2eRoomKeysHandler:
version: str, version: str,
room_id: Optional[str] = None, room_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
) -> List[JsonDict]: ) -> Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given """Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session. room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details. See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
@ -72,8 +77,8 @@ class E2eRoomKeysHandler:
Raises: Raises:
NotFoundError: if the backup version does not exist NotFoundError: if the backup version does not exist
Returns: Returns:
A list of dicts giving the session_data and message metadata for A dict giving the session_data and message metadata for these room keys.
these room keys. `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
""" """
# we deliberately take the lock to get keys so that changing the version # we deliberately take the lock to get keys so that changing the version
@ -273,7 +278,7 @@ class E2eRoomKeysHandler:
@staticmethod @staticmethod
def _should_replace_room_key( def _should_replace_room_key(
current_room_key: Optional[JsonDict], room_key: JsonDict current_room_key: Optional[RoomKey], room_key: RoomKey
) -> bool: ) -> bool:
""" """
Determine whether to replace a given current_room_key (if any) Determine whether to replace a given current_room_key (if any)

View File

@ -13,35 +13,71 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional from typing import Dict, Iterable, Mapping, Optional, Tuple, cast
from typing_extensions import Literal, TypedDict
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, JsonSerializable
from synapse.util import json_encoder from synapse.util import json_encoder
class RoomKey(TypedDict):
"""`KeyBackupData` in the Matrix spec.
https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
"""
first_message_index: int
forwarded_count: int
is_verified: bool
session_data: JsonSerializable
class EndToEndRoomKeyStore(SQLBaseStore): class EndToEndRoomKeyStore(SQLBaseStore):
"""The store for end to end room key backups.
See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups
As per the spec, backups are identified by an opaque version string. Internally,
version identifiers are assigned using incrementing integers. Non-numeric version
strings are treated as if they do not exist, since we would have never issued them.
"""
async def update_e2e_room_key( async def update_e2e_room_key(
self, user_id, version, room_id, session_id, room_key self,
): user_id: str,
version: str,
room_id: str,
session_id: str,
room_key: RoomKey,
) -> None:
"""Replaces the encrypted E2E room key for a given session in a given backup """Replaces the encrypted E2E room key for a given session in a given backup
Args: Args:
user_id(str): the user whose backup we're setting user_id: the user whose backup we're setting
version(str): the version ID of the backup we're updating version: the version ID of the backup we're updating
room_id(str): the ID of the room whose keys we're setting room_id: the ID of the room whose keys we're setting
session_id(str): the session whose room_key we're setting session_id: the session whose room_key we're setting
room_key(dict): the room_key being set room_key: the room_key being set
Raises: Raises:
StoreError StoreError
""" """
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No backup with that version exists")
await self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="e2e_room_keys", table="e2e_room_keys",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
"version": version, "version": version_int,
"room_id": room_id, "room_id": room_id,
"session_id": session_id, "session_id": session_id,
}, },
@ -54,22 +90,29 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="update_e2e_room_key", desc="update_e2e_room_key",
) )
async def add_e2e_room_keys(self, user_id, version, room_keys): async def add_e2e_room_keys(
self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]]
) -> None:
"""Bulk add room keys to a given backup. """Bulk add room keys to a given backup.
Args: Args:
user_id (str): the user whose backup we're adding to user_id: the user whose backup we're adding to
version (str): the version ID of the backup for the set of keys we're adding to version: the version ID of the backup for the set of keys we're adding to
room_keys (iterable[(str, str, dict)]): the keys to add, in the form room_keys: the keys to add, in the form (roomID, sessionID, keyData)
(roomID, sessionID, keyData)
""" """
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No backup with that version exists")
values = [] values = []
for (room_id, session_id, room_key) in room_keys: for (room_id, session_id, room_key) in room_keys:
values.append( values.append(
{ {
"user_id": user_id, "user_id": user_id,
"version": version, "version": version_int,
"room_id": room_id, "room_id": room_id,
"session_id": session_id, "session_id": session_id,
"first_message_index": room_key["first_message_index"], "first_message_index": room_key["first_message_index"],
@ -92,31 +135,39 @@ class EndToEndRoomKeyStore(SQLBaseStore):
) )
@trace @trace
async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None): async def get_e2e_room_keys(
self,
user_id: str,
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
]:
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given """Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session. room, or a given session.
Args: Args:
user_id (str): the user whose backup we're querying user_id: the user whose backup we're querying
version (str): the version ID of the backup for the set of keys we're querying version: the version ID of the backup for the set of keys we're querying
room_id (str): Optional. the ID of the room whose keys we're querying, if any. room_id: Optional. the ID of the room whose keys we're querying, if any.
If not specified, we return the keys for all the rooms in the backup. If not specified, we return the keys for all the rooms in the backup.
session_id (str): Optional. the session whose room_key we're querying, if any. session_id: Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified. If specified, we also require the room_id to be specified.
If not specified, we return all the keys in this version of If not specified, we return all the keys in this version of
the backup (or for the specified room) the backup (or for the specified room)
Returns: Returns:
A list of dicts giving the session_data and message metadata for A dict giving the session_data and message metadata for these room keys.
these room keys. `{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
""" """
try: try:
version = int(version) version_int = int(version)
except ValueError: except ValueError:
return {"rooms": {}} return {"rooms": {}}
keyvalues = {"user_id": user_id, "version": version} keyvalues = {"user_id": user_id, "version": version_int}
if room_id: if room_id:
keyvalues["room_id"] = room_id keyvalues["room_id"] = room_id
if session_id: if session_id:
@ -137,7 +188,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="get_e2e_room_keys", desc="get_e2e_room_keys",
) )
sessions = {"rooms": {}} sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
for row in rows: for row in rows:
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}}) room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
room_entry["sessions"][row["session_id"]] = { room_entry["sessions"][row["session_id"]] = {
@ -150,7 +203,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return sessions return sessions
async def get_e2e_room_keys_multi(self, user_id, version, room_keys): async def get_e2e_room_keys_multi(
self,
user_id: str,
version: str,
room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
) -> Dict[str, Dict[str, RoomKey]]:
"""Get multiple room keys at a time. The difference between this function and """Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for multiple specific keys at a time, whereas get_e2e_room_keys is used for
@ -158,26 +216,36 @@ class EndToEndRoomKeyStore(SQLBaseStore):
specific key. specific key.
Args: Args:
user_id (str): the user whose backup we're querying user_id: the user whose backup we're querying
version (str): the version ID of the backup we're querying about version: the version ID of the backup we're querying about
room_keys (dict[str, dict[str, iterable[str]]]): a map from room_keys: a map from room ID -> {"sessions": [session ids]}
room ID -> {"session": [session ids]} indicating the session IDs indicating the session IDs that we want to query
that we want to query
Returns: Returns:
dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key A map of room IDs to session IDs to room key
""" """
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return {}
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi", "get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn, self._get_e2e_room_keys_multi_txn,
user_id, user_id,
version, version_int,
room_keys, room_keys,
) )
@staticmethod @staticmethod
def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys): def _get_e2e_room_keys_multi_txn(
txn: LoggingTransaction,
user_id: str,
version: int,
room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
) -> Dict[str, Dict[str, RoomKey]]:
if not room_keys: if not room_keys:
return {} return {}
@ -209,7 +277,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
txn.execute(sql, params) txn.execute(sql, params)
ret = {} ret: Dict[str, Dict[str, RoomKey]] = {}
for row in txn: for row in txn:
room_id = row[0] room_id = row[0]
@ -231,36 +299,49 @@ class EndToEndRoomKeyStore(SQLBaseStore):
user_id: the user whose backup we're querying user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about version: the version ID of the backup we're querying about
""" """
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return 0
return await self.db_pool.simple_select_one_onecol( return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys", table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version}, keyvalues={"user_id": user_id, "version": version_int},
retcol="COUNT(*)", retcol="COUNT(*)",
desc="count_e2e_room_keys", desc="count_e2e_room_keys",
) )
@trace @trace
async def delete_e2e_room_keys( async def delete_e2e_room_keys(
self, user_id, version, room_id=None, session_id=None self,
): user_id: str,
version: str,
room_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> None:
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given """Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session. room or a given session.
Args: Args:
user_id(str): the user whose backup we're deleting from user_id: the user whose backup we're deleting from
version(str): the version ID of the backup for the set of keys we're deleting version: the version ID of the backup for the set of keys we're deleting
room_id(str): Optional. the ID of the room whose keys we're deleting, if any. room_id: Optional. the ID of the room whose keys we're deleting, if any.
If not specified, we delete the keys for all the rooms in the backup. If not specified, we delete the keys for all the rooms in the backup.
session_id(str): Optional. the session whose room_key we're querying, if any. session_id: Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified. If specified, we also require the room_id to be specified.
If not specified, we delete all the keys in this version of If not specified, we delete all the keys in this version of
the backup (or for the specified room) the backup (or for the specified room)
Returns:
The deletion transaction
""" """
try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
return
keyvalues = {"user_id": user_id, "version": int(version)} keyvalues = {"user_id": user_id, "version": version_int}
if room_id: if room_id:
keyvalues["room_id"] = room_id keyvalues["room_id"] = room_id
if session_id: if session_id:
@ -271,23 +352,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
) )
@staticmethod @staticmethod
def _get_current_version(txn, user_id): def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
txn.execute( txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions " "SELECT MAX(version) FROM e2e_room_keys_versions "
"WHERE user_id=? AND deleted=0", "WHERE user_id=? AND deleted=0",
(user_id,), (user_id,),
) )
row = txn.fetchone() # `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
if not row: # be `NULL` when there are no available versions.
row = cast(Tuple[Optional[int]], txn.fetchone())
if row[0] is None:
raise StoreError(404, "No current backup version") raise StoreError(404, "No current backup version")
return row[0] return row[0]
async def get_e2e_room_keys_version_info(self, user_id, version=None): async def get_e2e_room_keys_version_info(
self, user_id: str, version: Optional[str] = None
) -> JsonDict:
"""Get info metadata about a version of our room_keys backup. """Get info metadata about a version of our room_keys backup.
Args: Args:
user_id(str): the user whose backup we're querying user_id: the user whose backup we're querying
version(str): Optional. the version ID of the backup we're querying about version: Optional. the version ID of the backup we're querying about
If missing, we return the information about the current version. If missing, we return the information about the current version.
Raises: Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present StoreError: with code 404 if there are no e2e_room_keys_versions present
@ -300,7 +385,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
etag(int): tag of the keys in the backup etag(int): tag of the keys in the backup
""" """
def _get_e2e_room_keys_version_info_txn(txn): def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
if version is None: if version is None:
this_version = self._get_current_version(txn, user_id) this_version = self._get_current_version(txn, user_id)
else: else:
@ -309,14 +394,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
except ValueError: except ValueError:
# Our versions are all ints so if we can't convert it to an integer, # Our versions are all ints so if we can't convert it to an integer,
# it isn't there. # it isn't there.
raise StoreError(404, "No row found") raise StoreError(404, "No backup with that version exists")
result = self.db_pool.simple_select_one_txn( result = self.db_pool.simple_select_one_txn(
txn, txn,
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0}, keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"), retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
) )
assert result is not None # see comment on `simple_select_one_txn`
result["auth_data"] = db_to_json(result["auth_data"]) result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"]) result["version"] = str(result["version"])
if result["etag"] is None: if result["etag"] is None:
@ -328,28 +415,28 @@ class EndToEndRoomKeyStore(SQLBaseStore):
) )
@trace @trace
async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str: async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str:
"""Atomically creates a new version of this user's e2e_room_keys store """Atomically creates a new version of this user's e2e_room_keys store
with the given version info. with the given version info.
Args: Args:
user_id(str): the user whose backup we're creating a version user_id: the user whose backup we're creating a version
info(dict): the info about the backup version to be created info: the info about the backup version to be created
Returns: Returns:
The newly created version ID The newly created version ID
""" """
def _create_e2e_room_keys_version_txn(txn): def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
txn.execute( txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?", "SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
(user_id,), (user_id,),
) )
current_version = txn.fetchone()[0] current_version = cast(Tuple[Optional[int]], txn.fetchone())[0]
if current_version is None: if current_version is None:
current_version = "0" current_version = 0
new_version = str(int(current_version) + 1) new_version = current_version + 1
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
@ -362,7 +449,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
}, },
) )
return new_version return str(new_version)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn "create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
@ -373,7 +460,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
self, self,
user_id: str, user_id: str,
version: str, version: str,
info: Optional[dict] = None, info: Optional[JsonDict] = None,
version_etag: Optional[int] = None, version_etag: Optional[int] = None,
) -> None: ) -> None:
"""Update a given backup version """Update a given backup version
@ -386,7 +473,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version_etag: etag of the keys in the backup. If None, then the etag version_etag: etag of the keys in the backup. If None, then the etag
is not updated. is not updated.
""" """
updatevalues = {} updatevalues: Dict[str, object] = {}
if info is not None and "auth_data" in info: if info is not None and "auth_data" in info:
updatevalues["auth_data"] = json_encoder.encode(info["auth_data"]) updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
@ -394,9 +481,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues["etag"] = version_etag updatevalues["etag"] = version_etag
if updatevalues: if updatevalues:
await self.db_pool.simple_update( try:
version_int = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it doesn't exist.
raise StoreError(404, "No backup with that version exists")
await self.db_pool.simple_update_one(
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version}, keyvalues={"user_id": user_id, "version": version_int},
updatevalues=updatevalues, updatevalues=updatevalues,
desc="update_e2e_room_keys_version", desc="update_e2e_room_keys_version",
) )
@ -417,13 +511,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
or if the version requested doesn't exist. or if the version requested doesn't exist.
""" """
def _delete_e2e_room_keys_version_txn(txn): def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
if version is None: if version is None:
this_version = self._get_current_version(txn, user_id) this_version = self._get_current_version(txn, user_id)
if this_version is None:
raise StoreError(404, "No current backup version")
else: else:
this_version = version try:
this_version = int(version)
except ValueError:
# Our versions are all ints so if we can't convert it to an integer,
# it isn't there.
raise StoreError(404, "No backup with that version exists")
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,

View File

@ -59,9 +59,11 @@ StateKey = Tuple[str, str]
StateMap = Mapping[StateKey, T] StateMap = Mapping[StateKey, T]
MutableStateMap = MutableMapping[StateKey, T] MutableStateMap = MutableMapping[StateKey, T]
# the type of a JSON-serialisable dict. This could be made stronger, but it will # JSON types. These could be made stronger, but will do for now.
# do for now. # A JSON-serialisable dict.
JsonDict = Dict[str, Any] JsonDict = Dict[str, Any]
# A JSON-serialisable object.
JsonSerializable = object
# Note that this seems to require inheriting *directly* from Interface in order # Note that this seems to require inheriting *directly* from Interface in order

View File

@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage.databases.main.e2e_room_keys import RoomKey
from tests import unittest from tests import unittest
# sample room_key data for use in the tests # sample room_key data for use in the tests
room_key = { room_key: RoomKey = {
"first_message_index": 1, "first_message_index": 1,
"forwarded_count": 1, "forwarded_count": 1,
"is_verified": False, "is_verified": False,