Convert device handler to async/await (#7871)

This commit is contained in:
Patrick Cloke 2020-07-17 07:09:25 -04:00 committed by GitHub
parent 00e57b755c
commit 6b3ac3b8cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 162 additions and 166 deletions

1
changelog.d/7871.misc Normal file
View File

@ -0,0 +1 @@
Convert device handler to async/await.

View File

@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
from twisted.internet import defer
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -57,21 +55,20 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
@trace @trace
@defer.inlineCallbacks async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
def get_devices_by_user(self, user_id):
""" """
Retrieve the given user's devices Retrieve the given user's devices
Args: Args:
user_id (str): user_id: The user ID to query for devices.
Returns: Returns:
defer.Deferred: list[dict[str, X]]: info on each device info on each device
""" """
set_tag("user_id", user_id) set_tag("user_id", user_id)
device_map = yield self.store.get_devices_by_user(user_id) device_map = await self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None) ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None)
devices = list(device_map.values()) devices = list(device_map.values())
for device in devices: for device in devices:
@ -81,24 +78,23 @@ class DeviceWorkerHandler(BaseHandler):
return devices return devices
@trace @trace
@defer.inlineCallbacks async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
def get_device(self, user_id, device_id):
""" Retrieve the given device """ Retrieve the given device
Args: Args:
user_id (str): user_id: The user to get the device from
device_id (str): device_id: The device to fetch.
Returns: Returns:
defer.Deferred: dict[str, X]: info on the device info on the device
Raises: Raises:
errors.NotFoundError: if the device was not found errors.NotFoundError: if the device was not found
""" """
try: try:
device = yield self.store.get_device(user_id, device_id) device = await self.store.get_device(user_id, device_id)
except errors.StoreError: except errors.StoreError:
raise errors.NotFoundError raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id) ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips) _update_device_from_client_ips(device, ips)
set_tag("device", device) set_tag("device", device)
@ -106,10 +102,9 @@ class DeviceWorkerHandler(BaseHandler):
return device return device
@measure_func("device.get_user_ids_changed")
@trace @trace
@defer.inlineCallbacks @measure_func("device.get_user_ids_changed")
def get_user_ids_changed(self, user_id, from_token): async def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly """Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in. joined a room, that `user_id` may be interested in.
@ -120,13 +115,13 @@ class DeviceWorkerHandler(BaseHandler):
set_tag("user_id", user_id) set_tag("user_id", user_id)
set_tag("from_token", from_token) set_tag("from_token", from_token)
now_room_key = yield self.store.get_room_events_max_id() now_room_key = await self.store.get_room_events_max_id()
room_ids = yield self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed for users that we share # First we check if any devices have changed for users that we share
# rooms with. # rooms with.
users_who_share_room = yield self.store.get_users_who_share_room_with_user( users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id user_id
) )
@ -135,14 +130,14 @@ class DeviceWorkerHandler(BaseHandler):
# Always tell the user about their own devices # Always tell the user about their own devices
tracked_users.add(user_id) tracked_users.add(user_id)
changed = yield self.store.get_users_whose_devices_changed( changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users from_token.device_list_key, tracked_users
) )
# Then work out if any users have since joined # Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user( member_events = await self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_room_key user_id, from_token.room_key, now_room_key
) )
rooms_changed.update(event.room_id for event in member_events) rooms_changed.update(event.room_id for event in member_events)
@ -152,7 +147,7 @@ class DeviceWorkerHandler(BaseHandler):
possibly_changed = set(changed) possibly_changed = set(changed)
possibly_left = set() possibly_left = set()
for room_id in rooms_changed: for room_id in rooms_changed:
current_state_ids = yield self.store.get_current_state_ids(room_id) current_state_ids = await self.store.get_current_state_ids(room_id)
# The user may have left the room # The user may have left the room
# TODO: Check if they actually did or if we were just invited. # TODO: Check if they actually did or if we were just invited.
@ -166,7 +161,7 @@ class DeviceWorkerHandler(BaseHandler):
# Fetch the current state at the time. # Fetch the current state at the time.
try: try:
event_ids = yield self.store.get_forward_extremeties_for_room( event_ids = await self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering room_id, stream_ordering=stream_ordering
) )
except errors.StoreError: except errors.StoreError:
@ -192,7 +187,7 @@ class DeviceWorkerHandler(BaseHandler):
continue continue
# mapping from event_id -> state_dict # mapping from event_id -> state_dict
prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids) prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to # Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users. # the "possibly changed" users.
@ -238,11 +233,10 @@ class DeviceWorkerHandler(BaseHandler):
return result return result
@defer.inlineCallbacks async def on_federation_query_user_devices(self, user_id):
def on_federation_query_user_devices(self, user_id): stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") self_signing_key = await self.store.get_e2e_cross_signing_key(
self_signing_key = yield self.store.get_e2e_cross_signing_key(
user_id, "self_signing" user_id, "self_signing"
) )
@ -271,8 +265,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks async def check_device_registered(
def check_device_registered(
self, user_id, device_id, initial_device_display_name=None self, user_id, device_id, initial_device_display_name=None
): ):
""" """
@ -290,13 +283,13 @@ class DeviceHandler(DeviceWorkerHandler):
str: device id (generated if none was supplied) str: device id (generated if none was supplied)
""" """
if device_id is not None: if device_id is not None:
new_device = yield self.store.store_device( new_device = await self.store.store_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_device_display_name=initial_device_display_name, initial_device_display_name=initial_device_display_name,
) )
if new_device: if new_device:
yield self.notify_device_update(user_id, [device_id]) await self.notify_device_update(user_id, [device_id])
return device_id return device_id
# if the device id is not specified, we'll autogen one, but loop a few # if the device id is not specified, we'll autogen one, but loop a few
@ -304,33 +297,29 @@ class DeviceHandler(DeviceWorkerHandler):
attempts = 0 attempts = 0
while attempts < 5: while attempts < 5:
device_id = stringutils.random_string(10).upper() device_id = stringutils.random_string(10).upper()
new_device = yield self.store.store_device( new_device = await self.store.store_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_device_display_name=initial_device_display_name, initial_device_display_name=initial_device_display_name,
) )
if new_device: if new_device:
yield self.notify_device_update(user_id, [device_id]) await self.notify_device_update(user_id, [device_id])
return device_id return device_id
attempts += 1 attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.") raise errors.StoreError(500, "Couldn't generate a device ID.")
@trace @trace
@defer.inlineCallbacks async def delete_device(self, user_id: str, device_id: str) -> None:
def delete_device(self, user_id, device_id):
""" Delete the given device """ Delete the given device
Args: Args:
user_id (str): user_id: The user to delete the device from.
device_id (str): device_id: The device to delete.
Returns:
defer.Deferred:
""" """
try: try:
yield self.store.delete_device(user_id, device_id) await self.store.delete_device(user_id, device_id)
except errors.StoreError as e: except errors.StoreError as e:
if e.code == 404: if e.code == 404:
# no match # no match
@ -342,49 +331,40 @@ class DeviceHandler(DeviceWorkerHandler):
else: else:
raise raise
yield defer.ensureDeferred( await self._auth_handler.delete_access_tokens_for_user(
self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id user_id, device_id=device_id
) )
)
yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
yield self.notify_device_update(user_id, [device_id]) await self.notify_device_update(user_id, [device_id])
@trace @trace
@defer.inlineCallbacks async def delete_all_devices_for_user(
def delete_all_devices_for_user(self, user_id, except_device_id=None): self, user_id: str, except_device_id: Optional[str] = None
) -> None:
"""Delete all of the user's devices """Delete all of the user's devices
Args: Args:
user_id (str): user_id: The user to remove all devices from
except_device_id (str|None): optional device id which should not except_device_id: optional device id which should not be deleted
be deleted
Returns:
defer.Deferred:
""" """
device_map = yield self.store.get_devices_by_user(user_id) device_map = await self.store.get_devices_by_user(user_id)
device_ids = list(device_map) device_ids = list(device_map)
if except_device_id is not None: if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id] device_ids = [d for d in device_ids if d != except_device_id]
yield self.delete_devices(user_id, device_ids) await self.delete_devices(user_id, device_ids)
@defer.inlineCallbacks async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
def delete_devices(self, user_id, device_ids):
""" Delete several devices """ Delete several devices
Args: Args:
user_id (str): user_id: The user to delete devices from.
device_ids (List[str]): The list of device IDs to delete device_ids: The list of device IDs to delete
Returns:
defer.Deferred:
""" """
try: try:
yield self.store.delete_devices(user_id, device_ids) await self.store.delete_devices(user_id, device_ids)
except errors.StoreError as e: except errors.StoreError as e:
if e.code == 404: if e.code == 404:
# no match # no match
@ -397,28 +377,22 @@ class DeviceHandler(DeviceWorkerHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not # Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path. # considered as part of a critical path.
for device_id in device_ids: for device_id in device_ids:
yield defer.ensureDeferred( await self._auth_handler.delete_access_tokens_for_user(
self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id user_id, device_id=device_id
) )
) await self.store.delete_e2e_keys_by_device(
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id user_id=user_id, device_id=device_id
) )
yield self.notify_device_update(user_id, device_ids) await self.notify_device_update(user_id, device_ids)
@defer.inlineCallbacks async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
def update_device(self, user_id, device_id, content):
""" Update the given device """ Update the given device
Args: Args:
user_id (str): user_id: The user to update devices of.
device_id (str): device_id: The device to update.
content (dict): body of update request content: body of update request
Returns:
defer.Deferred:
""" """
# Reject a new displayname which is too long. # Reject a new displayname which is too long.
@ -431,10 +405,10 @@ class DeviceHandler(DeviceWorkerHandler):
) )
try: try:
yield self.store.update_device( await self.store.update_device(
user_id, device_id, new_display_name=new_display_name user_id, device_id, new_display_name=new_display_name
) )
yield self.notify_device_update(user_id, [device_id]) await self.notify_device_update(user_id, [device_id])
except errors.StoreError as e: except errors.StoreError as e:
if e.code == 404: if e.code == 404:
raise errors.NotFoundError() raise errors.NotFoundError()
@ -443,12 +417,11 @@ class DeviceHandler(DeviceWorkerHandler):
@trace @trace
@measure_func("notify_device_update") @measure_func("notify_device_update")
@defer.inlineCallbacks async def notify_device_update(self, user_id, device_ids):
def notify_device_update(self, user_id, device_ids):
"""Notify that a user's device(s) has changed. Pokes the notifier, and """Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local. remote servers if the user is local.
""" """
users_who_share_room = yield self.store.get_users_who_share_room_with_user( users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id user_id
) )
@ -459,7 +432,7 @@ class DeviceHandler(DeviceWorkerHandler):
set_tag("target_hosts", hosts) set_tag("target_hosts", hosts)
position = yield self.store.add_device_change_to_streams( position = await self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts) user_id, device_ids, list(hosts)
) )
@ -468,11 +441,11 @@ class DeviceHandler(DeviceWorkerHandler):
"Notifying about update %r/%r, ID: %r", user_id, device_id, position "Notifying about update %r/%r, ID: %r", user_id, device_id, position
) )
room_ids = yield self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
# specify the user ID too since the user should always get their own device list # specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms. # updates, even if they aren't in any rooms.
yield self.notifier.on_new_event( self.notifier.on_new_event(
"device_list_key", position, users=[user_id], rooms=room_ids "device_list_key", position, users=[user_id], rooms=room_ids
) )
@ -484,29 +457,29 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender.send_device_messages(host) self.federation_sender.send_device_messages(host)
log_kv({"message": "sent device update to host", "host": host}) log_kv({"message": "sent device update to host", "host": host})
@defer.inlineCallbacks async def notify_user_signature_update(
def notify_user_signature_update(self, from_user_id, user_ids): self, from_user_id: str, user_ids: List[str]
) -> None:
"""Notify a user that they have made new signatures of other users. """Notify a user that they have made new signatures of other users.
Args: Args:
from_user_id (str): the user who made the signature from_user_id: the user who made the signature
user_ids (list[str]): the users IDs that have new signatures user_ids: the users IDs that have new signatures
""" """
position = yield self.store.add_user_signature_change_to_streams( position = await self.store.add_user_signature_change_to_streams(
from_user_id, user_ids from_user_id, user_ids
) )
self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
@defer.inlineCallbacks async def user_left_room(self, user, room_id):
def user_left_room(self, user, room_id):
user_id = user.to_string() user_id = user.to_string()
room_ids = yield self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids: if not room_ids:
# We no longer share rooms with this user, so we'll no longer # We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB. # receive device updates. Mark this in DB.
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id) await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
def _update_device_from_client_ips(device, client_ips): def _update_device_from_client_ips(device, client_ips):
@ -549,8 +522,7 @@ class DeviceListUpdater(object):
) )
@trace @trace
@defer.inlineCallbacks async def incoming_device_list_update(self, origin, edu_content):
def incoming_device_list_update(self, origin, edu_content):
"""Called on incoming device list update from federation. Responsible """Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list. for parsing the EDU and adding to pending updates list.
""" """
@ -583,7 +555,7 @@ class DeviceListUpdater(object):
) )
return return
room_ids = yield self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids: if not room_ids:
# We don't share any rooms with this user. Ignore update, as we # We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates. # probably won't get any further updates.
@ -608,14 +580,13 @@ class DeviceListUpdater(object):
(device_id, stream_id, prev_ids, edu_content) (device_id, stream_id, prev_ids, edu_content)
) )
yield self._handle_device_updates(user_id) await self._handle_device_updates(user_id)
@measure_func("_incoming_device_list_update") @measure_func("_incoming_device_list_update")
@defer.inlineCallbacks async def _handle_device_updates(self, user_id):
def _handle_device_updates(self, user_id):
"Actually handle pending updates." "Actually handle pending updates."
with (yield self._remote_edu_linearizer.queue(user_id)): with (await self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, []) pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates: if not pending_updates:
# This can happen since we batch updates # This can happen since we batch updates
@ -632,7 +603,7 @@ class DeviceListUpdater(object):
# Given a list of updates we check if we need to resync. This # Given a list of updates we check if we need to resync. This
# happens if we've missed updates. # happens if we've missed updates.
resync = yield self._need_to_do_resync(user_id, pending_updates) resync = await self._need_to_do_resync(user_id, pending_updates)
if logger.isEnabledFor(logging.INFO): if logger.isEnabledFor(logging.INFO):
logger.info( logger.info(
@ -643,16 +614,16 @@ class DeviceListUpdater(object):
) )
if resync: if resync:
yield self.user_device_resync(user_id) await self.user_device_resync(user_id)
else: else:
# Simply update the single device, since we know that is the only # Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache) # change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates: for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry( await self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id user_id, device_id, content, stream_id
) )
yield self.device_handler.notify_device_update( await self.device_handler.notify_device_update(
user_id, [device_id for device_id, _, _, _ in pending_updates] user_id, [device_id for device_id, _, _, _ in pending_updates]
) )
@ -660,14 +631,13 @@ class DeviceListUpdater(object):
stream_id for _, stream_id, _, _ in pending_updates stream_id for _, stream_id, _, _ in pending_updates
) )
@defer.inlineCallbacks async def _need_to_do_resync(self, user_id, updates):
def _need_to_do_resync(self, user_id, updates):
"""Given a list of updates for a user figure out if we need to do a full """Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta. resync, or whether we have enough data that we can just apply the delta.
""" """
seen_updates = self._seen_updates.get(user_id, set()) seen_updates = self._seen_updates.get(user_id, set())
extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id) extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
logger.debug("Current extremity for %r: %r", user_id, extremity) logger.debug("Current extremity for %r: %r", user_id, extremity)
@ -692,8 +662,7 @@ class DeviceListUpdater(object):
return False return False
@trace @trace
@defer.inlineCallbacks async def _maybe_retry_device_resync(self):
def _maybe_retry_device_resync(self):
"""Retry to resync device lists that are out of sync, except if another retry is """Retry to resync device lists that are out of sync, except if another retry is
in progress. in progress.
""" """
@ -705,12 +674,12 @@ class DeviceListUpdater(object):
# we don't send too many requests. # we don't send too many requests.
self._resync_retry_in_progress = True self._resync_retry_in_progress = True
# Get all of the users that need resyncing. # Get all of the users that need resyncing.
need_resync = yield self.store.get_user_ids_requiring_device_list_resync() need_resync = await self.store.get_user_ids_requiring_device_list_resync()
# Iterate over the set of user IDs. # Iterate over the set of user IDs.
for user_id in need_resync: for user_id in need_resync:
try: try:
# Try to resync the current user's devices list. # Try to resync the current user's devices list.
result = yield self.user_device_resync( result = await self.user_device_resync(
user_id=user_id, mark_failed_as_stale=False, user_id=user_id, mark_failed_as_stale=False,
) )
@ -734,16 +703,17 @@ class DeviceListUpdater(object):
# Allow future calls to retry resyncinc out of sync device lists. # Allow future calls to retry resyncinc out of sync device lists.
self._resync_retry_in_progress = False self._resync_retry_in_progress = False
@defer.inlineCallbacks async def user_device_resync(
def user_device_resync(self, user_id, mark_failed_as_stale=True): self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[dict]:
"""Fetches all devices for a user and updates the device cache with them. """Fetches all devices for a user and updates the device cache with them.
Args: Args:
user_id (str): The user's id whose device_list will be updated. user_id: The user's id whose device_list will be updated.
mark_failed_as_stale (bool): Whether to mark the user's device list as stale mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed. if the attempt to resync failed.
Returns: Returns:
Deferred[dict]: a dict with device info as under the "devices" in the result of this A dict with device info as under the "devices" in the result of this
request: request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
""" """
@ -752,12 +722,12 @@ class DeviceListUpdater(object):
# Fetch all devices for the user. # Fetch all devices for the user.
origin = get_domain_from_id(user_id) origin = get_domain_from_id(user_id)
try: try:
result = yield self.federation.query_user_devices(origin, user_id) result = await self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination: except NotRetryingDestination:
if mark_failed_as_stale: if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry # Mark the remote user's device list as stale so we know we need to retry
# it later. # it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id) await self.store.mark_remote_user_device_cache_as_stale(user_id)
return return
except (RequestSendFailed, HttpResponseException) as e: except (RequestSendFailed, HttpResponseException) as e:
@ -768,7 +738,7 @@ class DeviceListUpdater(object):
if mark_failed_as_stale: if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry # Mark the remote user's device list as stale so we know we need to retry
# it later. # it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id) await self.store.mark_remote_user_device_cache_as_stale(user_id)
# We abort on exceptions rather than accepting the update # We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list # as otherwise synapse will 'forget' that its device list
@ -792,7 +762,7 @@ class DeviceListUpdater(object):
if mark_failed_as_stale: if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry # Mark the remote user's device list as stale so we know we need to retry
# it later. # it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id) await self.store.mark_remote_user_device_cache_as_stale(user_id)
return return
log_kv({"result": result}) log_kv({"result": result})
@ -833,25 +803,24 @@ class DeviceListUpdater(object):
stream_id, stream_id,
) )
yield self.store.update_remote_device_list_cache(user_id, devices, stream_id) await self.store.update_remote_device_list_cache(user_id, devices, stream_id)
device_ids = [device["device_id"] for device in devices] device_ids = [device["device_id"] for device in devices]
# Handle cross-signing keys. # Handle cross-signing keys.
cross_signing_device_ids = yield self.process_cross_signing_key_update( cross_signing_device_ids = await self.process_cross_signing_key_update(
user_id, master_key, self_signing_key, user_id, master_key, self_signing_key,
) )
device_ids = device_ids + cross_signing_device_ids device_ids = device_ids + cross_signing_device_ids
yield self.device_handler.notify_device_update(user_id, device_ids) await self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given # We clobber the seen updates since we've re-synced from a given
# point. # point.
self._seen_updates[user_id] = {stream_id} self._seen_updates[user_id] = {stream_id}
defer.returnValue(result) return result
@defer.inlineCallbacks async def process_cross_signing_key_update(
def process_cross_signing_key_update(
self, self,
user_id: str, user_id: str,
master_key: Optional[Dict[str, Any]], master_key: Optional[Dict[str, Any]],
@ -872,14 +841,14 @@ class DeviceListUpdater(object):
device_ids = [] device_ids = []
if master_key: if master_key:
yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
_, verify_key = get_verify_key_from_cross_signing_key(master_key) _, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses # verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the # .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID # algorithm and colon, which is the device ID
device_ids.append(verify_key.version) device_ids.append(verify_key.version)
if self_signing_key: if self_signing_key:
yield self.store.set_e2e_cross_signing_key( await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key user_id, "self_signing", self_signing_key
) )
_, verify_key = get_verify_key_from_cross_signing_key(self_signing_key) _, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)

View File

@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred, fail, succeed
from twisted.python import failure
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -79,6 +81,28 @@ class Distributor(object):
run_as_background_process(name, self.signals[name].fire, *args, **kwargs) run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
def maybeAwaitableDeferred(f, *args, **kw):
"""
Invoke a function that may or may not return a Deferred or an Awaitable.
This is a modified version of twisted.internet.defer.maybeDeferred.
"""
try:
result = f(*args, **kw)
except Exception:
return fail(failure.Failure(captureVars=Deferred.debug))
if isinstance(result, Deferred):
return result
# Handle the additional case of an awaitable being returned.
elif inspect.isawaitable(result):
return defer.ensureDeferred(result)
elif isinstance(result, failure.Failure):
return fail(result)
else:
return succeed(result)
class Signal(object): class Signal(object):
"""A Signal is a dispatch point that stores a list of callables as """A Signal is a dispatch point that stores a list of callables as
observers of it. observers of it.
@ -122,7 +146,7 @@ class Signal(object):
), ),
) )
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb) return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
deferreds = [run_in_background(do, o) for o in self.observers] deferreds = [run_in_background(do, o) for o in self.observers]

View File

@ -142,10 +142,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.delete_device(user1, "abc")) self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted # check the device was deleted
res = self.handler.get_device(user1, "abc") self.get_failure(
self.pump() self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
self.assertIsInstance(
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
) )
# we'd like to check the access token was invalidated, but that's a # we'd like to check the access token was invalidated, but that's a
@ -180,10 +178,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
def test_update_unknown_device(self): def test_update_unknown_device(self):
update = {"display_name": "new_display"} update = {"display_name": "new_display"}
res = self.handler.update_device("user_id", "unknown_device_id", update) self.get_failure(
self.pump() self.handler.update_device("user_id", "unknown_device_id", update),
self.assertIsInstance( synapse.api.errors.NotFoundError,
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
) )
def _record_users(self): def _record_users(self):

View File

@ -334,11 +334,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
res = None res = None
try: try:
yield self.hs.get_device_handler().check_device_registered( yield defer.ensureDeferred(
self.hs.get_device_handler().check_device_registered(
user_id=local_user, user_id=local_user,
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name", initial_device_display_name="new display name",
) )
)
except errors.SynapseError as e: except errors.SynapseError as e:
res = e.code res = e.code
self.assertEqual(res, 400) self.assertEqual(res, 400)

View File

@ -173,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register a mock on the store so that the incoming update doesn't fail because # Register a mock on the store so that the incoming update doesn't fail because
# we don't share a room with the user. # we don't share a room with the user.
store = self.homeserver.get_datastore() store = self.homeserver.get_datastore()
store.get_rooms_for_user = Mock(return_value=["!someroom:test"]) store.get_rooms_for_user = Mock(return_value=succeed(["!someroom:test"]))
# Manually inject a fake device list update. We need this update to include at # Manually inject a fake device list update. We need this update to include at
# least one prev_id so that the user's device list will need to be retried. # least one prev_id so that the user's device list will need to be retried.
@ -218,7 +218,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client. # Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client() federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock( federation_client.query_user_devices = Mock(
return_value={ return_value=succeed(
{
"user_id": remote_user_id, "user_id": remote_user_id,
"stream_id": 1, "stream_id": 1,
"devices": [], "devices": [],
@ -231,11 +232,13 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
"user_id": remote_user_id, "user_id": remote_user_id,
"usage": ["self_signing"], "usage": ["self_signing"],
"keys": { "keys": {
"ed25519:" + remote_self_signing_key: remote_self_signing_key "ed25519:"
+ remote_self_signing_key: remote_self_signing_key
}, },
}, },
} }
) )
)
# Resync the device list. # Resync the device list.
device_handler = self.homeserver.get_device_handler() device_handler = self.homeserver.get_device_handler()