Move get_or_create_user to test code (#5628)

This is only used in tests, so...
This commit is contained in:
Richard van der Hoff 2019-07-08 14:52:26 +01:00 committed by Amber Brown
parent 589d43d9cd
commit 1af2fcd492
3 changed files with 60 additions and 60 deletions

1
changelog.d/5628.misc Normal file
View File

@ -0,0 +1 @@
Move RegistrationHandler.get_or_create_user to test code.

View File

@ -505,57 +505,6 @@ class RegistrationHandler(BaseHandler):
) )
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks
def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
Args:
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
Returns:
A tuple of (user_id, access_token).
Raises:
RegistrationError if there was a problem registering.
NB this is only used in tests. TODO: move it to the test package!
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
yield self.auth.check_auth_blocking()
need_register = True
try:
yield self.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
else:
raise
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self.macaroon_gen.generate_access_token(user_id)
if need_register:
yield self.register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
create_profile_with_displayname=user.localpart,
)
else:
yield self._auth_handler.delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
yield self.profile_handler.set_displayname(
user, requester, displayname, by_admin=True
)
defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def _join_user_to_room(self, requester, room_identifier): def _join_user_to_room(self, requester, room_identifier):
room_id = None room_id = None

View File

@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import ResourceLimitError, SynapseError from synapse.api.errors import Codes, ResourceLimitError, SynapseError
from synapse.handlers.register import RegistrationHandler from synapse.handlers.register import RegistrationHandler
from synapse.types import RoomAlias, UserID, create_requester from synapse.types import RoomAlias, UserID, create_requester
@ -67,7 +67,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = frank.to_string() user_id = frank.to_string()
requester = create_requester(user_id) requester = create_requester(user_id)
result_user_id, result_token = self.get_success( result_user_id, result_token = self.get_success(
self.handler.get_or_create_user(requester, frank.localpart, "Frankie") self.get_or_create_user(requester, frank.localpart, "Frankie")
) )
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None) self.assertTrue(result_token is not None)
@ -87,7 +87,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = frank.to_string() user_id = frank.to_string()
requester = create_requester(user_id) requester = create_requester(user_id)
result_user_id, result_token = self.get_success( result_user_id, result_token = self.get_success(
self.handler.get_or_create_user(requester, local_part, None) self.get_or_create_user(requester, local_part, None)
) )
self.assertEquals(result_user_id, user_id) self.assertEquals(result_user_id, user_id)
self.assertTrue(result_token is not None) self.assertTrue(result_token is not None)
@ -95,9 +95,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_mau_limits_when_disabled(self): def test_mau_limits_when_disabled(self):
self.hs.config.limit_usage_by_mau = False self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception # Ensure does not throw exception
self.get_success( self.get_success(self.get_or_create_user(self.requester, "a", "display_name"))
self.handler.get_or_create_user(self.requester, "a", "display_name")
)
def test_get_or_create_user_mau_not_blocked(self): def test_get_or_create_user_mau_not_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
@ -105,7 +103,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.hs.config.max_mau_value - 1) return_value=defer.succeed(self.hs.config.max_mau_value - 1)
) )
# Ensure does not throw exception # Ensure does not throw exception
self.get_success(self.handler.get_or_create_user(self.requester, "c", "User")) self.get_success(self.get_or_create_user(self.requester, "c", "User"))
def test_get_or_create_user_mau_blocked(self): def test_get_or_create_user_mau_blocked(self):
self.hs.config.limit_usage_by_mau = True self.hs.config.limit_usage_by_mau = True
@ -113,7 +111,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.lots_of_users) return_value=defer.succeed(self.lots_of_users)
) )
self.get_failure( self.get_failure(
self.handler.get_or_create_user(self.requester, "b", "display_name"), self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError, ResourceLimitError,
) )
@ -121,7 +119,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
return_value=defer.succeed(self.hs.config.max_mau_value) return_value=defer.succeed(self.hs.config.max_mau_value)
) )
self.get_failure( self.get_failure(
self.handler.get_or_create_user(self.requester, "b", "display_name"), self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError, ResourceLimitError,
) )
@ -232,3 +230,55 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_invalid_user_id_length(self): def test_invalid_user_id_length(self):
invalid_user_id = "x" * 256 invalid_user_id = "x" * 256
self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError) self.get_failure(self.handler.register(localpart=invalid_user_id), SynapseError)
@defer.inlineCallbacks
def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
XXX: this used to be in the main codebase, but was only used by this file,
so got moved here. TODO: get rid of it, probably
Args:
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
Returns:
A tuple of (user_id, access_token).
Raises:
RegistrationError if there was a problem registering.
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
yield self.hs.get_auth().check_auth_blocking()
need_register = True
try:
yield self.handler.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
else:
raise
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self.macaroon_generator.generate_access_token(user_id)
if need_register:
yield self.handler.register_with_store(
user_id=user_id,
token=token,
password_hash=password_hash,
create_profile_with_displayname=user.localpart,
)
else:
yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
# logger.info("setting user display name: %s -> %s", user_id, displayname)
yield self.hs.get_profile_handler().set_displayname(
user, requester, displayname, by_admin=True
)
defer.returnValue((user_id, token))