Delete devices in various logout situations

Make sure that we delete devices whenever a user is logged out due to any of
the following situations:

 * /logout
 * /logout_all
 * change password
 * deactivate account (by the user or by an admin)
 * invalidate access token from a dynamic module

Fixes #2672.
This commit is contained in:
Richard van der Hoff 2017-11-29 15:44:59 +00:00
parent ae31f8ce45
commit ad7e570d07
5 changed files with 75 additions and 5 deletions

View File

@ -26,6 +26,7 @@ class DeactivateAccountHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(DeactivateAccountHandler, self).__init__(hs) super(DeactivateAccountHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def deactivate_account(self, user_id): def deactivate_account(self, user_id):
@ -39,6 +40,13 @@ class DeactivateAccountHandler(BaseHandler):
""" """
# FIXME: Theoretically there is a race here wherein user resets # FIXME: Theoretically there is a race here wherein user resets
# password using threepid. # password using threepid.
# first delete any devices belonging to the user, which will also
# delete corresponding access tokens.
yield self._device_handler.delete_all_devices_for_user(user_id)
# then delete any remaining access tokens which weren't associated with
# a device.
yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.user_delete_threepids(user_id) yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None) yield self.store.user_set_password_hash(user_id, None)

View File

@ -170,13 +170,31 @@ class DeviceHandler(BaseHandler):
yield self.notify_device_update(user_id, [device_id]) yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def delete_all_devices_for_user(self, user_id, except_device_id=None):
"""Delete all of the user's devices
Args:
user_id (str):
except_device_id (str|None): optional device id which should not
be deleted
Returns:
defer.Deferred:
"""
device_map = yield self.store.get_devices_by_user(user_id)
device_ids = device_map.keys()
if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id]
yield self.delete_devices(user_id, device_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_devices(self, user_id, device_ids): def delete_devices(self, user_id, device_ids):
""" Delete several devices """ Delete several devices
Args: Args:
user_id (str): user_id (str):
device_ids (str): The list of device IDs to delete device_ids (List[str]): The list of device IDs to delete
Returns: Returns:
defer.Deferred: defer.Deferred:

View File

@ -27,11 +27,13 @@ class SetPasswordHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(SetPasswordHandler, self).__init__(hs) super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self._auth_handler.hash(newpassword) password_hash = self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None except_access_token_id = requester.access_token_id if requester else None
try: try:
@ -40,6 +42,15 @@ class SetPasswordHandler(BaseHandler):
if e.code == 404: if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e raise e
# we want to log out all of the user's other sessions. First delete
# all his other devices.
yield self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id,
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user( yield self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id, user_id, except_token_id=except_access_token_id,
) )

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
@ -81,6 +82,7 @@ class ModuleApi(object):
reg = self.hs.get_handlers().registration_handler reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart) return reg.register(localpart=localpart)
@defer.inlineCallbacks
def invalidate_access_token(self, access_token): def invalidate_access_token(self, access_token):
"""Invalidate an access token for a user """Invalidate an access token for a user
@ -94,8 +96,16 @@ class ModuleApi(object):
Raises: Raises:
synapse.api.errors.AuthError: the access token is invalid synapse.api.errors.AuthError: the access token is invalid
""" """
# see if the access token corresponds to a device
return self._auth_handler.delete_access_token(access_token) user_info = yield self._auth.get_user_by_access_token(access_token)
device_id = user_info.get("device_id")
user_id = user_info["user"].to_string()
if device_id:
# delete the device, which will also delete its access tokens
yield self.hs.get_device_handler().delete_device(user_id, device_id)
else:
# no associated device. Just delete the access token.
yield self._auth_handler.delete_access_token(access_token)
def run_db_interaction(self, desc, func, *args, **kwargs): def run_db_interaction(self, desc, func, *args, **kwargs):
"""Run a function with a database connection """Run a function with a database connection

View File

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs) super(LogoutRestServlet, self).__init__(hs)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
access_token = get_access_token_from_request(request) try:
yield self._auth_handler.delete_access_token(access_token) requester = yield self.auth.get_user_by_req(request)
except AuthError:
# this implies the access token has already been deleted.
pass
else:
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(
requester.user.to_string(), requester.device_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -49,6 +65,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
super(LogoutAllRestServlet, self).__init__(hs) super(LogoutAllRestServlet, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -57,6 +74,12 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
# first delete all of the user's devices
yield self._device_handler.delete_all_devices_for_user(user_id)
# .. and then delete any access tokens which weren't associated with
# devices.
yield self._auth_handler.delete_access_tokens_for_user(user_id) yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))