Fix slow performance of /logout in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens. (#12056)

This commit is contained in:
reivilibre 2022-02-22 13:29:04 +00:00 committed by GitHub
parent 6a1bad511d
commit 235d2916ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 136 additions and 4 deletions

1
changelog.d/12056.bugfix Normal file
View File

@ -0,0 +1 @@
Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens.

View File

@ -1681,7 +1681,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
user_id=row[1], user_id=row[1],
device_id=row[2], device_id=row[2],
next_token_id=row[3], next_token_id=row[3],
has_next_refresh_token_been_refreshed=row[4], # SQLite returns 0 or 1 for false/true, so convert to a bool.
has_next_refresh_token_been_refreshed=bool(row[4]),
# This column is nullable, ensure it's a boolean # This column is nullable, ensure it's a boolean
has_next_access_token_been_used=(row[5] or False), has_next_access_token_been_used=(row[5] or False),
expiry_ts=row[6], expiry_ts=row[6],
@ -1697,12 +1698,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
Set the successor of a refresh token, removing the existing successor Set the successor of a refresh token, removing the existing successor
if any. if any.
This also deletes the predecessor refresh and access tokens,
since they cannot be valid anymore.
Args: Args:
token_id: ID of the refresh token to update. token_id: ID of the refresh token to update.
next_token_id: ID of its successor. next_token_id: ID of its successor.
""" """
def _replace_refresh_token_txn(txn) -> None: def _replace_refresh_token_txn(txn: LoggingTransaction) -> None:
# First check if there was an existing refresh token # First check if there was an existing refresh token
old_next_token_id = self.db_pool.simple_select_one_onecol_txn( old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
@ -1728,6 +1732,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
{"id": old_next_token_id}, {"id": old_next_token_id},
) )
# Delete the previous refresh token, since we only want to keep the
# last 2 refresh tokens in the database.
# (The predecessor of the latest refresh token is still useful in
# case the refresh was interrupted and the client re-uses the old
# one.)
# This cascades to delete the associated access token.
self.db_pool.simple_delete_txn(
txn, "refresh_tokens", {"next_token_id": token_id}
)
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"replace_refresh_token", _replace_refresh_token_txn "replace_refresh_token", _replace_refresh_token_txn
) )

View File

@ -0,0 +1,28 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- next_token_id is a foreign key reference, so previously required a table scan
-- when a row in the referenced table was deleted.
-- As it was self-referential and cascaded deletes, this led to O(t*n) time to
-- delete a row, where t: number of rows in the table and n: number of rows in
-- the ancestral 'chain' of access tokens.
--
-- This index is partial since we only require it for rows which reference
-- another.
-- Performance was tested to be the same regardless of whether the index was
-- full or partial, but a partial index can be smaller.
CREATE INDEX refresh_tokens_next_token_id
ON refresh_tokens(next_token_id)
WHERE next_token_id IS NOT NULL;

View File

@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus from http import HTTPStatus
from typing import Optional, Union from typing import Optional, Tuple, Union
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, register from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from tests import unittest from tests import unittest
@ -527,6 +528,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
auth.register_servlets, auth.register_servlets,
account.register_servlets, account.register_servlets,
login.register_servlets, login.register_servlets,
logout.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
register.register_servlets, register.register_servlets,
] ]
@ -984,3 +986,90 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.assertEqual( self.assertEqual(
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
) )
def test_many_token_refresh(self):
"""
If a refresh is performed many times during a session, there shouldn't be
extra 'cruft' built up over time.
This test was written specifically to troubleshoot a case where logout
was very slow if a lot of refreshes had been performed for the session.
"""
def _refresh(refresh_token: str) -> Tuple[str, str]:
"""
Performs one refresh, returning the next refresh token and access token.
"""
refresh_response = self.use_refresh_token(refresh_token)
self.assertEqual(
refresh_response.code, HTTPStatus.OK, refresh_response.result
)
return (
refresh_response.json_body["refresh_token"],
refresh_response.json_body["access_token"],
)
def _table_length(table_name: str) -> int:
"""
Helper to get the size of a table, in rows.
For testing only; trivially vulnerable to SQL injection.
"""
def _txn(txn: LoggingTransaction) -> int:
txn.execute(f"SELECT COUNT(1) FROM {table_name}")
row = txn.fetchone()
# Query is infallible
assert row is not None
return row[0]
return self.get_success(
self.hs.get_datastores().main.db_pool.runInteraction(
"_table_length", _txn
)
)
# Before we log in, there are no access tokens.
self.assertEqual(_table_length("access_tokens"), 0)
self.assertEqual(_table_length("refresh_tokens"), 0)
body = {
"type": "m.login.password",
"user": "test",
"password": self.user_pass,
"refresh_token": True,
}
login_response = self.make_request(
"POST",
"/_matrix/client/v3/login",
body,
)
self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
access_token = login_response.json_body["access_token"]
refresh_token = login_response.json_body["refresh_token"]
# Now that we have logged in, there should be one access token and one
# refresh token
self.assertEqual(_table_length("access_tokens"), 1)
self.assertEqual(_table_length("refresh_tokens"), 1)
for _ in range(5):
refresh_token, access_token = _refresh(refresh_token)
# After 5 sequential refreshes, there should only be the latest two
# refresh/access token pairs.
# (The last one is preserved because it's in use!
# The one before that is preserved because it can still be used to
# replace the last token pair, in case of e.g. a network interruption.)
self.assertEqual(_table_length("access_tokens"), 2)
self.assertEqual(_table_length("refresh_tokens"), 2)
logout_response = self.make_request(
"POST", "/_matrix/client/v3/logout", {}, access_token=access_token
)
self.assertEqual(logout_response.code, HTTPStatus.OK, logout_response.result)
# Now that we have logged in, there should be no access token
# and no refresh token
self.assertEqual(_table_length("access_tokens"), 0)
self.assertEqual(_table_length("refresh_tokens"), 0)