Add basic implementation of local device list changes

This commit is contained in:
Erik Johnston 2017-01-25 14:27:27 +00:00
parent ba8e144554
commit 2367c5568c
14 changed files with 348 additions and 39 deletions

View file

@ -15,6 +15,7 @@
from synapse.api import errors
from synapse.util import stringutils
from synapse.types import get_domain_from_id
from twisted.internet import defer
from ._base import BaseHandler
@ -27,6 +28,8 @@ class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
self.state = hs.get_state_handler()
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
@ -45,29 +48,29 @@ class DeviceHandler(BaseHandler):
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
new_device = yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
if new_device:
yield self.notify_device_update(user_id, device_id)
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string(10).upper()
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
device_id = stringutils.random_string(10).upper()
new_device = yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
yield self.notify_device_update(user_id, device_id)
defer.returnValue(device_id)
except errors.StoreError:
attempts += 1
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@ -147,6 +150,8 @@ class DeviceHandler(BaseHandler):
user_id=user_id, device_id=device_id
)
yield self.notify_device_update(user_id, device_id)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
@ -166,12 +171,48 @@ class DeviceHandler(BaseHandler):
device_id,
new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
@defer.inlineCallbacks
def notify_device_update(self, user_id, device_id):
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
hosts = set()
for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in users)
hosts.discard(self.server_name)
position = yield self.store.add_device_change_to_streams(
user_id, device_id, list(hosts)
)
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
for host in hosts:
self.federation.send_device_messages(host)
@defer.inlineCallbacks
def get_device_list_changes(self, user_id, room_ids, from_key):
room_ids = frozenset(room_ids)
user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(from_key)
for other_user_id in changed:
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed)
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})

View file

@ -259,6 +259,7 @@ class E2eKeysHandler(object):
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
yield self.device_handler.notify_device_update(user_id, device_id)
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:

View file

@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
"device_lists", # List of user_ids whose devices have chanegd
])):
__slots__ = []
@ -143,6 +144,7 @@ class SyncHandler(object):
self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs)
self.state = hs.get_state_handler()
self.device_handler = hs.get_device_handler()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
@ -544,6 +546,16 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder)
if since_token and since_token.device_list_key:
user_id = sync_config.user.to_string()
rooms = yield self.store.get_rooms_for_user(user_id)
joined_room_ids = set(r.room_id for r in rooms)
device_lists = yield self.device_handler.get_device_list_changes(
user_id, joined_room_ids, since_token.device_list_key
)
else:
device_lists = []
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
@ -551,6 +563,7 @@ class SyncHandler(object):
invited=sync_result_builder.invited,
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
next_batch=sync_result_builder.now_token,
))