Fix unit tests

This commit is contained in:
Erik Johnston 2019-02-18 13:43:16 +00:00
parent 91c8a7f9f4
commit 41c3f21c3b

View File

@ -32,7 +32,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.identity_handler = Mock() self.identity_handler = Mock()
self.login_handler = Mock() self.login_handler = Mock()
self.device_handler = Mock() self.device_handler = Mock()
self.device_handler.check_device_registered = Mock(return_value="FAKE")
def check_device_registered(user_id, device_id, initial_display_name):
# Just echo back the given device ID, or return a new "FAKE" device
# ID
if device_id:
return device_id
else:
return "FAKE"
self.device_handler.check_device_registered = Mock(
side_effect=check_device_registered,
)
self.datastore = Mock(return_value=Mock()) self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=[]) self.datastore.get_current_state_deltas = Mock(return_value=[])
@ -106,14 +117,12 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
user_id = "@kermit:muppet" user_id = "@kermit:muppet"
token = "kermits_access_token" token = "kermits_access_token"
device_id = "frogfone" device_id = "frogfone"
request_data = json.dumps( params = {"username": "kermit", "password": "monkey", "device_id": device_id}
{"username": "kermit", "password": "monkey", "device_id": device_id} request_data = json.dumps(params)
)
self.registration_handler.check_username = Mock(return_value=True) self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.auth_result = (None, params, None)
self.registration_handler.register = Mock(return_value=(user_id, None)) self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
self.device_handler.check_device_registered = Mock(return_value=device_id)
request, channel = self.make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request) self.render(request)