mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-06 19:27:52 -05:00
Convert devices database to async/await. (#8069)
This commit is contained in:
parent
5dd73d029e
commit
5ecc8b5825
1
changelog.d/8069.misc
Normal file
1
changelog.d/8069.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert various parts of the codebase to async/await.
|
@ -15,9 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import Codes, StoreError
|
from synapse.api.errors import Codes, StoreError
|
||||||
from synapse.logging.opentracing import (
|
from synapse.logging.opentracing import (
|
||||||
@ -33,14 +31,9 @@ from synapse.storage.database import (
|
|||||||
LoggingTransaction,
|
LoggingTransaction,
|
||||||
make_tuple_comparison_clause,
|
make_tuple_comparison_clause,
|
||||||
)
|
)
|
||||||
from synapse.types import Collection, get_verify_key_from_cross_signing_key
|
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import (
|
from synapse.util.caches.descriptors import Cache, cached, cachedList
|
||||||
Cache,
|
|
||||||
cached,
|
|
||||||
cachedInlineCallbacks,
|
|
||||||
cachedList,
|
|
||||||
)
|
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.stringutils import shortstr
|
from synapse.util.stringutils import shortstr
|
||||||
|
|
||||||
@ -54,13 +47,13 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
|||||||
|
|
||||||
|
|
||||||
class DeviceWorkerStore(SQLBaseStore):
|
class DeviceWorkerStore(SQLBaseStore):
|
||||||
def get_device(self, user_id, device_id):
|
def get_device(self, user_id: str, device_id: str):
|
||||||
"""Retrieve a device. Only returns devices that are not marked as
|
"""Retrieve a device. Only returns devices that are not marked as
|
||||||
hidden.
|
hidden.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The ID of the user which owns the device
|
user_id: The ID of the user which owns the device
|
||||||
device_id (str): The ID of the device to retrieve
|
device_id: The ID of the device to retrieve
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred for a dict containing the device information
|
defer.Deferred for a dict containing the device information
|
||||||
Raises:
|
Raises:
|
||||||
@ -73,19 +66,17 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
desc="get_device",
|
desc="get_device",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
|
||||||
def get_devices_by_user(self, user_id):
|
|
||||||
"""Retrieve all of a user's registered devices. Only returns devices
|
"""Retrieve all of a user's registered devices. Only returns devices
|
||||||
that are not marked as hidden.
|
that are not marked as hidden.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str):
|
user_id:
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: resolves to a dict from device_id to a dict
|
A mapping from device_id to a dict containing "device_id", "user_id"
|
||||||
containing "device_id", "user_id" and "display_name" for each
|
and "display_name" for each device.
|
||||||
device.
|
|
||||||
"""
|
"""
|
||||||
devices = yield self.db_pool.simple_select_list(
|
devices = await self.db_pool.simple_select_list(
|
||||||
table="devices",
|
table="devices",
|
||||||
keyvalues={"user_id": user_id, "hidden": False},
|
keyvalues={"user_id": user_id, "hidden": False},
|
||||||
retcols=("user_id", "device_id", "display_name"),
|
retcols=("user_id", "device_id", "display_name"),
|
||||||
@ -95,19 +86,20 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
return {d["device_id"]: d for d in devices}
|
return {d["device_id"]: d for d in devices}
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
async def get_device_updates_by_remote(
|
||||||
def get_device_updates_by_remote(self, destination, from_stream_id, limit):
|
self, destination: str, from_stream_id: int, limit: int
|
||||||
|
) -> Tuple[int, List[Tuple[str, dict]]]:
|
||||||
"""Get a stream of device updates to send to the given remote server.
|
"""Get a stream of device updates to send to the given remote server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): The host the device updates are intended for
|
destination: The host the device updates are intended for
|
||||||
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
||||||
limit (int): Maximum number of device updates to return
|
limit: Maximum number of device updates to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[tuple[int, list[tuple[string,dict]]]]:
|
A mapping from the current stream id (ie, the stream id of the last
|
||||||
current stream id (ie, the stream id of the last update included in the
|
update included in the response), and the list of updates, where
|
||||||
response), and the list of updates, where each update is a pair of EDU
|
each update is a pair of EDU type and EDU contents.
|
||||||
type and EDU contents
|
|
||||||
"""
|
"""
|
||||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
@ -117,7 +109,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
if not has_changed:
|
if not has_changed:
|
||||||
return now_stream_id, []
|
return now_stream_id, []
|
||||||
|
|
||||||
updates = yield self.db_pool.runInteraction(
|
updates = await self.db_pool.runInteraction(
|
||||||
"get_device_updates_by_remote",
|
"get_device_updates_by_remote",
|
||||||
self._get_device_updates_by_remote_txn,
|
self._get_device_updates_by_remote_txn,
|
||||||
destination,
|
destination,
|
||||||
@ -136,9 +128,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
master_key_by_user = {}
|
master_key_by_user = {}
|
||||||
self_signing_key_by_user = {}
|
self_signing_key_by_user = {}
|
||||||
for user in users:
|
for user in users:
|
||||||
cross_signing_key = yield defer.ensureDeferred(
|
cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
|
||||||
self.get_e2e_cross_signing_key(user, "master")
|
|
||||||
)
|
|
||||||
if cross_signing_key:
|
if cross_signing_key:
|
||||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||||
cross_signing_key
|
cross_signing_key
|
||||||
@ -151,8 +141,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
"device_id": verify_key.version,
|
"device_id": verify_key.version,
|
||||||
}
|
}
|
||||||
|
|
||||||
cross_signing_key = yield defer.ensureDeferred(
|
cross_signing_key = await self.get_e2e_cross_signing_key(
|
||||||
self.get_e2e_cross_signing_key(user, "self_signing")
|
user, "self_signing"
|
||||||
)
|
)
|
||||||
if cross_signing_key:
|
if cross_signing_key:
|
||||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||||
@ -202,7 +192,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
if update_stream_id > previous_update_stream_id:
|
if update_stream_id > previous_update_stream_id:
|
||||||
query_map[key] = (update_stream_id, update_context)
|
query_map[key] = (update_stream_id, update_context)
|
||||||
|
|
||||||
results = yield self._get_device_update_edus_by_remote(
|
results = await self._get_device_update_edus_by_remote(
|
||||||
destination, from_stream_id, query_map
|
destination, from_stream_id, query_map
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -215,16 +205,21 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
return now_stream_id, results
|
return now_stream_id, results
|
||||||
|
|
||||||
def _get_device_updates_by_remote_txn(
|
def _get_device_updates_by_remote_txn(
|
||||||
self, txn, destination, from_stream_id, now_stream_id, limit
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
destination: str,
|
||||||
|
from_stream_id: int,
|
||||||
|
now_stream_id: int,
|
||||||
|
limit: int,
|
||||||
):
|
):
|
||||||
"""Return device update information for a given remote destination
|
"""Return device update information for a given remote destination
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn (LoggingTransaction): The transaction to execute
|
txn: The transaction to execute
|
||||||
destination (str): The host the device updates are intended for
|
destination: The host the device updates are intended for
|
||||||
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
||||||
now_stream_id (int): The maximum stream_id to filter updates by, inclusive
|
now_stream_id: The maximum stream_id to filter updates by, inclusive
|
||||||
limit (int): Maximum number of device updates to return
|
limit: Maximum number of device updates to return
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List: List of device updates
|
List: List of device updates
|
||||||
@ -240,23 +235,26 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return list(txn)
|
return list(txn)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _get_device_update_edus_by_remote(
|
||||||
def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
|
self,
|
||||||
|
destination: str,
|
||||||
|
from_stream_id: int,
|
||||||
|
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
|
||||||
|
) -> List[Tuple[str, dict]]:
|
||||||
"""Returns a list of device update EDUs as well as E2EE keys
|
"""Returns a list of device update EDUs as well as E2EE keys
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): The host the device updates are intended for
|
destination: The host the device updates are intended for
|
||||||
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
||||||
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
|
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
|
||||||
user_id/device_id to update stream_id and the relevant json-encoded
|
user_id/device_id to update stream_id and the relevant json-encoded
|
||||||
opentracing context
|
opentracing context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict]: List of objects representing an device update EDU
|
List of objects representing an device update EDU
|
||||||
|
|
||||||
"""
|
"""
|
||||||
devices = (
|
devices = (
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"_get_e2e_device_keys_txn",
|
"_get_e2e_device_keys_txn",
|
||||||
self._get_e2e_device_keys_txn,
|
self._get_e2e_device_keys_txn,
|
||||||
query_map.keys(),
|
query_map.keys(),
|
||||||
@ -271,7 +269,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
for user_id, user_devices in devices.items():
|
for user_id, user_devices in devices.items():
|
||||||
# The prev_id for the first row is always the last row before
|
# The prev_id for the first row is always the last row before
|
||||||
# `from_stream_id`
|
# `from_stream_id`
|
||||||
prev_id = yield self._get_last_device_update_for_remote_user(
|
prev_id = await self._get_last_device_update_for_remote_user(
|
||||||
destination, user_id, from_stream_id
|
destination, user_id, from_stream_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -315,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def _get_last_device_update_for_remote_user(
|
def _get_last_device_update_for_remote_user(
|
||||||
self, destination, user_id, from_stream_id
|
self, destination: str, user_id: str, from_stream_id: int
|
||||||
):
|
):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
prev_sent_id_sql = """
|
prev_sent_id_sql = """
|
||||||
@ -329,7 +327,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
|
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
|
||||||
|
|
||||||
def mark_as_sent_devices_by_remote(self, destination, stream_id):
|
def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
|
||||||
"""Mark that updates have successfully been sent to the destination.
|
"""Mark that updates have successfully been sent to the destination.
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
return self.db_pool.runInteraction(
|
||||||
@ -339,7 +337,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
stream_id,
|
stream_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
|
def _mark_as_sent_devices_by_remote_txn(
|
||||||
|
self, txn: LoggingTransaction, destination: str, stream_id: int
|
||||||
|
) -> None:
|
||||||
# We update the device_lists_outbound_last_success with the successfully
|
# We update the device_lists_outbound_last_success with the successfully
|
||||||
# poked users.
|
# poked users.
|
||||||
sql = """
|
sql = """
|
||||||
@ -367,17 +367,21 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
txn.execute(sql, (destination, stream_id))
|
txn.execute(sql, (destination, stream_id))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def add_user_signature_change_to_streams(
|
||||||
def add_user_signature_change_to_streams(self, from_user_id, user_ids):
|
self, from_user_id: str, user_ids: List[str]
|
||||||
|
) -> int:
|
||||||
"""Persist that a user has made new signatures
|
"""Persist that a user has made new signatures
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
from_user_id (str): the user who made the signatures
|
from_user_id: the user who made the signatures
|
||||||
user_ids (list[str]): the users who were signed
|
user_ids: the users who were signed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
THe new stream ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self._device_list_id_gen.get_next() as stream_id:
|
with self._device_list_id_gen.get_next() as stream_id:
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_user_sig_change_to_streams",
|
"add_user_sig_change_to_streams",
|
||||||
self._add_user_signature_change_txn,
|
self._add_user_signature_change_txn,
|
||||||
from_user_id,
|
from_user_id,
|
||||||
@ -386,7 +390,13 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
return stream_id
|
return stream_id
|
||||||
|
|
||||||
def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
|
def _add_user_signature_change_txn(
|
||||||
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
from_user_id: str,
|
||||||
|
user_ids: List[str],
|
||||||
|
stream_id: int,
|
||||||
|
) -> None:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self._user_signature_stream_cache.entity_has_changed,
|
self._user_signature_stream_cache.entity_has_changed,
|
||||||
from_user_id,
|
from_user_id,
|
||||||
@ -402,29 +412,30 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_device_stream_token(self):
|
def get_device_stream_token(self) -> int:
|
||||||
return self._device_list_id_gen.get_current_token()
|
return self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
@defer.inlineCallbacks
|
async def get_user_devices_from_cache(
|
||||||
def get_user_devices_from_cache(self, query_list):
|
self, query_list: List[Tuple[str, str]]
|
||||||
|
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
|
||||||
"""Get the devices (and keys if any) for remote users from the cache.
|
"""Get the devices (and keys if any) for remote users from the cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_list(list): List of (user_id, device_ids), if device_ids is
|
query_list: List of (user_id, device_ids), if device_ids is
|
||||||
falsey then return all device ids for that user.
|
falsey then return all device ids for that user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
|
A tuple of (user_ids_not_in_cache, results_map), where
|
||||||
a set of user_ids and results_map is a mapping of
|
user_ids_not_in_cache is a set of user_ids and results_map is a
|
||||||
user_id -> device_id -> device_info
|
mapping of user_id -> device_id -> device_info.
|
||||||
"""
|
"""
|
||||||
user_ids = {user_id for user_id, _ in query_list}
|
user_ids = {user_id for user_id, _ in query_list}
|
||||||
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
||||||
|
|
||||||
# We go and check if any of the users need to have their device lists
|
# We go and check if any of the users need to have their device lists
|
||||||
# resynced. If they do then we remove them from the cached list.
|
# resynced. If they do then we remove them from the cached list.
|
||||||
users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
|
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
|
||||||
user_ids
|
user_ids
|
||||||
)
|
)
|
||||||
user_ids_in_cache = {
|
user_ids_in_cache = {
|
||||||
@ -438,19 +449,19 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if device_id:
|
if device_id:
|
||||||
device = yield self._get_cached_user_device(user_id, device_id)
|
device = await self._get_cached_user_device(user_id, device_id)
|
||||||
results.setdefault(user_id, {})[device_id] = device
|
results.setdefault(user_id, {})[device_id] = device
|
||||||
else:
|
else:
|
||||||
results[user_id] = yield self.get_cached_devices_for_user(user_id)
|
results[user_id] = await self.get_cached_devices_for_user(user_id)
|
||||||
|
|
||||||
set_tag("in_cache", results)
|
set_tag("in_cache", results)
|
||||||
set_tag("not_in_cache", user_ids_not_in_cache)
|
set_tag("not_in_cache", user_ids_not_in_cache)
|
||||||
|
|
||||||
return user_ids_not_in_cache, results
|
return user_ids_not_in_cache, results
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, tree=True)
|
@cached(num_args=2, tree=True)
|
||||||
def _get_cached_user_device(self, user_id, device_id):
|
async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
|
||||||
content = yield self.db_pool.simple_select_one_onecol(
|
content = await self.db_pool.simple_select_one_onecol(
|
||||||
table="device_lists_remote_cache",
|
table="device_lists_remote_cache",
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
retcol="content",
|
retcol="content",
|
||||||
@ -458,9 +469,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
return db_to_json(content)
|
return db_to_json(content)
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cached()
|
||||||
def get_cached_devices_for_user(self, user_id):
|
async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
|
||||||
devices = yield self.db_pool.simple_select_list(
|
devices = await self.db_pool.simple_select_list(
|
||||||
table="device_lists_remote_cache",
|
table="device_lists_remote_cache",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
retcols=("device_id", "content"),
|
retcols=("device_id", "content"),
|
||||||
@ -470,11 +481,11 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
device["device_id"]: db_to_json(device["content"]) for device in devices
|
device["device_id"]: db_to_json(device["content"]) for device in devices
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_devices_with_keys_by_user(self, user_id):
|
def get_devices_with_keys_by_user(self, user_id: str):
|
||||||
"""Get all devices (with any device keys) for a user
|
"""Get all devices (with any device keys) for a user
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(stream_id, devices)
|
Deferred which resolves to (stream_id, devices)
|
||||||
"""
|
"""
|
||||||
return self.db_pool.runInteraction(
|
return self.db_pool.runInteraction(
|
||||||
"get_devices_with_keys_by_user",
|
"get_devices_with_keys_by_user",
|
||||||
@ -482,7 +493,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
|
def _get_devices_with_keys_by_user_txn(
|
||||||
|
self, txn: LoggingTransaction, user_id: str
|
||||||
|
) -> Tuple[int, List[JsonDict]]:
|
||||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
devices = self._get_e2e_device_keys_txn(
|
devices = self._get_e2e_device_keys_txn(
|
||||||
@ -515,17 +528,18 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return now_stream_id, []
|
return now_stream_id, []
|
||||||
|
|
||||||
def get_users_whose_devices_changed(self, from_key, user_ids):
|
async def get_users_whose_devices_changed(
|
||||||
|
self, from_key: str, user_ids: Iterable[str]
|
||||||
|
) -> Set[str]:
|
||||||
"""Get set of users whose devices have changed since `from_key` that
|
"""Get set of users whose devices have changed since `from_key` that
|
||||||
are in the given list of user_ids.
|
are in the given list of user_ids.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
from_key (str): The device lists stream token
|
from_key: The device lists stream token
|
||||||
user_ids (Iterable[str])
|
user_ids: The user IDs to query for devices.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[set[str]]: The set of user_ids whose devices have changed
|
The set of user_ids whose devices have changed since `from_key`
|
||||||
since `from_key`
|
|
||||||
"""
|
"""
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
|
|
||||||
@ -536,7 +550,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not to_check:
|
if not to_check:
|
||||||
return defer.succeed(set())
|
return set()
|
||||||
|
|
||||||
def _get_users_whose_devices_changed_txn(txn):
|
def _get_users_whose_devices_changed_txn(txn):
|
||||||
changes = set()
|
changes = set()
|
||||||
@ -556,18 +570,22 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return changes
|
return changes
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
|
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_users_whose_signatures_changed(
|
||||||
def get_users_whose_signatures_changed(self, user_id, from_key):
|
self, user_id: str, from_key: str
|
||||||
|
) -> Set[str]:
|
||||||
"""Get the users who have new cross-signing signatures made by `user_id` since
|
"""Get the users who have new cross-signing signatures made by `user_id` since
|
||||||
`from_key`.
|
`from_key`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): the user who made the signatures
|
user_id: the user who made the signatures
|
||||||
from_key (str): The device lists stream token
|
from_key: The device lists stream token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of user IDs with updated signatures.
|
||||||
"""
|
"""
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
|
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
|
||||||
@ -575,7 +593,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
SELECT DISTINCT user_ids FROM user_signature_stream
|
SELECT DISTINCT user_ids FROM user_signature_stream
|
||||||
WHERE from_user_id = ? AND stream_id > ?
|
WHERE from_user_id = ? AND stream_id > ?
|
||||||
"""
|
"""
|
||||||
rows = yield self.db_pool.execute(
|
rows = await self.db_pool.execute(
|
||||||
"get_users_whose_signatures_changed", None, sql, user_id, from_key
|
"get_users_whose_signatures_changed", None, sql, user_id, from_key
|
||||||
)
|
)
|
||||||
return {user for row in rows for user in db_to_json(row[0])}
|
return {user for row in rows for user in db_to_json(row[0])}
|
||||||
@ -638,7 +656,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=10000)
|
@cached(max_entries=10000)
|
||||||
def get_device_list_last_stream_id_for_remote(self, user_id):
|
def get_device_list_last_stream_id_for_remote(self, user_id: str):
|
||||||
"""Get the last stream_id we got for a user. May be None if we haven't
|
"""Get the last stream_id we got for a user. May be None if we haven't
|
||||||
got any information for them.
|
got any information for them.
|
||||||
"""
|
"""
|
||||||
@ -655,7 +673,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
list_name="user_ids",
|
list_name="user_ids",
|
||||||
inlineCallbacks=True,
|
inlineCallbacks=True,
|
||||||
)
|
)
|
||||||
def get_device_list_last_stream_id_for_remotes(self, user_ids):
|
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = yield self.db_pool.simple_select_many_batch(
|
||||||
table="device_lists_remote_extremeties",
|
table="device_lists_remote_extremeties",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
@ -669,8 +687,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_user_ids_requiring_device_list_resync(
|
||||||
def get_user_ids_requiring_device_list_resync(
|
|
||||||
self, user_ids: Optional[Collection[str]] = None,
|
self, user_ids: Optional[Collection[str]] = None,
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Given a list of remote users return the list of users that we
|
"""Given a list of remote users return the list of users that we
|
||||||
@ -681,7 +698,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
The IDs of users whose device lists need resync.
|
The IDs of users whose device lists need resync.
|
||||||
"""
|
"""
|
||||||
if user_ids:
|
if user_ids:
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="device_lists_remote_resync",
|
table="device_lists_remote_resync",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
@ -689,7 +706,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rows = yield self.db_pool.simple_select_list(
|
rows = await self.db_pool.simple_select_list(
|
||||||
table="device_lists_remote_resync",
|
table="device_lists_remote_resync",
|
||||||
keyvalues=None,
|
keyvalues=None,
|
||||||
retcols=("user_id",),
|
retcols=("user_id",),
|
||||||
@ -710,7 +727,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
desc="make_remote_user_device_cache_as_stale",
|
desc="make_remote_user_device_cache_as_stale",
|
||||||
)
|
)
|
||||||
|
|
||||||
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
|
def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
|
||||||
"""Mark that we no longer track device lists for remote user.
|
"""Mark that we no longer track device lists for remote user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -779,16 +796,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||||||
"drop_device_lists_outbound_last_success_non_unique_idx",
|
"drop_device_lists_outbound_last_success_non_unique_idx",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
||||||
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
|
||||||
def f(conn):
|
def f(conn):
|
||||||
txn = conn.cursor()
|
txn = conn.cursor()
|
||||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
||||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
||||||
txn.close()
|
txn.close()
|
||||||
|
|
||||||
yield self.db_pool.runWithConnection(f)
|
await self.db_pool.runWithConnection(f)
|
||||||
yield self.db_pool.updates._end_background_update(
|
await self.db_pool.updates._end_background_update(
|
||||||
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
|
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
|
||||||
)
|
)
|
||||||
return 1
|
return 1
|
||||||
@ -868,18 +884,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
|
|
||||||
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
|
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def store_device(
|
||||||
def store_device(self, user_id, device_id, initial_device_display_name):
|
self, user_id: str, device_id: str, initial_device_display_name: str
|
||||||
|
) -> bool:
|
||||||
"""Ensure the given device is known; add it to the store if not
|
"""Ensure the given device is known; add it to the store if not
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): id of user associated with the device
|
user_id: id of user associated with the device
|
||||||
device_id (str): id of device
|
device_id: id of device
|
||||||
initial_device_display_name (str): initial displayname of the
|
initial_device_display_name: initial displayname of the device.
|
||||||
device. Ignored if device exists.
|
Ignored if device exists.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred: boolean whether the device was inserted or an
|
Whether the device was inserted or an existing device existed with that ID.
|
||||||
existing device existed with that ID.
|
|
||||||
Raises:
|
Raises:
|
||||||
StoreError: if the device is already in use
|
StoreError: if the device is already in use
|
||||||
"""
|
"""
|
||||||
@ -888,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
inserted = yield self.db_pool.simple_insert(
|
inserted = await self.db_pool.simple_insert(
|
||||||
"devices",
|
"devices",
|
||||||
values={
|
values={
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@ -902,7 +920,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
if not inserted:
|
if not inserted:
|
||||||
# if the device already exists, check if it's a real device, or
|
# if the device already exists, check if it's a real device, or
|
||||||
# if the device ID is reserved by something else
|
# if the device ID is reserved by something else
|
||||||
hidden = yield self.db_pool.simple_select_one_onecol(
|
hidden = await self.db_pool.simple_select_one_onecol(
|
||||||
"devices",
|
"devices",
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
retcol="hidden",
|
retcol="hidden",
|
||||||
@ -927,17 +945,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
raise StoreError(500, "Problem storing device.")
|
raise StoreError(500, "Problem storing device.")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def delete_device(self, user_id: str, device_id: str) -> None:
|
||||||
def delete_device(self, user_id, device_id):
|
|
||||||
"""Delete a device.
|
"""Delete a device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The ID of the user which owns the device
|
user_id: The ID of the user which owns the device
|
||||||
device_id (str): The ID of the device to delete
|
device_id: The ID of the device to delete
|
||||||
Returns:
|
|
||||||
defer.Deferred
|
|
||||||
"""
|
"""
|
||||||
yield self.db_pool.simple_delete_one(
|
await self.db_pool.simple_delete_one(
|
||||||
table="devices",
|
table="devices",
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||||
desc="delete_device",
|
desc="delete_device",
|
||||||
@ -945,17 +960,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
|
|
||||||
self.device_id_exists_cache.invalidate((user_id, device_id))
|
self.device_id_exists_cache.invalidate((user_id, device_id))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
|
||||||
def delete_devices(self, user_id, device_ids):
|
|
||||||
"""Deletes several devices.
|
"""Deletes several devices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The ID of the user which owns the devices
|
user_id: The ID of the user which owns the devices
|
||||||
device_ids (list): The IDs of the devices to delete
|
device_ids: The IDs of the devices to delete
|
||||||
Returns:
|
|
||||||
defer.Deferred
|
|
||||||
"""
|
"""
|
||||||
yield self.db_pool.simple_delete_many(
|
await self.db_pool.simple_delete_many(
|
||||||
table="devices",
|
table="devices",
|
||||||
column="device_id",
|
column="device_id",
|
||||||
iterable=device_ids,
|
iterable=device_ids,
|
||||||
@ -965,26 +977,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
for device_id in device_ids:
|
for device_id in device_ids:
|
||||||
self.device_id_exists_cache.invalidate((user_id, device_id))
|
self.device_id_exists_cache.invalidate((user_id, device_id))
|
||||||
|
|
||||||
def update_device(self, user_id, device_id, new_display_name=None):
|
async def update_device(
|
||||||
|
self, user_id: str, device_id: str, new_display_name: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
"""Update a device. Only updates the device if it is not marked as
|
"""Update a device. Only updates the device if it is not marked as
|
||||||
hidden.
|
hidden.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The ID of the user which owns the device
|
user_id: The ID of the user which owns the device
|
||||||
device_id (str): The ID of the device to update
|
device_id: The ID of the device to update
|
||||||
new_display_name (str|None): new displayname for device; None
|
new_display_name: new displayname for device; None to leave unchanged
|
||||||
to leave unchanged
|
|
||||||
Raises:
|
Raises:
|
||||||
StoreError: if the device is not found
|
StoreError: if the device is not found
|
||||||
Returns:
|
|
||||||
defer.Deferred
|
|
||||||
"""
|
"""
|
||||||
updates = {}
|
updates = {}
|
||||||
if new_display_name is not None:
|
if new_display_name is not None:
|
||||||
updates["display_name"] = new_display_name
|
updates["display_name"] = new_display_name
|
||||||
if not updates:
|
if not updates:
|
||||||
return defer.succeed(None)
|
return None
|
||||||
return self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="devices",
|
table="devices",
|
||||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||||
updatevalues=updates,
|
updatevalues=updates,
|
||||||
@ -992,7 +1003,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def update_remote_device_list_cache_entry(
|
def update_remote_device_list_cache_entry(
|
||||||
self, user_id, device_id, content, stream_id
|
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
|
||||||
):
|
):
|
||||||
"""Updates a single device in the cache of a remote user's devicelist.
|
"""Updates a single device in the cache of a remote user's devicelist.
|
||||||
|
|
||||||
@ -1000,10 +1011,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
device list.
|
device list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): User to update device list for
|
user_id: User to update device list for
|
||||||
device_id (str): ID of decivice being updated
|
device_id: ID of decivice being updated
|
||||||
content (dict): new data on this device
|
content: new data on this device
|
||||||
stream_id (int): the version of the device list
|
stream_id: the version of the device list
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[None]
|
Deferred[None]
|
||||||
@ -1018,8 +1029,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _update_remote_device_list_cache_entry_txn(
|
def _update_remote_device_list_cache_entry_txn(
|
||||||
self, txn, user_id, device_id, content, stream_id
|
self,
|
||||||
):
|
txn: LoggingTransaction,
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
content: JsonDict,
|
||||||
|
stream_id: int,
|
||||||
|
) -> None:
|
||||||
if content.get("deleted"):
|
if content.get("deleted"):
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
@ -1055,16 +1071,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
lock=False,
|
lock=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_remote_device_list_cache(self, user_id, devices, stream_id):
|
def update_remote_device_list_cache(
|
||||||
|
self, user_id: str, devices: List[dict], stream_id: int
|
||||||
|
):
|
||||||
"""Replace the entire cache of the remote user's devices.
|
"""Replace the entire cache of the remote user's devices.
|
||||||
|
|
||||||
Note: assumes that we are the only thread that can be updating this user's
|
Note: assumes that we are the only thread that can be updating this user's
|
||||||
device list.
|
device list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): User to update device list for
|
user_id: User to update device list for
|
||||||
devices (list[dict]): list of device objects supplied over federation
|
devices: list of device objects supplied over federation
|
||||||
stream_id (int): the version of the device list
|
stream_id: the version of the device list
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[None]
|
Deferred[None]
|
||||||
@ -1077,7 +1095,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
stream_id,
|
stream_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
|
def _update_remote_device_list_cache_txn(
|
||||||
|
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
|
||||||
|
):
|
||||||
self.db_pool.simple_delete_txn(
|
self.db_pool.simple_delete_txn(
|
||||||
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
||||||
)
|
)
|
||||||
@ -1118,8 +1138,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
|
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def add_device_change_to_streams(
|
||||||
def add_device_change_to_streams(self, user_id, device_ids, hosts):
|
self, user_id: str, device_ids: Collection[str], hosts: List[str]
|
||||||
|
):
|
||||||
"""Persist that a user's devices have been updated, and which hosts
|
"""Persist that a user's devices have been updated, and which hosts
|
||||||
(if any) should be poked.
|
(if any) should be poked.
|
||||||
"""
|
"""
|
||||||
@ -1127,7 +1148,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
return
|
return
|
||||||
|
|
||||||
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
|
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_device_change_to_stream",
|
"add_device_change_to_stream",
|
||||||
self._add_device_change_to_stream_txn,
|
self._add_device_change_to_stream_txn,
|
||||||
user_id,
|
user_id,
|
||||||
@ -1142,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
with self._device_list_id_gen.get_next_mult(
|
with self._device_list_id_gen.get_next_mult(
|
||||||
len(hosts) * len(device_ids)
|
len(hosts) * len(device_ids)
|
||||||
) as stream_ids:
|
) as stream_ids:
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"add_device_outbound_poke_to_stream",
|
"add_device_outbound_poke_to_stream",
|
||||||
self._add_device_outbound_poke_to_stream_txn,
|
self._add_device_outbound_poke_to_stream_txn,
|
||||||
user_id,
|
user_id,
|
||||||
@ -1187,7 +1208,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _add_device_outbound_poke_to_stream_txn(
|
def _add_device_outbound_poke_to_stream_txn(
|
||||||
self, txn, user_id, device_ids, hosts, stream_ids, context,
|
self,
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
user_id: str,
|
||||||
|
device_ids: Collection[str],
|
||||||
|
hosts: List[str],
|
||||||
|
stream_ids: List[str],
|
||||||
|
context: Dict[str, str],
|
||||||
):
|
):
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
@ -1219,7 +1246,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
|
def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
|
||||||
"""Delete old entries out of the device_lists_outbound_pokes to ensure
|
"""Delete old entries out of the device_lists_outbound_pokes to ensure
|
||||||
that we don't fill up due to dead servers.
|
that we don't fill up due to dead servers.
|
||||||
|
|
||||||
|
@ -116,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||||||
retry_timings_res
|
retry_timings_res
|
||||||
)
|
)
|
||||||
|
|
||||||
self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
|
self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
|
||||||
(0, [])
|
(0, [])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,7 +34,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_store_new_device(self):
|
def test_store_new_device(self):
|
||||||
yield self.store.store_device("user_id", "device_id", "display_name")
|
yield defer.ensureDeferred(
|
||||||
|
self.store.store_device("user_id", "device_id", "display_name")
|
||||||
|
)
|
||||||
|
|
||||||
res = yield self.store.get_device("user_id", "device_id")
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_devices_by_user(self):
|
def test_get_devices_by_user(self):
|
||||||
yield self.store.store_device("user_id", "device1", "display_name 1")
|
yield defer.ensureDeferred(
|
||||||
yield self.store.store_device("user_id", "device2", "display_name 2")
|
self.store.store_device("user_id", "device1", "display_name 1")
|
||||||
yield self.store.store_device("user_id2", "device3", "display_name 3")
|
)
|
||||||
|
yield defer.ensureDeferred(
|
||||||
|
self.store.store_device("user_id", "device2", "display_name 2")
|
||||||
|
)
|
||||||
|
yield defer.ensureDeferred(
|
||||||
|
self.store.store_device("user_id2", "device3", "display_name 3")
|
||||||
|
)
|
||||||
|
|
||||||
res = yield self.store.get_devices_by_user("user_id")
|
res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
|
||||||
self.assertEqual(2, len(res.keys()))
|
self.assertEqual(2, len(res.keys()))
|
||||||
self.assertDictContainsSubset(
|
self.assertDictContainsSubset(
|
||||||
{
|
{
|
||||||
@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||||||
device_ids = ["device_id1", "device_id2"]
|
device_ids = ["device_id1", "device_id2"]
|
||||||
|
|
||||||
# Add two device updates with a single stream_id
|
# Add two device updates with a single stream_id
|
||||||
yield self.store.add_device_change_to_streams(
|
yield defer.ensureDeferred(
|
||||||
"user_id", device_ids, ["somehost"]
|
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all device updates ever meant for this remote
|
# Get all device updates ever meant for this remote
|
||||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
now_stream_id, device_updates = yield defer.ensureDeferred(
|
||||||
"somehost", -1, limit=100
|
self.store.get_device_updates_by_remote("somehost", -1, limit=100)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check original device_ids are contained within these updates
|
# Check original device_ids are contained within these updates
|
||||||
@ -99,20 +107,24 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_update_device(self):
|
def test_update_device(self):
|
||||||
yield self.store.store_device("user_id", "device_id", "display_name 1")
|
yield defer.ensureDeferred(
|
||||||
|
self.store.store_device("user_id", "device_id", "display_name 1")
|
||||||
|
)
|
||||||
|
|
||||||
res = yield self.store.get_device("user_id", "device_id")
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
self.assertEqual("display_name 1", res["display_name"])
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
# do a no-op first
|
# do a no-op first
|
||||||
yield self.store.update_device("user_id", "device_id")
|
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
|
||||||
res = yield self.store.get_device("user_id", "device_id")
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
self.assertEqual("display_name 1", res["display_name"])
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
# do the update
|
# do the update
|
||||||
yield self.store.update_device(
|
yield defer.ensureDeferred(
|
||||||
|
self.store.update_device(
|
||||||
"user_id", "device_id", new_display_name="display_name 2"
|
"user_id", "device_id", new_display_name="display_name 2"
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# check it worked
|
# check it worked
|
||||||
res = yield self.store.get_device("user_id", "device_id")
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
@ -121,7 +133,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_update_unknown_device(self):
|
def test_update_unknown_device(self):
|
||||||
with self.assertRaises(synapse.api.errors.StoreError) as cm:
|
with self.assertRaises(synapse.api.errors.StoreError) as cm:
|
||||||
yield self.store.update_device(
|
yield defer.ensureDeferred(
|
||||||
|
self.store.update_device(
|
||||||
"user_id", "unknown_device_id", new_display_name="display_name 2"
|
"user_id", "unknown_device_id", new_display_name="display_name 2"
|
||||||
)
|
)
|
||||||
|
)
|
||||||
self.assertEqual(404, cm.exception.code)
|
self.assertEqual(404, cm.exception.code)
|
||||||
|
@ -30,7 +30,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||||||
now = 1470174257070
|
now = 1470174257070
|
||||||
json = {"key": "value"}
|
json = {"key": "value"}
|
||||||
|
|
||||||
yield self.store.store_device("user", "device", None)
|
yield defer.ensureDeferred(self.store.store_device("user", "device", None))
|
||||||
|
|
||||||
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||||||
now = 1470174257070
|
now = 1470174257070
|
||||||
json = {"key": "value"}
|
json = {"key": "value"}
|
||||||
|
|
||||||
yield self.store.store_device("user", "device", None)
|
yield defer.ensureDeferred(self.store.store_device("user", "device", None))
|
||||||
|
|
||||||
changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
|
changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
|
||||||
self.assertTrue(changed)
|
self.assertTrue(changed)
|
||||||
@ -63,7 +63,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||||||
json = {"key": "value"}
|
json = {"key": "value"}
|
||||||
|
|
||||||
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
||||||
yield self.store.store_device("user", "device", "display_name")
|
yield defer.ensureDeferred(
|
||||||
|
self.store.store_device("user", "device", "display_name")
|
||||||
|
)
|
||||||
|
|
||||||
res = yield defer.ensureDeferred(
|
res = yield defer.ensureDeferred(
|
||||||
self.store.get_e2e_device_keys((("user", "device"),))
|
self.store.get_e2e_device_keys((("user", "device"),))
|
||||||
@ -79,10 +81,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||||||
def test_multiple_devices(self):
|
def test_multiple_devices(self):
|
||||||
now = 1470174257070
|
now = 1470174257070
|
||||||
|
|
||||||
yield self.store.store_device("user1", "device1", None)
|
yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
|
||||||
yield self.store.store_device("user1", "device2", None)
|
yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
|
||||||
yield self.store.store_device("user2", "device1", None)
|
yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
|
||||||
yield self.store.store_device("user2", "device2", None)
|
yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
|
||||||
|
|
||||||
yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
|
yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
|
||||||
yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
|
yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
|
||||||
|
Loading…
Reference in New Issue
Block a user