Implement deleting devices

This commit is contained in:
Richard van der Hoff 2016-07-22 14:52:53 +01:00
parent 33d08e8433
commit 436bffd15f
11 changed files with 176 additions and 21 deletions

View File

@ -77,6 +77,7 @@ class AuthHandler(BaseHandler):
self.ldap_bind_password = hs.config.ldap_bind_password self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -374,7 +375,8 @@ class AuthHandler(BaseHandler):
return self._check_password(user_id, password) return self._check_password(user_id, password)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None): def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
""" """
Gets login tuple for the user with the given user ID. Gets login tuple for the user with the given user ID.
@ -383,9 +385,15 @@ class AuthHandler(BaseHandler):
The user is assumed to have been authenticated by some other The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case. machanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
Args: Args:
user_id (str): canonical User ID user_id (str): canonical User ID
device_id (str): the device ID to associate with the access token device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns: Returns:
A tuple of: A tuple of:
The access token for the user's session. The access token for the user's session.
@ -397,6 +405,16 @@ class AuthHandler(BaseHandler):
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id) refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
defer.returnValue((access_token, refresh_token)) defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -100,7 +100,7 @@ class DeviceHandler(BaseHandler):
Args: Args:
user_id (str): user_id (str):
device_id (str) device_id (str):
Returns: Returns:
defer.Deferred: dict[str, X]: info on the device defer.Deferred: dict[str, X]: info on the device
@ -117,6 +117,31 @@ class DeviceHandler(BaseHandler):
_update_device_from_client_ips(device, ips) _update_device_from_client_ips(device, ips)
defer.returnValue(device) defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
yield self.store.user_delete_access_tokens(user_id,
device_id=device_id)
def _update_device_from_client_ips(device, client_ips): def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]), {})

View File

@ -152,7 +152,10 @@ class LoginRestServlet(ClientV1RestServlet):
) )
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
@ -173,7 +176,10 @@ class LoginRestServlet(ClientV1RestServlet):
) )
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id, device_id) yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
@ -262,7 +268,8 @@ class LoginRestServlet(ClientV1RestServlet):
) )
access_token, refresh_token = ( access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id( yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id registered_user_id, device_id,
login_submission.get("initial_device_display_name")
) )
) )
result = { result = {

View File

@ -70,6 +70,20 @@ class DeviceRestServlet(RestServlet):
) )
defer.returnValue((200, device)) defer.returnValue((200, device))
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint.
# It allows the client to delete access tokens, which feels like a
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one
# device at a time.
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
DevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server)

View File

@ -374,13 +374,13 @@ class RegisterRestServlet(RestServlet):
""" """
device_id = yield self._register_device(user_id, params) device_id = yield self._register_device(user_id, params)
access_token = yield self.auth_handler.issue_access_token( access_token, refresh_token = (
user_id, device_id=device_id yield self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
)
) )
refresh_token = yield self.auth_handler.issue_refresh_token(
user_id, device_id=device_id
)
defer.returnValue({ defer.returnValue({
"user_id": user_id, "user_id": user_id,
"access_token": access_token, "access_token": access_token,

View File

@ -76,6 +76,21 @@ class DeviceStore(SQLBaseStore):
desc="get_device", desc="get_device",
) )
def delete_device(self, user_id, device_id):
"""Delete a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
Returns:
defer.Deferred
"""
return self._simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_device",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_devices_by_user(self, user_id): def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices. """Retrieve all of a user's registered devices.

View File

@ -18,18 +18,31 @@ import re
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from synapse.storage import background_updates
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(SQLBaseStore): class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs): def __init__(self, hs):
super(RegistrationStore, self).__init__(hs) super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.register_background_index_update(
"access_tokens_device_index",
index_name="access_tokens_device_id",
table="access_tokens",
columns=["user_id", "device_id"],
)
self.register_background_index_update(
"refresh_tokens_device_index",
index_name="refresh_tokens_device_id",
table="refresh_tokens",
columns=["user_id", "device_id"],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None): def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user. """Adds an access token for the given user.
@ -238,11 +251,16 @@ class RegistrationStore(SQLBaseStore):
self.get_user_by_id.invalidate((user_id,)) self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[]): def user_delete_access_tokens(self, user_id, except_token_ids=[],
device_id=None):
def f(txn): def f(txn):
sql = "SELECT token FROM access_tokens WHERE user_id = ?" sql = "SELECT token FROM access_tokens WHERE user_id = ?"
clauses = [user_id] clauses = [user_id]
if device_id is not None:
sql += " AND device_id = ?"
clauses.append(device_id)
if except_token_ids: if except_token_ids:
sql += " AND id NOT IN (%s)" % ( sql += " AND id NOT IN (%s)" % (
",".join(["?" for _ in except_token_ids]), ",".join(["?" for _ in except_token_ids]),

View File

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('access_tokens_device_index', '{}');

View File

@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('refresh_tokens_device_index', '{}');

View File

@ -12,11 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse import types
from twisted.internet import defer from twisted.internet import defer
import synapse.api.errors
import synapse.handlers.device import synapse.handlers.device
import synapse.storage import synapse.storage
from synapse import types
from tests import unittest, utils from tests import unittest, utils
user1 = "@boris:aaa" user1 = "@boris:aaa"
@ -27,7 +30,7 @@ class DeviceTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(DeviceTestCase, self).__init__(*args, **kwargs) super(DeviceTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore self.store = None # type: synapse.storage.DataStore
self.handler = None # type: device.DeviceHandler self.handler = None # type: synapse.handlers.device.DeviceHandler
self.clock = None # type: utils.MockClock self.clock = None # type: utils.MockClock
@defer.inlineCallbacks @defer.inlineCallbacks
@ -123,6 +126,21 @@ class DeviceTestCase(unittest.TestCase):
"last_seen_ts": 3000000, "last_seen_ts": 3000000,
}, res) }, res)
@defer.inlineCallbacks
def test_delete_device(self):
yield self._record_users()
# delete the device
yield self.handler.delete_device(user1, "abc")
# check the device was deleted
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.get_device(user1, "abc")
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
@defer.inlineCallbacks @defer.inlineCallbacks
def _record_users(self): def _record_users(self):
# check this works for both devices which have a recorded client_ip, # check this works for both devices which have a recorded client_ip,

View File

@ -65,13 +65,16 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.appservice_register = Mock( self.registration_handler.appservice_register = Mock(
return_value=user_id return_value=user_id
) )
self.auth_handler.issue_access_token = Mock(return_value=token) self.auth_handler.get_login_tuple_for_user_id = Mock(
return_value=(token, "kermits_refresh_token")
)
(code, result) = yield self.servlet.on_POST(self.request) (code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200) self.assertEquals(code, 200)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname "home_server": self.hs.hostname
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
@ -121,7 +124,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey" "password": "monkey"
}, None) }, None)
self.registration_handler.register = Mock(return_value=(user_id, None)) self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.issue_access_token = Mock(return_value=token) self.auth_handler.get_login_tuple_for_user_id = Mock(
return_value=(token, "kermits_refresh_token")
)
self.device_handler.check_device_registered = \ self.device_handler.check_device_registered = \
Mock(return_value=device_id) Mock(return_value=device_id)
@ -130,13 +135,14 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result) self.assertIn("refresh_token", result)
self.auth_handler.issue_access_token.assert_called_once_with( self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id) user_id, device_id=device_id, initial_device_display_name=None)
def test_POST_disabled_registration(self): def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False self.hs.config.enable_registration = False