rest/client/v1/register: use the correct requester in createUser

Signed-off-by: Patrik Oldsberg <patrik.oldsberg@ericsson.com>
This commit is contained in:
Patrik Oldsberg 2016-10-04 22:32:58 +02:00
parent 3de7c8a4d0
commit 7b5546d077
4 changed files with 23 additions and 32 deletions

View File

@ -19,7 +19,6 @@ import urllib
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
@ -370,7 +369,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_in_ms, def get_or_create_user(self, requester, localpart, displayname, duration_in_ms,
password_hash=None): password_hash=None):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@ -417,9 +416,8 @@ class RegistrationHandler(BaseHandler):
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname( yield profile_handler.set_displayname(
user, requester, displayname user, requester, displayname, by_admin=True,
) )
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))

View File

@ -22,6 +22,7 @@ from synapse.api.auth import get_access_token_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.types import create_requester
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -397,9 +398,10 @@ class CreateUserRestServlet(ClientV1RestServlet):
if not app_service: if not app_service:
raise SynapseError(403, "Invalid application service token.") raise SynapseError(403, "Invalid application service token.")
logger.debug("creating user: %s", user_json) requester = create_requester(app_service.sender)
response = yield self._do_create(user_json) logger.debug("creating user: %s", user_json)
response = yield self._do_create(requester, user_json)
defer.returnValue((200, response)) defer.returnValue((200, response))
@ -407,7 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
return 403, {} return 403, {}
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_create(self, user_json): def _do_create(self, requester, user_json):
yield run_on_reactor() yield run_on_reactor()
if "localpart" not in user_json: if "localpart" not in user_json:
@ -433,6 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user( user_id, token = yield handler.get_or_create_user(
requester=requester,
localpart=localpart, localpart=localpart,
displayname=displayname, displayname=displayname,
duration_in_ms=(duration_seconds * 1000), duration_in_ms=(duration_seconds * 1000),

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from .. import unittest from .. import unittest
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import UserID from synapse.types import UserID, create_requester
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -57,8 +57,9 @@ class RegistrationTestCase(unittest.TestCase):
local_part = "someone" local_part = "someone"
display_name = "someone" display_name = "someone"
user_id = "@someone:test" user_id = "@someone:test"
requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms) requester, local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')
@ -74,7 +75,8 @@ class RegistrationTestCase(unittest.TestCase):
local_part = "frank" local_part = "frank"
display_name = "Frank" display_name = "Frank"
user_id = "@frank:test" user_id = "@frank:test"
requester = create_requester("@as:test")
result_user_id, result_token = yield self.handler.get_or_create_user( result_user_id, result_token = yield self.handler.get_or_create_user(
local_part, display_name, duration_ms) requester, local_part, display_name, duration_ms)
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertEquals(result_token, 'secret') self.assertEquals(result_token, 'secret')

View File

@ -31,33 +31,21 @@ class CreateUserServletTestCase(unittest.TestCase):
) )
self.request.args = {} self.request.args = {}
self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice))
)
self.auth_result = (False, None, None, None)
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
)
self.registration_handler = Mock() self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
# do the dance to hook it up to the hs global self.appservice = Mock(sender="@as:test")
self.handlers = Mock( self.datastore = Mock(
auth_handler=self.auth_handler, get_app_service_by_token=Mock(return_value=self.appservice)
)
# do the dance to hook things up to the hs global
handlers = Mock(
registration_handler=self.registration_handler, registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
) )
self.hs = Mock() self.hs = Mock()
self.hs.hostname = "supergbig~testing~thing.com" self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_handlers = Mock(return_value=handlers)
self.hs.config.enable_registration = True
# init the thing we're testing
self.servlet = CreateUserRestServlet(self.hs) self.servlet = CreateUserRestServlet(self.hs)
@defer.inlineCallbacks @defer.inlineCallbacks