diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9e65d85e6..eaead5080 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -138,8 +138,10 @@ class DeviceHandler(BaseHandler): else: raise - yield self.store.user_delete_access_tokens(user_id, - device_id=device_id) + yield self.store.user_delete_access_tokens( + user_id, device_id=device_id, + delete_refresh_tokens=True, + ) @defer.inlineCallbacks def update_device(self, user_id, device_id, content): diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 935e82bf7..d9555e073 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -252,20 +252,36 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_ids=[], - device_id=None): - def f(txn): - sql = "SELECT token FROM access_tokens WHERE user_id = ?" + device_id=None, + delete_refresh_tokens=False): + """ + Invalidate access/refresh tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_ids (list[str]): list of access_tokens which should + *not* be deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + delete_refresh_tokens (bool): True to delete refresh tokens as + well as access tokens. + Returns: + defer.Deferred: + """ + def f(txn, table, except_tokens, call_after_delete): + sql = "SELECT token FROM %s WHERE user_id = ?" % table clauses = [user_id] if device_id is not None: sql += " AND device_id = ?" clauses.append(device_id) - if except_token_ids: + if except_tokens: sql += " AND id NOT IN (%s)" % ( - ",".join(["?" for _ in except_token_ids]), + ",".join(["?" for _ in except_tokens]), ) - clauses += except_token_ids + clauses += except_tokens txn.execute(sql, clauses) @@ -274,16 +290,33 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): n = 100 chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] for chunk in chunks: - for row in chunk: - txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) + if call_after_delete: + for row in chunk: + txn.call_after(call_after_delete, (row[0],)) txn.execute( - "DELETE FROM access_tokens WHERE token in (%s)" % ( + "DELETE FROM %s WHERE token in (%s)" % ( + table, ",".join(["?" for _ in chunk]), ), [r[0] for r in chunk] ) - yield self.runInteraction("user_delete_access_tokens", f) + # delete refresh tokens first, to stop new access tokens being + # allocated while our backs are turned + if delete_refresh_tokens: + yield self.runInteraction( + "user_delete_access_tokens", f, + table="refresh_tokens", + except_tokens=[], + call_after_delete=None, + ) + + yield self.runInteraction( + "user_delete_access_tokens", f, + table="access_tokens", + except_tokens=except_token_ids, + call_after_delete=self.get_user_by_access_token.invalidate, + ) def delete_access_token(self, access_token): def f(txn): @@ -306,9 +339,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): Args: token (str): The access token of a user. Returns: - dict: Including the name (user_id) and the ID of their access token. - Raises: - StoreError if no user was found. + defer.Deferred: None, if the token did not match, ootherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. """ return self.runInteraction( "get_user_by_access_token", diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index b03ca303a..f7d74dea8 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -128,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase): with self.assertRaises(StoreError): yield self.store.exchange_refresh_token(last_token, generator.generate) + @defer.inlineCallbacks + def test_user_delete_access_tokens(self): + # add some tokens + generator = TokenGenerator() + refresh_token = generator.generate(self.user_id) + yield self.store.register(self.user_id, self.tokens[0], self.pwhash) + yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], + self.device_id) + yield self.store.add_refresh_token_to_user(self.user_id, refresh_token, + self.device_id) + + # now delete some + yield self.store.user_delete_access_tokens( + self.user_id, device_id=self.device_id, delete_refresh_tokens=True) + + # check they were deleted + user = yield self.store.get_user_by_access_token(self.tokens[1]) + self.assertIsNone(user, "access token was not deleted by device_id") + with self.assertRaises(StoreError): + yield self.store.exchange_refresh_token(refresh_token, + generator.generate) + + # check the one not associated with the device was not deleted + user = yield self.store.get_user_by_access_token(self.tokens[0]) + self.assertEqual(self.user_id, user["name"]) + + # now delete the rest + yield self.store.user_delete_access_tokens( + self.user_id, delete_refresh_tokens=True) + + user = yield self.store.get_user_by_access_token(self.tokens[0]) + self.assertIsNone(user, + "access token was not deleted without device_id") + class TokenGenerator: def __init__(self):