Stop generating refresh tokens

Since we're not doing refresh tokens any more, we should start killing off the
dead code paths. /tokenrefresh itself is a bit of a thornier subject, since
there might be apps out there using it, but we can at least not generate
refresh tokens on new logins.
This commit is contained in:
Richard van der Hoff 2016-11-28 09:52:02 +00:00
parent b6146537d2
commit 5c4edc83b5
4 changed files with 20 additions and 45 deletions

View File

@ -380,12 +380,10 @@ class AuthHandler(BaseHandler):
return self._check_password(user_id, password) return self._check_password(user_id, password)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None, def get_access_token_for_user_id(self, user_id, device_id=None,
initial_display_name=None): initial_display_name=None):
""" """
Gets login tuple for the user with the given user ID. Creates a new access token for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case. machanism (e.g. CAS), and the user_id converted to the canonical case.
@ -400,16 +398,13 @@ class AuthHandler(BaseHandler):
initial_display_name (str): display name to associate with the initial_display_name (str): display name to associate with the
device if it needs re-registering device if it needs re-registering
Returns: Returns:
A tuple of:
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
logger.info("Logging in user %s on device %s", user_id, device_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we # it's possible we raced against a DELETE operation. The thing we
@ -420,7 +415,7 @@ class AuthHandler(BaseHandler):
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
defer.returnValue((access_token, refresh_token)) defer.returnValue(access_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_exists(self, user_id): def check_user_exists(self, user_id):
@ -531,13 +526,6 @@ class AuthHandler(BaseHandler):
device_id) device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks
def issue_refresh_token(self, user_id, device_id=None):
refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
device_id)
defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None, def generate_access_token(self, user_id, extra_caveats=None,
duration_in_ms=(60 * 60 * 1000)): duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []

View File

@ -137,16 +137,13 @@ class LoginRestServlet(ClientV1RestServlet):
password=login_submission["password"], password=login_submission["password"],
) )
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token = yield auth_handler.get_access_token_for_user_id(
yield auth_handler.get_login_tuple_for_user_id( user_id, device_id,
user_id, device_id, login_submission.get("initial_device_display_name"),
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
@ -161,16 +158,13 @@ class LoginRestServlet(ClientV1RestServlet):
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
device_id = yield self._register_device(user_id, login_submission) device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = ( access_token = yield auth_handler.get_access_token_for_user_id(
yield auth_handler.get_login_tuple_for_user_id( user_id, device_id,
user_id, device_id, login_submission.get("initial_device_display_name"),
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
@ -207,16 +201,14 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device( device_id = yield self._register_device(
registered_user_id, login_submission registered_user_id, login_submission
) )
access_token, refresh_token = ( access_token = yield auth_handler.get_access_token_for_user_id(
yield auth_handler.get_login_tuple_for_user_id( registered_user_id, device_id,
registered_user_id, device_id, login_submission.get("initial_device_display_name"),
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": registered_user_id, "user_id": registered_user_id,
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
else: else:

View File

@ -385,8 +385,8 @@ class RegisterRestServlet(RestServlet):
""" """
device_id = yield self._register_device(user_id, params) device_id = yield self._register_device(user_id, params)
access_token, refresh_token = ( access_token = (
yield self.auth_handler.get_login_tuple_for_user_id( yield self.auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id, user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name") initial_display_name=params.get("initial_device_display_name")
) )
@ -396,7 +396,6 @@ class RegisterRestServlet(RestServlet):
"user_id": user_id, "user_id": user_id,
"access_token": access_token, "access_token": access_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"refresh_token": refresh_token,
"device_id": device_id, "device_id": device_id,
}) })

View File

@ -67,8 +67,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler.appservice_register = Mock( self.registration_handler.appservice_register = Mock(
return_value=user_id return_value=user_id
) )
self.auth_handler.get_login_tuple_for_user_id = Mock( self.auth_handler.get_access_token_for_user_id = Mock(
return_value=(token, "kermits_refresh_token") return_value=token
) )
(code, result) = yield self.servlet.on_POST(self.request) (code, result) = yield self.servlet.on_POST(self.request)
@ -76,11 +76,9 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname "home_server": self.hs.hostname
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self): def test_POST_appservice_registration_invalid(self):
@ -126,8 +124,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
"password": "monkey" "password": "monkey"
}, None) }, None)
self.registration_handler.register = Mock(return_value=(user_id, None)) self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_login_tuple_for_user_id = Mock( self.auth_handler.get_access_token_for_user_id = Mock(
return_value=(token, "kermits_refresh_token") return_value=token
) )
self.device_handler.check_device_registered = \ self.device_handler.check_device_registered = \
Mock(return_value=device_id) Mock(return_value=device_id)
@ -137,12 +135,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token, "access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertDictContainsSubset(det_data, result) self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
self.auth_handler.get_login_tuple_for_user_id( self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None) user_id, device_id=device_id, initial_device_display_name=None)