mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add type hints to synapse/storage/databases/main/e2e_room_keys.py
(#11549)
This commit is contained in:
parent
0147b3de20
commit
ecfcd9bbbe
1
changelog.d/11549.misc
Normal file
1
changelog.d/11549.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add missing type hints to storage classes.
|
4
mypy.ini
4
mypy.ini
@ -27,7 +27,6 @@ exclude = (?x)
|
|||||||
|synapse/storage/databases/main/__init__.py
|
|synapse/storage/databases/main/__init__.py
|
||||||
|synapse/storage/databases/main/cache.py
|
|synapse/storage/databases/main/cache.py
|
||||||
|synapse/storage/databases/main/devices.py
|
|synapse/storage/databases/main/devices.py
|
||||||
|synapse/storage/databases/main/e2e_room_keys.py
|
|
||||||
|synapse/storage/databases/main/event_federation.py
|
|synapse/storage/databases/main/event_federation.py
|
||||||
|synapse/storage/databases/main/event_push_actions.py
|
|synapse/storage/databases/main/event_push_actions.py
|
||||||
|synapse/storage/databases/main/events_bg_updates.py
|
|synapse/storage/databases/main/events_bg_updates.py
|
||||||
@ -197,6 +196,9 @@ disallow_untyped_defs = True
|
|||||||
[mypy-synapse.storage.databases.main.directory]
|
[mypy-synapse.storage.databases.main.directory]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-synapse.storage.databases.main.e2e_room_keys]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.storage.databases.main.end_to_end_keys]
|
[mypy-synapse.storage.databases.main.end_to_end_keys]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -14,7 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Dict, Optional
|
||||||
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
Codes,
|
Codes,
|
||||||
@ -24,6 +26,7 @@ from synapse.api.errors import (
|
|||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.logging.opentracing import log_kv, trace
|
from synapse.logging.opentracing import log_kv, trace
|
||||||
|
from synapse.storage.databases.main.e2e_room_keys import RoomKey
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
|
|
||||||
@ -58,7 +61,9 @@ class E2eRoomKeysHandler:
|
|||||||
version: str,
|
version: str,
|
||||||
room_id: Optional[str] = None,
|
room_id: Optional[str] = None,
|
||||||
session_id: Optional[str] = None,
|
session_id: Optional[str] = None,
|
||||||
) -> List[JsonDict]:
|
) -> Dict[
|
||||||
|
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
|
||||||
|
]:
|
||||||
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
||||||
room, or a given session.
|
room, or a given session.
|
||||||
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
|
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
|
||||||
@ -72,8 +77,8 @@ class E2eRoomKeysHandler:
|
|||||||
Raises:
|
Raises:
|
||||||
NotFoundError: if the backup version does not exist
|
NotFoundError: if the backup version does not exist
|
||||||
Returns:
|
Returns:
|
||||||
A list of dicts giving the session_data and message metadata for
|
A dict giving the session_data and message metadata for these room keys.
|
||||||
these room keys.
|
`{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# we deliberately take the lock to get keys so that changing the version
|
# we deliberately take the lock to get keys so that changing the version
|
||||||
@ -273,7 +278,7 @@ class E2eRoomKeysHandler:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _should_replace_room_key(
|
def _should_replace_room_key(
|
||||||
current_room_key: Optional[JsonDict], room_key: JsonDict
|
current_room_key: Optional[RoomKey], room_key: RoomKey
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Determine whether to replace a given current_room_key (if any)
|
Determine whether to replace a given current_room_key (if any)
|
||||||
|
@ -13,35 +13,71 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Dict, Iterable, Mapping, Optional, Tuple, cast
|
||||||
|
|
||||||
|
from typing_extensions import Literal, TypedDict
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.logging.opentracing import log_kv, trace
|
from synapse.logging.opentracing import log_kv, trace
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
|
from synapse.types import JsonDict, JsonSerializable
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
|
|
||||||
|
|
||||||
|
class RoomKey(TypedDict):
|
||||||
|
"""`KeyBackupData` in the Matrix spec.
|
||||||
|
|
||||||
|
https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
|
||||||
|
"""
|
||||||
|
|
||||||
|
first_message_index: int
|
||||||
|
forwarded_count: int
|
||||||
|
is_verified: bool
|
||||||
|
session_data: JsonSerializable
|
||||||
|
|
||||||
|
|
||||||
class EndToEndRoomKeyStore(SQLBaseStore):
|
class EndToEndRoomKeyStore(SQLBaseStore):
|
||||||
|
"""The store for end to end room key backups.
|
||||||
|
|
||||||
|
See https://spec.matrix.org/v1.1/client-server-api/#server-side-key-backups
|
||||||
|
|
||||||
|
As per the spec, backups are identified by an opaque version string. Internally,
|
||||||
|
version identifiers are assigned using incrementing integers. Non-numeric version
|
||||||
|
strings are treated as if they do not exist, since we would have never issued them.
|
||||||
|
"""
|
||||||
|
|
||||||
async def update_e2e_room_key(
|
async def update_e2e_room_key(
|
||||||
self, user_id, version, room_id, session_id, room_key
|
self,
|
||||||
):
|
user_id: str,
|
||||||
|
version: str,
|
||||||
|
room_id: str,
|
||||||
|
session_id: str,
|
||||||
|
room_key: RoomKey,
|
||||||
|
) -> None:
|
||||||
"""Replaces the encrypted E2E room key for a given session in a given backup
|
"""Replaces the encrypted E2E room key for a given session in a given backup
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): the user whose backup we're setting
|
user_id: the user whose backup we're setting
|
||||||
version(str): the version ID of the backup we're updating
|
version: the version ID of the backup we're updating
|
||||||
room_id(str): the ID of the room whose keys we're setting
|
room_id: the ID of the room whose keys we're setting
|
||||||
session_id(str): the session whose room_key we're setting
|
session_id: the session whose room_key we're setting
|
||||||
room_key(dict): the room_key being set
|
room_key: the room_key being set
|
||||||
Raises:
|
Raises:
|
||||||
StoreError
|
StoreError
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
version_int = int(version)
|
||||||
|
except ValueError:
|
||||||
|
# Our versions are all ints so if we can't convert it to an integer,
|
||||||
|
# it doesn't exist.
|
||||||
|
raise StoreError(404, "No backup with that version exists")
|
||||||
|
|
||||||
await self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="e2e_room_keys",
|
table="e2e_room_keys",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"version": version,
|
"version": version_int,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
},
|
},
|
||||||
@ -54,22 +90,29 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
desc="update_e2e_room_key",
|
desc="update_e2e_room_key",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def add_e2e_room_keys(self, user_id, version, room_keys):
|
async def add_e2e_room_keys(
|
||||||
|
self, user_id: str, version: str, room_keys: Iterable[Tuple[str, str, RoomKey]]
|
||||||
|
) -> None:
|
||||||
"""Bulk add room keys to a given backup.
|
"""Bulk add room keys to a given backup.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): the user whose backup we're adding to
|
user_id: the user whose backup we're adding to
|
||||||
version (str): the version ID of the backup for the set of keys we're adding to
|
version: the version ID of the backup for the set of keys we're adding to
|
||||||
room_keys (iterable[(str, str, dict)]): the keys to add, in the form
|
room_keys: the keys to add, in the form (roomID, sessionID, keyData)
|
||||||
(roomID, sessionID, keyData)
|
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
version_int = int(version)
|
||||||
|
except ValueError:
|
||||||
|
# Our versions are all ints so if we can't convert it to an integer,
|
||||||
|
# it doesn't exist.
|
||||||
|
raise StoreError(404, "No backup with that version exists")
|
||||||
|
|
||||||
values = []
|
values = []
|
||||||
for (room_id, session_id, room_key) in room_keys:
|
for (room_id, session_id, room_key) in room_keys:
|
||||||
values.append(
|
values.append(
|
||||||
{
|
{
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"version": version,
|
"version": version_int,
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"first_message_index": room_key["first_message_index"],
|
"first_message_index": room_key["first_message_index"],
|
||||||
@ -92,31 +135,39 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
|
async def get_e2e_room_keys(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
version: str,
|
||||||
|
room_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> Dict[
|
||||||
|
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
|
||||||
|
]:
|
||||||
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
||||||
room, or a given session.
|
room, or a given session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): the user whose backup we're querying
|
user_id: the user whose backup we're querying
|
||||||
version (str): the version ID of the backup for the set of keys we're querying
|
version: the version ID of the backup for the set of keys we're querying
|
||||||
room_id (str): Optional. the ID of the room whose keys we're querying, if any.
|
room_id: Optional. the ID of the room whose keys we're querying, if any.
|
||||||
If not specified, we return the keys for all the rooms in the backup.
|
If not specified, we return the keys for all the rooms in the backup.
|
||||||
session_id (str): Optional. the session whose room_key we're querying, if any.
|
session_id: Optional. the session whose room_key we're querying, if any.
|
||||||
If specified, we also require the room_id to be specified.
|
If specified, we also require the room_id to be specified.
|
||||||
If not specified, we return all the keys in this version of
|
If not specified, we return all the keys in this version of
|
||||||
the backup (or for the specified room)
|
the backup (or for the specified room)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of dicts giving the session_data and message metadata for
|
A dict giving the session_data and message metadata for these room keys.
|
||||||
these room keys.
|
`{"rooms": {room_id: {"sessions": {session_id: room_key}}}}`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
version = int(version)
|
version_int = int(version)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return {"rooms": {}}
|
return {"rooms": {}}
|
||||||
|
|
||||||
keyvalues = {"user_id": user_id, "version": version}
|
keyvalues = {"user_id": user_id, "version": version_int}
|
||||||
if room_id:
|
if room_id:
|
||||||
keyvalues["room_id"] = room_id
|
keyvalues["room_id"] = room_id
|
||||||
if session_id:
|
if session_id:
|
||||||
@ -137,7 +188,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
desc="get_e2e_room_keys",
|
desc="get_e2e_room_keys",
|
||||||
)
|
)
|
||||||
|
|
||||||
sessions = {"rooms": {}}
|
sessions: Dict[
|
||||||
|
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
|
||||||
|
] = {"rooms": {}}
|
||||||
for row in rows:
|
for row in rows:
|
||||||
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
|
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
|
||||||
room_entry["sessions"][row["session_id"]] = {
|
room_entry["sessions"][row["session_id"]] = {
|
||||||
@ -150,7 +203,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
|
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
|
async def get_e2e_room_keys_multi(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
version: str,
|
||||||
|
room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
|
||||||
|
) -> Dict[str, Dict[str, RoomKey]]:
|
||||||
"""Get multiple room keys at a time. The difference between this function and
|
"""Get multiple room keys at a time. The difference between this function and
|
||||||
get_e2e_room_keys is that this function can be used to retrieve
|
get_e2e_room_keys is that this function can be used to retrieve
|
||||||
multiple specific keys at a time, whereas get_e2e_room_keys is used for
|
multiple specific keys at a time, whereas get_e2e_room_keys is used for
|
||||||
@ -158,26 +216,36 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
specific key.
|
specific key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): the user whose backup we're querying
|
user_id: the user whose backup we're querying
|
||||||
version (str): the version ID of the backup we're querying about
|
version: the version ID of the backup we're querying about
|
||||||
room_keys (dict[str, dict[str, iterable[str]]]): a map from
|
room_keys: a map from room ID -> {"sessions": [session ids]}
|
||||||
room ID -> {"session": [session ids]} indicating the session IDs
|
indicating the session IDs that we want to query
|
||||||
that we want to query
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
|
A map of room IDs to session IDs to room key
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
version_int = int(version)
|
||||||
|
except ValueError:
|
||||||
|
# Our versions are all ints so if we can't convert it to an integer,
|
||||||
|
# it doesn't exist.
|
||||||
|
return {}
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_e2e_room_keys_multi",
|
"get_e2e_room_keys_multi",
|
||||||
self._get_e2e_room_keys_multi_txn,
|
self._get_e2e_room_keys_multi_txn,
|
||||||
user_id,
|
user_id,
|
||||||
version,
|
version_int,
|
||||||
room_keys,
|
room_keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
|
def _get_e2e_room_keys_multi_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
user_id: str,
|
||||||
|
version: int,
|
||||||
|
room_keys: Mapping[str, Mapping[Literal["sessions"], Iterable[str]]],
|
||||||
|
) -> Dict[str, Dict[str, RoomKey]]:
|
||||||
if not room_keys:
|
if not room_keys:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -209,7 +277,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
|
|
||||||
txn.execute(sql, params)
|
txn.execute(sql, params)
|
||||||
|
|
||||||
ret = {}
|
ret: Dict[str, Dict[str, RoomKey]] = {}
|
||||||
|
|
||||||
for row in txn:
|
for row in txn:
|
||||||
room_id = row[0]
|
room_id = row[0]
|
||||||
@ -231,36 +299,49 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
user_id: the user whose backup we're querying
|
user_id: the user whose backup we're querying
|
||||||
version: the version ID of the backup we're querying about
|
version: the version ID of the backup we're querying about
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
version_int = int(version)
|
||||||
|
except ValueError:
|
||||||
|
# Our versions are all ints so if we can't convert it to an integer,
|
||||||
|
# it doesn't exist.
|
||||||
|
return 0
|
||||||
|
|
||||||
return await self.db_pool.simple_select_one_onecol(
|
return await self.db_pool.simple_select_one_onecol(
|
||||||
table="e2e_room_keys",
|
table="e2e_room_keys",
|
||||||
keyvalues={"user_id": user_id, "version": version},
|
keyvalues={"user_id": user_id, "version": version_int},
|
||||||
retcol="COUNT(*)",
|
retcol="COUNT(*)",
|
||||||
desc="count_e2e_room_keys",
|
desc="count_e2e_room_keys",
|
||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def delete_e2e_room_keys(
|
async def delete_e2e_room_keys(
|
||||||
self, user_id, version, room_id=None, session_id=None
|
self,
|
||||||
):
|
user_id: str,
|
||||||
|
version: str,
|
||||||
|
room_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
|
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
|
||||||
room or a given session.
|
room or a given session.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): the user whose backup we're deleting from
|
user_id: the user whose backup we're deleting from
|
||||||
version(str): the version ID of the backup for the set of keys we're deleting
|
version: the version ID of the backup for the set of keys we're deleting
|
||||||
room_id(str): Optional. the ID of the room whose keys we're deleting, if any.
|
room_id: Optional. the ID of the room whose keys we're deleting, if any.
|
||||||
If not specified, we delete the keys for all the rooms in the backup.
|
If not specified, we delete the keys for all the rooms in the backup.
|
||||||
session_id(str): Optional. the session whose room_key we're querying, if any.
|
session_id: Optional. the session whose room_key we're querying, if any.
|
||||||
If specified, we also require the room_id to be specified.
|
If specified, we also require the room_id to be specified.
|
||||||
If not specified, we delete all the keys in this version of
|
If not specified, we delete all the keys in this version of
|
||||||
the backup (or for the specified room)
|
the backup (or for the specified room)
|
||||||
|
|
||||||
Returns:
|
|
||||||
The deletion transaction
|
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
version_int = int(version)
|
||||||
|
except ValueError:
|
||||||
|
# Our versions are all ints so if we can't convert it to an integer,
|
||||||
|
# it doesn't exist.
|
||||||
|
return
|
||||||
|
|
||||||
keyvalues = {"user_id": user_id, "version": int(version)}
|
keyvalues = {"user_id": user_id, "version": version_int}
|
||||||
if room_id:
|
if room_id:
|
||||||
keyvalues["room_id"] = room_id
|
keyvalues["room_id"] = room_id
|
||||||
if session_id:
|
if session_id:
|
||||||
@ -271,23 +352,27 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_current_version(txn, user_id):
|
def _get_current_version(txn: LoggingTransaction, user_id: str) -> int:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT MAX(version) FROM e2e_room_keys_versions "
|
"SELECT MAX(version) FROM e2e_room_keys_versions "
|
||||||
"WHERE user_id=? AND deleted=0",
|
"WHERE user_id=? AND deleted=0",
|
||||||
(user_id,),
|
(user_id,),
|
||||||
)
|
)
|
||||||
row = txn.fetchone()
|
# `SELECT MAX() FROM ...` will always return 1 row. The value in that row will
|
||||||
if not row:
|
# be `NULL` when there are no available versions.
|
||||||
|
row = cast(Tuple[Optional[int]], txn.fetchone())
|
||||||
|
if row[0] is None:
|
||||||
raise StoreError(404, "No current backup version")
|
raise StoreError(404, "No current backup version")
|
||||||
return row[0]
|
return row[0]
|
||||||
|
|
||||||
async def get_e2e_room_keys_version_info(self, user_id, version=None):
|
async def get_e2e_room_keys_version_info(
|
||||||
|
self, user_id: str, version: Optional[str] = None
|
||||||
|
) -> JsonDict:
|
||||||
"""Get info metadata about a version of our room_keys backup.
|
"""Get info metadata about a version of our room_keys backup.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): the user whose backup we're querying
|
user_id: the user whose backup we're querying
|
||||||
version(str): Optional. the version ID of the backup we're querying about
|
version: Optional. the version ID of the backup we're querying about
|
||||||
If missing, we return the information about the current version.
|
If missing, we return the information about the current version.
|
||||||
Raises:
|
Raises:
|
||||||
StoreError: with code 404 if there are no e2e_room_keys_versions present
|
StoreError: with code 404 if there are no e2e_room_keys_versions present
|
||||||
@ -300,7 +385,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
etag(int): tag of the keys in the backup
|
etag(int): tag of the keys in the backup
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _get_e2e_room_keys_version_info_txn(txn):
|
def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
|
||||||
if version is None:
|
if version is None:
|
||||||
this_version = self._get_current_version(txn, user_id)
|
this_version = self._get_current_version(txn, user_id)
|
||||||
else:
|
else:
|
||||||
@ -309,14 +394,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
# Our versions are all ints so if we can't convert it to an integer,
|
# Our versions are all ints so if we can't convert it to an integer,
|
||||||
# it isn't there.
|
# it isn't there.
|
||||||
raise StoreError(404, "No row found")
|
raise StoreError(404, "No backup with that version exists")
|
||||||
|
|
||||||
result = self.db_pool.simple_select_one_txn(
|
result = self.db_pool.simple_select_one_txn(
|
||||||
txn,
|
txn,
|
||||||
table="e2e_room_keys_versions",
|
table="e2e_room_keys_versions",
|
||||||
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
|
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
|
||||||
retcols=("version", "algorithm", "auth_data", "etag"),
|
retcols=("version", "algorithm", "auth_data", "etag"),
|
||||||
|
allow_none=False,
|
||||||
)
|
)
|
||||||
|
assert result is not None # see comment on `simple_select_one_txn`
|
||||||
result["auth_data"] = db_to_json(result["auth_data"])
|
result["auth_data"] = db_to_json(result["auth_data"])
|
||||||
result["version"] = str(result["version"])
|
result["version"] = str(result["version"])
|
||||||
if result["etag"] is None:
|
if result["etag"] is None:
|
||||||
@ -328,28 +415,28 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
|
async def create_e2e_room_keys_version(self, user_id: str, info: JsonDict) -> str:
|
||||||
"""Atomically creates a new version of this user's e2e_room_keys store
|
"""Atomically creates a new version of this user's e2e_room_keys store
|
||||||
with the given version info.
|
with the given version info.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str): the user whose backup we're creating a version
|
user_id: the user whose backup we're creating a version
|
||||||
info(dict): the info about the backup version to be created
|
info: the info about the backup version to be created
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The newly created version ID
|
The newly created version ID
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _create_e2e_room_keys_version_txn(txn):
|
def _create_e2e_room_keys_version_txn(txn: LoggingTransaction) -> str:
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
|
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
|
||||||
(user_id,),
|
(user_id,),
|
||||||
)
|
)
|
||||||
current_version = txn.fetchone()[0]
|
current_version = cast(Tuple[Optional[int]], txn.fetchone())[0]
|
||||||
if current_version is None:
|
if current_version is None:
|
||||||
current_version = "0"
|
current_version = 0
|
||||||
|
|
||||||
new_version = str(int(current_version) + 1)
|
new_version = current_version + 1
|
||||||
|
|
||||||
self.db_pool.simple_insert_txn(
|
self.db_pool.simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
@ -362,7 +449,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_version
|
return str(new_version)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
|
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
|
||||||
@ -373,7 +460,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
version: str,
|
version: str,
|
||||||
info: Optional[dict] = None,
|
info: Optional[JsonDict] = None,
|
||||||
version_etag: Optional[int] = None,
|
version_etag: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a given backup version
|
"""Update a given backup version
|
||||||
@ -386,7 +473,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
version_etag: etag of the keys in the backup. If None, then the etag
|
version_etag: etag of the keys in the backup. If None, then the etag
|
||||||
is not updated.
|
is not updated.
|
||||||
"""
|
"""
|
||||||
updatevalues = {}
|
updatevalues: Dict[str, object] = {}
|
||||||
|
|
||||||
if info is not None and "auth_data" in info:
|
if info is not None and "auth_data" in info:
|
||||||
updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
|
updatevalues["auth_data"] = json_encoder.encode(info["auth_data"])
|
||||||
@ -394,9 +481,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
updatevalues["etag"] = version_etag
|
updatevalues["etag"] = version_etag
|
||||||
|
|
||||||
if updatevalues:
|
if updatevalues:
|
||||||
await self.db_pool.simple_update(
|
try:
|
||||||
|
version_int = int(version)
|
||||||
|
except ValueError:
|
||||||
|
# Our versions are all ints so if we can't convert it to an integer,
|
||||||
|
# it doesn't exist.
|
||||||
|
raise StoreError(404, "No backup with that version exists")
|
||||||
|
|
||||||
|
await self.db_pool.simple_update_one(
|
||||||
table="e2e_room_keys_versions",
|
table="e2e_room_keys_versions",
|
||||||
keyvalues={"user_id": user_id, "version": version},
|
keyvalues={"user_id": user_id, "version": version_int},
|
||||||
updatevalues=updatevalues,
|
updatevalues=updatevalues,
|
||||||
desc="update_e2e_room_keys_version",
|
desc="update_e2e_room_keys_version",
|
||||||
)
|
)
|
||||||
@ -417,13 +511,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
|||||||
or if the version requested doesn't exist.
|
or if the version requested doesn't exist.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _delete_e2e_room_keys_version_txn(txn):
|
def _delete_e2e_room_keys_version_txn(txn: LoggingTransaction) -> None:
|
||||||
if version is None:
|
if version is None:
|
||||||
this_version = self._get_current_version(txn, user_id)
|
this_version = self._get_current_version(txn, user_id)
|
||||||
if this_version is None:
|
|
||||||
raise StoreError(404, "No current backup version")
|
|
||||||
else:
|
else:
|
||||||
this_version = version
|
try:
|
||||||
|
this_version = int(version)
|
||||||
|
except ValueError:
|
||||||
|
# Our versions are all ints so if we can't convert it to an integer,
|
||||||
|
# it isn't there.
|
||||||
|
raise StoreError(404, "No backup with that version exists")
|
||||||
|
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -59,9 +59,11 @@ StateKey = Tuple[str, str]
|
|||||||
StateMap = Mapping[StateKey, T]
|
StateMap = Mapping[StateKey, T]
|
||||||
MutableStateMap = MutableMapping[StateKey, T]
|
MutableStateMap = MutableMapping[StateKey, T]
|
||||||
|
|
||||||
# the type of a JSON-serialisable dict. This could be made stronger, but it will
|
# JSON types. These could be made stronger, but will do for now.
|
||||||
# do for now.
|
# A JSON-serialisable dict.
|
||||||
JsonDict = Dict[str, Any]
|
JsonDict = Dict[str, Any]
|
||||||
|
# A JSON-serialisable object.
|
||||||
|
JsonSerializable = object
|
||||||
|
|
||||||
|
|
||||||
# Note that this seems to require inheriting *directly* from Interface in order
|
# Note that this seems to require inheriting *directly* from Interface in order
|
||||||
|
@ -12,10 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.storage.databases.main.e2e_room_keys import RoomKey
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
# sample room_key data for use in the tests
|
# sample room_key data for use in the tests
|
||||||
room_key = {
|
room_key: RoomKey = {
|
||||||
"first_message_index": 1,
|
"first_message_index": 1,
|
||||||
"forwarded_count": 1,
|
"forwarded_count": 1,
|
||||||
"is_verified": False,
|
"is_verified": False,
|
||||||
|
Loading…
Reference in New Issue
Block a user