Merge pull request #1010 from matrix-org/erikj/refactor_deletions

Refactor user_delete_access_tokens. Invalidate get_user_by_access_token to slaves.
This commit is contained in:
Erik Johnston 2016-08-16 11:37:53 +01:00 committed by GitHub
commit 25c2332071
6 changed files with 46 additions and 51 deletions

View File

@ -741,7 +741,7 @@ class AuthHandler(BaseHandler):
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else [] except_access_token_id = requester.access_token_id if requester else None
try: try:
yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_set_password_hash(user_id, password_hash)
@ -750,10 +750,10 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e raise e
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids user_id, except_access_token_id
) )
yield self.hs.get_pusherpool().remove_pushers_by_user( yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_ids user_id, except_access_token_id
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_token_ids=[]): def remove_pushers_by_user(self, user_id, except_access_token_id=None):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s except access tokens ids %r", "Removing all pushers for user %s except access tokens id %r",
user_id, except_token_ids user_id, except_access_token_id
) )
for p in all: for p in all:
if p['user_name'] == user_id and p['access_token'] not in except_token_ids: if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']

View File

@ -51,6 +51,6 @@ class BaseSlavedStore(SQLBaseStore):
try: try:
getattr(self, cache_func).invalidate(tuple(keys)) getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError: except AttributeError:
logger.warn("Got unexpected cache_func: %r", cache_func) logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"])) self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None) return defer.succeed(None)

View File

@ -25,6 +25,6 @@ class SlavedRegistrationStore(BaseSlavedStore):
# TODO: use the cached version and invalidate deleted tokens # TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[ get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token" "get_user_by_access_token"
].orig ]
_query_for_auth = DataStore._query_for_auth.__func__ _query_for_auth = DataStore._query_for_auth.__func__

View File

@ -880,6 +880,7 @@ class SQLBaseStore(object):
ctx = self._cache_id_gen.get_next() ctx = self._cache_id_gen.get_next()
stream_id = ctx.__enter__() stream_id = ctx.__enter__()
txn.call_after(ctx.__exit__, None, None, None) txn.call_after(ctx.__exit__, None, None, None)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,

View File

@ -251,7 +251,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
self.get_user_by_id.invalidate((user_id,)) self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[], def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None, device_id=None,
delete_refresh_tokens=False): delete_refresh_tokens=False):
""" """
@ -259,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Args: Args:
user_id (str): ID of user the tokens belong to user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should except_token_id (str): list of access_tokens IDs which should
*not* be deleted *not* be deleted
device_id (str|None): ID of device the tokens are associated with. device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will If None, tokens associated with any device (or no device) will
@ -269,53 +269,45 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
Returns: Returns:
defer.Deferred: defer.Deferred:
""" """
def f(txn, table, except_tokens, call_after_delete): def f(txn):
sql = "SELECT token FROM %s WHERE user_id = ?" % table keyvalues = {
clauses = [user_id] "user_id": user_id,
}
if device_id is not None: if device_id is not None:
sql += " AND device_id = ?" keyvalues["device_id"] = device_id
clauses.append(device_id)
if except_tokens: if delete_refresh_tokens:
sql += " AND id NOT IN (%s)" % ( self._simple_delete_txn(
",".join(["?" for _ in except_tokens]), txn,
table="refresh_tokens",
keyvalues=keyvalues,
) )
clauses += except_tokens
txn.execute(sql, clauses) items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
rows = txn.fetchall() values = [v for _, v in items]
if except_token_id:
n = 100 where_clause += " AND id != ?"
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] values.append(except_token_id)
for chunk in chunks:
if call_after_delete:
for row in chunk:
txn.call_after(call_after_delete, (row[0],))
txn.execute( txn.execute(
"DELETE FROM %s WHERE token in (%s)" % ( "SELECT token FROM access_tokens WHERE %s" % where_clause,
table, values
",".join(["?" for _ in chunk]), )
), [r[0] for r in chunk] rows = self.cursor_to_dict(txn)
for row in rows:
self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (row["token"],)
) )
# delete refresh tokens first, to stop new access tokens being txn.execute(
# allocated while our backs are turned "DELETE FROM access_tokens WHERE %s" % where_clause,
if delete_refresh_tokens: values
yield self.runInteraction(
"user_delete_access_tokens", f,
table="refresh_tokens",
except_tokens=[],
call_after_delete=None,
) )
yield self.runInteraction( yield self.runInteraction(
"user_delete_access_tokens", f, "user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
) )
def delete_access_token(self, access_token): def delete_access_token(self, access_token):
@ -328,7 +320,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
}, },
) )
txn.call_after(self.get_user_by_access_token.invalidate, (access_token,)) self._invalidate_cache_and_stream(
txn, self.get_user_by_access_token, (access_token,)
)
return self.runInteraction("delete_access_token", f) return self.runInteraction("delete_access_token", f)