Delete refresh tokens when deleting devices

This commit is contained in:
Richard van der Hoff 2016-07-26 11:09:47 +01:00
parent d34e9f93b7
commit 8e02494166
3 changed files with 83 additions and 15 deletions

View File

@ -138,8 +138,10 @@ class DeviceHandler(BaseHandler):
else: else:
raise raise
yield self.store.user_delete_access_tokens(user_id, yield self.store.user_delete_access_tokens(
device_id=device_id) user_id, device_id=device_id,
delete_refresh_tokens=True,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_device(self, user_id, device_id, content): def update_device(self, user_id, device_id, content):

View File

@ -252,20 +252,36 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[], def user_delete_access_tokens(self, user_id, except_token_ids=[],
device_id=None): device_id=None,
def f(txn): delete_refresh_tokens=False):
sql = "SELECT token FROM access_tokens WHERE user_id = ?" """
Invalidate access/refresh tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
delete_refresh_tokens (bool): True to delete refresh tokens as
well as access tokens.
Returns:
defer.Deferred:
"""
def f(txn, table, except_tokens, call_after_delete):
sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id] clauses = [user_id]
if device_id is not None: if device_id is not None:
sql += " AND device_id = ?" sql += " AND device_id = ?"
clauses.append(device_id) clauses.append(device_id)
if except_token_ids: if except_tokens:
sql += " AND id NOT IN (%s)" % ( sql += " AND id NOT IN (%s)" % (
",".join(["?" for _ in except_token_ids]), ",".join(["?" for _ in except_tokens]),
) )
clauses += except_token_ids clauses += except_tokens
txn.execute(sql, clauses) txn.execute(sql, clauses)
@ -274,16 +290,33 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
n = 100 n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks: for chunk in chunks:
if call_after_delete:
for row in chunk: for row in chunk:
txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) txn.call_after(call_after_delete, (row[0],))
txn.execute( txn.execute(
"DELETE FROM access_tokens WHERE token in (%s)" % ( "DELETE FROM %s WHERE token in (%s)" % (
table,
",".join(["?" for _ in chunk]), ",".join(["?" for _ in chunk]),
), [r[0] for r in chunk] ), [r[0] for r in chunk]
) )
yield self.runInteraction("user_delete_access_tokens", f) # delete refresh tokens first, to stop new access tokens being
# allocated while our backs are turned
if delete_refresh_tokens:
yield self.runInteraction(
"user_delete_access_tokens", f,
table="refresh_tokens",
except_tokens=[],
call_after_delete=None,
)
yield self.runInteraction(
"user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
)
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
def f(txn): def f(txn):
@ -306,9 +339,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Args: Args:
token (str): The access token of a user. token (str): The access token of a user.
Returns: Returns:
dict: Including the name (user_id) and the ID of their access token. defer.Deferred: None, if the token did not match, ootherwise dict
Raises: including the keys `name`, `is_guest`, `device_id`, `token_id`.
StoreError if no user was found.
""" """
return self.runInteraction( return self.runInteraction(
"get_user_by_access_token", "get_user_by_access_token",

View File

@ -128,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase):
with self.assertRaises(StoreError): with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate) yield self.store.exchange_refresh_token(last_token, generator.generate)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
generator = TokenGenerator()
refresh_token = generator.generate(self.user_id)
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
self.device_id)
yield self.store.add_refresh_token_to_user(self.user_id, refresh_token,
self.device_id)
# now delete some
yield self.store.user_delete_access_tokens(
self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
# check they were deleted
user = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertIsNone(user, "access token was not deleted by device_id")
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(refresh_token,
generator.generate)
# check the one not associated with the device was not deleted
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertEqual(self.user_id, user["name"])
# now delete the rest
yield self.store.user_delete_access_tokens(
self.user_id, delete_refresh_tokens=True)
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertIsNone(user,
"access token was not deleted without device_id")
class TokenGenerator: class TokenGenerator:
def __init__(self): def __init__(self):