Fix cache invalidation so deleting access tokens (which we did when changing password) actually takes effect without HS restart. Reinstate the code to avoid logging out the session that changed the password, removed in 415c2f0549

This commit is contained in:
David Baker 2016-03-11 13:14:18 +00:00
parent 379c60b08d
commit aa11db5f11
4 changed files with 34 additions and 17 deletions

View File

@ -432,13 +432,18 @@ class AuthHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else []
yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id) yield self.store.user_delete_access_tokens_except(
yield self.hs.get_pusherpool().remove_pushers_by_user(user_id) user_id, except_access_token_ids
yield self.store.flush_user(user_id) )
yield self.hs.get_pusherpool().remove_pushers_by_user_except_access_tokens(
user_id, except_access_token_ids
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id, medium, address, validated_at):

View File

@ -92,14 +92,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id): def remove_pushers_by_user_except_access_tokens(self, user_id, except_token_ids):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s", "Removing all pushers for user %s except access tokens ids %r",
user_id, user_id, except_token_ids
) )
for p in all: for p in all:
if p['user_name'] == user_id: if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']

View File

@ -79,7 +79,7 @@ class PasswordRestServlet(RestServlet):
new_password = params['new_password'] new_password = params['new_password']
yield self.auth_handler.set_password( yield self.auth_handler.set_password(
user_id, new_password user_id, new_password, requester
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -208,14 +208,26 @@ class RegistrationStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def flush_user(self, user_id): def user_delete_access_tokens_except(self, user_id, except_token_ids):
rows = yield self._execute( def f(txn):
'flush_user', None, txn.execute(
"SELECT token FROM access_tokens WHERE user_id = ?", "SELECT id, token FROM access_tokens WHERE user_id = ? LIMIT 50",
user_id (user_id,)
) )
rows = txn.fetchall()
for r in rows: for r in rows:
self.get_user_by_access_token.invalidate((r,)) if r[0] in except_token_ids:
continue
txn.call_after(self.get_user_by_access_token.invalidate, (r[1],))
txn.execute(
"DELETE FROM access_tokens WHERE id in (%s)" % ",".join(
["?" for _ in rows]
), [r[0] for r in rows]
)
return len(rows) == 50
while (yield self.runInteraction("user_delete_access_tokens_except", f)):
pass
@cached() @cached()
def get_user_by_access_token(self, token): def get_user_by_access_token(self, token):