Merge pull request #240 from matrix-org/refresh

/tokenrefresh POST endpoint
This commit is contained in:
Daniel Wagner-Hall 2015-08-20 17:44:46 +01:00
commit b1e35eabf2
19 changed files with 303 additions and 76 deletions

View File

@ -361,7 +361,7 @@ class Auth(object):
except KeyError: except KeyError:
pass # normal users won't have the user_id query parameter set. pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_token(access_token) user_info = yield self.get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
device_id = user_info["device_id"] device_id = user_info["device_id"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
@ -390,7 +390,7 @@ class Auth(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_token(self, token): def get_user_by_access_token(self, token):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -401,7 +401,7 @@ class Auth(object):
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
ret = yield self.store.get_user_by_token(token) ret = yield self.store.get_user_by_access_token(token)
if not ret: if not ret:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",

View File

@ -26,6 +26,7 @@ from twisted.web.client import PartialDownloadError
import logging import logging
import bcrypt import bcrypt
import pymacaroons
import simplejson import simplejson
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -278,18 +279,18 @@ class AuthHandler(BaseHandler):
user_id (str): User ID user_id (str): User ID
password (str): Password password (str): Password
Returns: Returns:
The access token for the user's session. A tuple of:
The access token for the user's session.
The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
yield self._check_password(user_id, password) yield self._check_password(user_id, password)
reg_handler = self.hs.get_handlers().registration_handler
access_token = reg_handler.generate_token(user_id)
logger.info("Logging in user %s", user_id) logger.info("Logging in user %s", user_id)
yield self.store.add_access_token_to_user(user_id, access_token) access_token = yield self.issue_access_token(user_id)
defer.returnValue(access_token) refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
@ -304,6 +305,45 @@ class AuthHandler(BaseHandler):
logger.warn("Failed password login for user %s", user_id) logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def issue_access_token(self, user_id):
access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue(access_token)
@defer.inlineCallbacks
def issue_refresh_token(self, user_id):
refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
defer.returnValue(refresh_token)
def generate_access_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_refresh_token(self, user_id):
m = self._generate_base_macaroon(user_id)
m.add_first_party_caveat("type = refresh")
# Important to add a nonce, because otherwise every refresh token for a
# user will be the same.
m.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
return m.serialize()
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword): def set_password(self, user_id, newpassword):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())

View File

@ -27,7 +27,6 @@ from synapse.http.client import CaptchaServerHttpClient
import bcrypt import bcrypt
import logging import logging
import pymacaroons
import urllib import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -91,7 +90,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self.generate_token(user_id) token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -111,7 +110,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
token = self.generate_token(user_id) token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -161,7 +160,7 @@ class RegistrationHandler(BaseHandler):
400, "Invalid user localpart for this application service.", 400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
token = self.generate_token(user_id) token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -208,7 +207,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
token = self.generate_token(user_id) token = self.auth_handler().generate_access_token(user_id)
try: try:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -273,20 +272,6 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
def generate_token(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def _generate_user_id(self): def _generate_user_id(self):
return "-" + stringutils.random_string(18) return "-" + stringutils.random_string(18)
@ -329,3 +314,6 @@ class RegistrationHandler(BaseHandler):
} }
) )
defer.returnValue(data) defer.returnValue(data)
def auth_handler(self):
return self.hs.get_handlers().auth_handler

View File

@ -85,13 +85,15 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create( user_id = UserID.create(
user_id, self.hs.hostname).to_string() user_id, self.hs.hostname).to_string()
token = yield self.handlers.auth_handler.login_with_password( auth_handler = self.handlers.auth_handler
access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id, user_id=user_id,
password=login_submission["password"]) password=login_submission["password"])
result = { result = {
"user_id": user_id, # may have changed "user_id": login_submission["user"], # may have changed
"access_token": token, "access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }

View File

@ -21,6 +21,7 @@ from . import (
auth, auth,
receipts, receipts,
keys, keys,
tokenrefresh,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -42,3 +43,4 @@ class ClientV2AlphaRestResource(JsonResource):
auth.register_servlets(hs, client_resource) auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource)

View File

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request
class TokenRefreshRestServlet(RestServlet):
"""
Exchanges refresh tokens for a pair of an access token and a new refresh
token.
"""
PATTERN = client_v2_pattern("/tokenrefresh")
def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__()
self.hs = hs
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_dict_from_request(request)
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_handlers().auth_handler
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token)
new_access_token = yield auth_handler.issue_access_token(user_id)
defer.returnValue((200, {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}))
except KeyError:
raise SynapseError(400, "Missing required key 'refresh_token'.")
except StoreError:
raise AuthError(403, "Did not recognize refresh token")
def register_servlets(hs, http_server):
TokenRefreshRestServlet(hs).register(http_server)

View File

@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 22 SCHEMA_VERSION = 23
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -181,6 +181,7 @@ class SQLBaseStore(object):
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)

View File

@ -50,6 +50,28 @@ class RegistrationStore(SQLBaseStore):
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )
@defer.inlineCallbacks
def add_refresh_token_to_user(self, user_id, token):
"""Adds a refresh token for the given user.
Args:
user_id (str): The user ID.
token (str): The new refresh token to add.
Raises:
StoreError if there was a problem adding this.
"""
next_id = yield self._refresh_tokens_id_gen.get_next()
yield self._simple_insert(
"refresh_tokens",
{
"id": next_id,
"user_id": user_id,
"token": token
},
desc="add_refresh_token_to_user",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, user_id, token, password_hash): def register(self, user_id, token, password_hash):
"""Attempts to register an account. """Attempts to register an account.
@ -132,10 +154,10 @@ class RegistrationStore(SQLBaseStore):
user_id user_id
) )
for r in rows: for r in rows:
self.get_user_by_token.invalidate((r,)) self.get_user_by_access_token.invalidate((r,))
@cached() @cached()
def get_user_by_token(self, token): def get_user_by_access_token(self, token):
"""Get a user from the given access token. """Get a user from the given access token.
Args: Args:
@ -147,11 +169,51 @@ class RegistrationStore(SQLBaseStore):
StoreError if no user was found. StoreError if no user was found.
""" """
return self.runInteraction( return self.runInteraction(
"get_user_by_token", "get_user_by_access_token",
self._query_for_auth, self._query_for_auth,
token token
) )
def exchange_refresh_token(self, refresh_token, token_generator):
"""Exchange a refresh token for a new access token and refresh token.
Doing so invalidates the old refresh token - refresh tokens are single
use.
Args:
token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This
function must never return the same value twice.
Returns:
tuple of (user_id, refresh_token)
Raises:
StoreError if no user was found with that refresh token.
"""
return self.runInteraction(
"exchange_refresh_token",
self._exchange_refresh_token,
refresh_token,
token_generator
)
def _exchange_refresh_token(self, txn, old_token, token_generator):
sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn)
if not rows:
raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"]
# TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id.
new_token = token_generator(user_id)
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,))
return user_id, new_token
@defer.inlineCallbacks @defer.inlineCallbacks
def is_server_admin(self, user): def is_server_admin(self, user):
res = yield self._simple_select_one_onecol( res = yield self._simple_select_one_onecol(

View File

@ -0,0 +1,21 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS refresh_tokens(
id INTEGER PRIMARY KEY AUTOINCREMENT,
token TEXT NOT NULL,
user_id TEXT NOT NULL,
UNIQUE (token)
);

View File

@ -44,7 +44,7 @@ class AuthTestCase(unittest.TestCase):
"token_id": "ditto", "token_id": "ditto",
"admin": False "admin": False
} }
self.store.get_user_by_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
@ -54,7 +54,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_user_bad_token(self): def test_get_user_by_req_user_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
@ -70,7 +70,7 @@ class AuthTestCase(unittest.TestCase):
"token_id": "ditto", "token_id": "ditto",
"admin": False "admin": False
} }
self.store.get_user_by_token = Mock(return_value=user_info) self.store.get_user_by_access_token = Mock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
@ -81,7 +81,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_valid_token(self): def test_get_user_by_req_appservice_valid_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
@ -91,7 +91,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
@ -102,7 +102,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_missing_token(self): def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
@ -115,7 +115,7 @@ class AuthTestCase(unittest.TestCase):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
@ -129,7 +129,7 @@ class AuthTestCase(unittest.TestCase):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
app_service.is_interested_in_user = Mock(return_value=False) app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]

View File

@ -16,27 +16,27 @@
import pymacaroons import pymacaroons
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from synapse.handlers.register import RegistrationHandler from synapse.handlers.auth import AuthHandler
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
from twisted.internet import defer from twisted.internet import defer
class RegisterHandlers(object): class AuthHandlers(object):
def __init__(self, hs): def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs) self.auth_handler = AuthHandler(hs)
class RegisterTestCase(unittest.TestCase): class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
self.hs = yield setup_test_homeserver(handlers=None) self.hs = yield setup_test_homeserver(handlers=None)
self.hs.handlers = RegisterHandlers(self.hs) self.hs.handlers = AuthHandlers(self.hs)
def test_token_is_a_macaroon(self): def test_token_is_a_macaroon(self):
self.hs.config.macaroon_secret_key = "this key is a huge secret" self.hs.config.macaroon_secret_key = "this key is a huge secret"
token = self.hs.handlers.registration_handler.generate_token("some_user") token = self.hs.handlers.auth_handler.generate_access_token("some_user")
# Check that we can parse the thing with pymacaroons # Check that we can parse the thing with pymacaroons
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
# The most basic of sanity checks # The most basic of sanity checks
@ -47,7 +47,7 @@ class RegisterTestCase(unittest.TestCase):
self.hs.config.macaroon_secret_key = "this key is a massive secret" self.hs.config.macaroon_secret_key = "this key is a massive secret"
self.hs.clock.now = 5000 self.hs.clock.now = 5000
token = self.hs.handlers.registration_handler.generate_token("a_user") token = self.hs.handlers.auth_handler.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat): def verify_gen(caveat):

View File

@ -70,7 +70,7 @@ class PresenceStateTestCase(unittest.TestCase):
return defer.succeed([]) return defer.succeed([])
self.datastore.get_presence_list = get_presence_list self.datastore.get_presence_list = get_presence_list
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"admin": False, "admin": False,
@ -78,7 +78,7 @@ class PresenceStateTestCase(unittest.TestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
room_member_handler = hs.handlers.room_member_handler = Mock( room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[ spec=[
@ -159,7 +159,7 @@ class PresenceListTestCase(unittest.TestCase):
) )
self.datastore.has_presence_state = has_presence_state self.datastore.has_presence_state = has_presence_state
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"admin": False, "admin": False,
@ -173,7 +173,7 @@ class PresenceListTestCase(unittest.TestCase):
] ]
) )
hs.get_v1auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
presence.register_servlets(hs, self.mock_resource) presence.register_servlets(hs, self.mock_resource)

View File

@ -54,14 +54,14 @@ class RoomPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -441,14 +441,14 @@ class RoomsMemberListTestCase(RestTestCase):
self.auth_user_id = self.user_id self.auth_user_id = self.user_id
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -521,14 +521,14 @@ class RoomsCreateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -614,7 +614,7 @@ class RoomTopicTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
@ -622,7 +622,7 @@ class RoomTopicTestCase(RestTestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -721,14 +721,14 @@ class RoomMemberStateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -848,14 +848,14 @@ class RoomMessagesTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -945,14 +945,14 @@ class RoomInitialSyncTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)

View File

@ -61,7 +61,7 @@ class RoomTypingTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
@ -69,7 +69,7 @@ class RoomTypingTestCase(RestTestCase):
"token_id": 1, "token_id": 1,
} }
hs.get_v1auth().get_user_by_token = _get_user_by_token hs.get_v1auth().get_user_by_access_token = _get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)

View File

@ -37,7 +37,7 @@ class RestTestCase(unittest.TestCase):
self.mock_resource = None self.mock_resource = None
self.auth_user_id = None self.auth_user_id = None
def mock_get_user_by_token(self, token=None): def mock_get_user_by_access_token(self, token=None):
return self.auth_user_id return self.auth_user_id
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -43,14 +43,14 @@ class V2AlphaRestTestCase(unittest.TestCase):
resource_for_federation=self.mock_resource, resource_for_federation=self.mock_resource,
) )
def _get_user_by_token(token=None): def _get_user_by_access_token(token=None):
return { return {
"user": UserID.from_string(self.USER_ID), "user": UserID.from_string(self.USER_ID),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1, "token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_access_token = _get_user_by_access_token
for r in self.TO_REGISTER: for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource) r.register_servlets(hs, self.mock_resource)

View File

@ -17,7 +17,9 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage.registration import RegistrationStore from synapse.storage.registration import RegistrationStore
from synapse.util import stringutils
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -27,6 +29,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield setup_test_homeserver() hs = yield setup_test_homeserver()
self.db_pool = hs.get_db_pool()
self.store = RegistrationStore(hs) self.store = RegistrationStore(hs)
@ -46,7 +49,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
(yield self.store.get_user_by_id(self.user_id)) (yield self.store.get_user_by_id(self.user_id))
) )
result = yield self.store.get_user_by_token(self.tokens[0]) result = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@ -64,7 +67,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash) yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
result = yield self.store.get_user_by_token(self.tokens[1]) result = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@ -77,3 +80,55 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.assertTrue("token_id" in result) self.assertTrue("token_id" in result)
@defer.inlineCallbacks
def test_exchange_refresh_token_valid(self):
uid = stringutils.random_string(32)
generator = TokenGenerator()
last_token = generator.generate(uid)
self.db_pool.runQuery(
"INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
(uid, last_token,))
(found_user_id, refresh_token) = yield self.store.exchange_refresh_token(
last_token, generator.generate)
self.assertEqual(uid, found_user_id)
rows = yield self.db_pool.runQuery(
"SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, ))
self.assertEqual([(refresh_token,)], rows)
# We issued token 1, then exchanged it for token 2
expected_refresh_token = u"%s-%d" % (uid, 2,)
self.assertEqual(expected_refresh_token, refresh_token)
@defer.inlineCallbacks
def test_exchange_refresh_token_none(self):
uid = stringutils.random_string(32)
generator = TokenGenerator()
last_token = generator.generate(uid)
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate)
@defer.inlineCallbacks
def test_exchange_refresh_token_invalid(self):
uid = stringutils.random_string(32)
generator = TokenGenerator()
last_token = generator.generate(uid)
wrong_token = "%s-wrong" % (last_token,)
self.db_pool.runQuery(
"INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
(uid, wrong_token,))
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate)
class TokenGenerator:
def __init__(self):
self._last_issued_token = 0
def generate(self, user_id):
self._last_issued_token += 1
return u"%s-%d" % (user_id, self._last_issued_token,)

View File

@ -277,7 +277,7 @@ class MemoryDataStore(object):
raise StoreError(400, "User in use.") raise StoreError(400, "User in use.")
self.tokens_to_users[token] = user_id self.tokens_to_users[token] = user_id
def get_user_by_token(self, token): def get_user_by_access_token(self, token):
try: try:
return { return {
"name": self.tokens_to_users[token], "name": self.tokens_to_users[token],