Convert some test cases to use HomeserverTestCase. (#9377)

This has the side-effect of being able to remove use of `inlineCallbacks`
in the test-cases for cleaner tracebacks.
This commit is contained in:
Patrick Cloke 2021-02-11 10:29:09 -05:00 committed by GitHub
parent 6dade80048
commit 8a33d217bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 298 additions and 512 deletions

1
changelog.d/9377.misc Normal file
View File

@ -0,0 +1 @@
Convert tests to use `HomeserverTestCase`.

View File

@ -16,28 +16,21 @@ from mock import Mock
import pymacaroons import pymacaroons
from twisted.internet import defer from synapse.api.errors import AuthError, ResourceLimitError
import synapse
import synapse.api.errors
from synapse.api.errors import ResourceLimitError
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
class AuthTestCase(unittest.TestCase): class AuthTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self): self.auth_handler = hs.get_auth_handler()
self.hs = yield setup_test_homeserver(self.addCleanup) self.macaroon_generator = hs.get_macaroon_generator()
self.auth_handler = self.hs.get_auth_handler()
self.macaroon_generator = self.hs.get_macaroon_generator()
# MAU tests # MAU tests
# AuthBlocking reads from the hs' config on initialization. We need to # AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs' # modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth()._auth_blocking self.auth_blocking = hs.get_auth()._auth_blocking
self.auth_blocking._max_mau_value = 50 self.auth_blocking._max_mau_value = 50
self.small_number_of_users = 1 self.small_number_of_users = 1
@ -52,8 +45,6 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect()) self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self): def test_macaroon_caveats(self):
self.hs.get_clock().now = 5000
token = self.macaroon_generator.generate_access_token("a_user") token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
@ -76,29 +67,25 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_nonce) v.satisfy_general(verify_nonce)
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self): def test_short_term_login_token_gives_user_id(self):
self.hs.get_clock().now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield defer.ensureDeferred( user_id = self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token) self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
self.assertEqual("a_user", user_id) self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected # when we advance the clock, the token should be rejected
self.hs.get_clock().now = 6000 self.reactor.advance(6)
with self.assertRaises(synapse.api.errors.AuthError): self.get_failure(
yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
self.auth_handler.validate_short_term_login_token_and_get_user_id(token) AuthError,
) )
@defer.inlineCallbacks
def test_short_term_login_token_cannot_replace_user_id(self): def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
macaroon = pymacaroons.Macaroon.deserialize(token) macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = yield defer.ensureDeferred( user_id = self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize() macaroon.serialize()
) )
@ -109,102 +96,90 @@ class AuthTestCase(unittest.TestCase):
# user_id. # user_id.
macaroon.add_first_party_caveat("user_id = b_user") macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError): self.get_failure(
yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token_and_get_user_id( macaroon.serialize()
macaroon.serialize() ),
) AuthError,
) )
@defer.inlineCallbacks
def test_mau_limits_disabled(self): def test_mau_limits_disabled(self):
self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception # Ensure does not throw exception
yield defer.ensureDeferred( self.get_success(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None "user_a", device_id=None, valid_until_ms=None
) )
) )
yield defer.ensureDeferred( self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )
) )
@defer.inlineCallbacks
def test_mau_limits_exceeded_large(self): def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users) return_value=make_awaitable(self.large_number_of_users)
) )
with self.assertRaises(ResourceLimitError): self.get_failure(
yield defer.ensureDeferred( self.auth_handler.get_access_token_for_user_id(
self.auth_handler.get_access_token_for_user_id( "user_a", device_id=None, valid_until_ms=None
"user_a", device_id=None, valid_until_ms=None ),
) ResourceLimitError,
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users) return_value=make_awaitable(self.large_number_of_users)
) )
with self.assertRaises(ResourceLimitError): self.get_failure(
yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize()
self._get_macaroon().serialize() ),
) ResourceLimitError,
) )
@defer.inlineCallbacks
def test_mau_limits_parity(self): def test_mau_limits_parity(self):
# Ensure we're not at the unix epoch.
self.reactor.advance(1)
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
# If not in monthly active cohort # Set the server to be at the edge of too many users.
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.auth_blocking._max_mau_value) return_value=make_awaitable(self.auth_blocking._max_mau_value)
) )
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
)
)
self.hs.get_datastore().get_monthly_active_count = Mock( # If not in monthly active cohort
return_value=make_awaitable(self.auth_blocking._max_mau_value) self.get_failure(
self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None
),
ResourceLimitError,
) )
with self.assertRaises(ResourceLimitError): self.get_failure(
yield defer.ensureDeferred( self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self._get_macaroon().serialize()
self._get_macaroon().serialize() ),
) ResourceLimitError,
) )
# If in monthly active cohort # If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock( self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=make_awaitable(self.hs.get_clock().time_msec()) return_value=make_awaitable(self.clock.time_msec())
) )
self.hs.get_datastore().get_monthly_active_count = Mock( self.get_success(
return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None "user_a", device_id=None, valid_until_ms=None
) )
) )
self.hs.get_datastore().user_last_seen_monthly_active = Mock( self.get_success(
return_value=make_awaitable(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )
) )
@defer.inlineCallbacks
def test_mau_limits_not_exceeded(self): def test_mau_limits_not_exceeded(self):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
@ -212,7 +187,7 @@ class AuthTestCase(unittest.TestCase):
return_value=make_awaitable(self.small_number_of_users) return_value=make_awaitable(self.small_number_of_users)
) )
# Ensure does not raise exception # Ensure does not raise exception
yield defer.ensureDeferred( self.get_success(
self.auth_handler.get_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
"user_a", device_id=None, valid_until_ms=None "user_a", device_id=None, valid_until_ms=None
) )
@ -221,7 +196,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock( self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=make_awaitable(self.small_number_of_users) return_value=make_awaitable(self.small_number_of_users)
) )
yield defer.ensureDeferred( self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id( self.auth_handler.validate_short_term_login_token_and_get_user_id(
self._get_macaroon().serialize() self._get_macaroon().serialize()
) )

View File

@ -18,42 +18,27 @@ import mock
from signedjson import key as key, sign as sign from signedjson import key as key, sign as sign
from twisted.internet import defer
import synapse.handlers.e2e_keys
import synapse.storage
from synapse.api import errors
from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from tests import unittest, utils from tests import unittest
class E2eKeysHandlerTestCase(unittest.TestCase): class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def __init__(self, *args, **kwargs): def make_homeserver(self, reactor, clock):
super().__init__(*args, **kwargs) return self.setup_test_homeserver(federation_client=mock.Mock())
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
self.store = None # type: synapse.storage.Storage
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self): self.handler = hs.get_e2e_keys_handler()
self.hs = yield utils.setup_test_homeserver(
self.addCleanup, federation_client=mock.Mock()
)
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_query_local_devices_no_devices(self): def test_query_local_devices_no_devices(self):
"""If the user has no devices, we expect an empty list. """If the user has no devices, we expect an empty list.
""" """
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
res = yield defer.ensureDeferred( res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.handler.query_local_devices({local_user: None})
)
self.assertDictEqual(res, {local_user: {}}) self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
def test_reupload_one_time_keys(self): def test_reupload_one_time_keys(self):
"""we should be able to re-upload the same keys""" """we should be able to re-upload the same keys"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
@ -64,7 +49,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"}, "alg2:k3": {"key": "key3"},
} }
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys} local_user, device_id, {"one_time_keys": keys}
) )
@ -73,14 +58,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# we should be able to change the signature without a problem # we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2" keys["alg2:k2"]["signatures"]["k1"] = "sig2"
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys} local_user, device_id, {"one_time_keys": keys}
) )
) )
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
@defer.inlineCallbacks
def test_change_one_time_keys(self): def test_change_one_time_keys(self):
"""attempts to change one-time-keys should be rejected""" """attempts to change one-time-keys should be rejected"""
@ -92,75 +76,64 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"}, "alg2:k3": {"key": "key3"},
} }
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys} local_user, device_id, {"one_time_keys": keys}
) )
) )
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}}) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
try: # Error when changing string key
yield defer.ensureDeferred( self.get_failure(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}} local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
) ),
) SynapseError,
self.fail("No error when changing string key") )
except errors.SynapseError:
pass
try: # Error when replacing dict key with strin
yield defer.ensureDeferred( self.get_failure(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}} local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
) ),
) SynapseError,
self.fail("No error when replacing dict key with string") )
except errors.SynapseError:
pass
try: # Error when replacing string key with dict
yield defer.ensureDeferred( self.get_failure(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}},
device_id, ),
{"one_time_keys": {"alg1:k1": {"key": "key"}}}, SynapseError,
) )
)
self.fail("No error when replacing string key with dict")
except errors.SynapseError:
pass
try: # Error when replacing dict key
yield defer.ensureDeferred( self.get_failure(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, local_user,
device_id, device_id,
{ {
"one_time_keys": { "one_time_keys": {
"alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}} "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
} }
}, },
) ),
) SynapseError,
self.fail("No error when replacing dict key") )
except errors.SynapseError:
pass
@defer.inlineCallbacks
def test_claim_one_time_key(self): def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
keys = {"alg1:k1": "key1"} keys = {"alg1:k1": "key1"}
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys} local_user, device_id, {"one_time_keys": keys}
) )
) )
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}}) self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
res2 = yield defer.ensureDeferred( res2 = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
@ -173,7 +146,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
}, },
) )
@defer.inlineCallbacks
def test_fallback_key(self): def test_fallback_key(self):
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
device_id = "xyz" device_id = "xyz"
@ -181,12 +153,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
otk = {"alg1:k2": "key2"} otk = {"alg1:k2": "key2"}
# we shouldn't have any unused fallback keys yet # we shouldn't have any unused fallback keys yet
res = yield defer.ensureDeferred( res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
) )
self.assertEqual(res, []) self.assertEqual(res, [])
yield defer.ensureDeferred( self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, local_user,
device_id, device_id,
@ -195,14 +167,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
) )
# we should now have an unused alg1 key # we should now have an unused alg1 key
res = yield defer.ensureDeferred( res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
) )
self.assertEqual(res, ["alg1"]) self.assertEqual(res, ["alg1"])
# claiming an OTK when no OTKs are available should return the fallback # claiming an OTK when no OTKs are available should return the fallback
# key # key
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
@ -213,13 +185,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
) )
# we shouldn't have any unused fallback keys again # we shouldn't have any unused fallback keys again
res = yield defer.ensureDeferred( res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id) self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
) )
self.assertEqual(res, []) self.assertEqual(res, [])
# claiming an OTK again should return the same fallback key # claiming an OTK again should return the same fallback key
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
@ -231,13 +203,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# if the user uploads a one-time key, the next claim should fetch the # if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback # one-time key, and then go back to the fallback
yield defer.ensureDeferred( self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": otk} local_user, device_id, {"one_time_keys": otk}
) )
) )
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
@ -246,7 +218,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}}, res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
) )
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.claim_one_time_keys( self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
) )
@ -256,7 +228,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}}, {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
) )
@defer.inlineCallbacks
def test_replace_master_key(self): def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable""" """uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
@ -270,9 +241,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
}, },
} }
} }
yield defer.ensureDeferred( self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
self.handler.upload_signing_keys_for_user(local_user, keys1)
)
keys2 = { keys2 = {
"master_key": { "master_key": {
@ -284,16 +253,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
}, },
} }
} }
yield defer.ensureDeferred( self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
self.handler.upload_signing_keys_for_user(local_user, keys2)
)
devices = yield defer.ensureDeferred( devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
) )
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@defer.inlineCallbacks
def test_reupload_signatures(self): def test_reupload_signatures(self):
"""re-uploading a signature should not fail""" """re-uploading a signature should not fail"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
@ -326,9 +292,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
"2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
) )
yield defer.ensureDeferred( self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
self.handler.upload_signing_keys_for_user(local_user, keys1)
)
# upload two device keys, which will be signed later by the self-signing key # upload two device keys, which will be signed later by the self-signing key
device_key_1 = { device_key_1 = {
@ -358,12 +322,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"signatures": {local_user: {"ed25519:def": "base64+signature"}}, "signatures": {local_user: {"ed25519:def": "base64+signature"}},
} }
yield defer.ensureDeferred( self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, "abc", {"device_keys": device_key_1} local_user, "abc", {"device_keys": device_key_1}
) )
) )
yield defer.ensureDeferred( self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, "def", {"device_keys": device_key_2} local_user, "def", {"device_keys": device_key_2}
) )
@ -372,7 +336,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# sign the first device key and upload it # sign the first device key and upload it
del device_key_1["signatures"] del device_key_1["signatures"]
sign.sign_json(device_key_1, local_user, signing_key) sign.sign_json(device_key_1, local_user, signing_key)
yield defer.ensureDeferred( self.get_success(
self.handler.upload_signatures_for_device_keys( self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1}} local_user, {local_user: {"abc": device_key_1}}
) )
@ -383,7 +347,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# signature for it # signature for it
del device_key_2["signatures"] del device_key_2["signatures"]
sign.sign_json(device_key_2, local_user, signing_key) sign.sign_json(device_key_2, local_user, signing_key)
yield defer.ensureDeferred( self.get_success(
self.handler.upload_signatures_for_device_keys( self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
) )
@ -391,7 +355,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
devices = yield defer.ensureDeferred( devices = self.get_success(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
) )
del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["abc"]["unsigned"]
@ -399,7 +363,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
@defer.inlineCallbacks
def test_self_signing_key_doesnt_show_up_as_device(self): def test_self_signing_key_doesnt_show_up_as_device(self):
"""signing keys should be hidden when fetching a user's devices""" """signing keys should be hidden when fetching a user's devices"""
local_user = "@boris:" + self.hs.hostname local_user = "@boris:" + self.hs.hostname
@ -413,29 +376,22 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
}, },
} }
} }
yield defer.ensureDeferred( self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
self.handler.upload_signing_keys_for_user(local_user, keys1)
)
res = None e = self.get_failure(
try: self.hs.get_device_handler().check_device_registered(
yield defer.ensureDeferred( user_id=local_user,
self.hs.get_device_handler().check_device_registered( device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
user_id=local_user, initial_device_display_name="new display name",
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", ),
initial_device_display_name="new display name", SynapseError,
) )
) res = e.value.code
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 400) self.assertEqual(res, 400)
res = yield defer.ensureDeferred( res = self.get_success(self.handler.query_local_devices({local_user: None}))
self.handler.query_local_devices({local_user: None})
)
self.assertDictEqual(res, {local_user: {}}) self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
def test_upload_signatures(self): def test_upload_signatures(self):
"""should check signatures that are uploaded""" """should check signatures that are uploaded"""
# set up a user with cross-signing keys and a device. This user will # set up a user with cross-signing keys and a device. This user will
@ -458,7 +414,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
) )
yield defer.ensureDeferred( self.get_success(
self.handler.upload_keys_for_user( self.handler.upload_keys_for_user(
local_user, device_id, {"device_keys": device_key} local_user, device_id, {"device_keys": device_key}
) )
@ -501,7 +457,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_signing_key": usersigning_key, "user_signing_key": usersigning_key,
"self_signing_key": selfsigning_key, "self_signing_key": selfsigning_key,
} }
yield defer.ensureDeferred( self.get_success(
self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
) )
@ -515,14 +471,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"usage": ["master"], "usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
} }
yield defer.ensureDeferred( self.get_success(
self.handler.upload_signing_keys_for_user( self.handler.upload_signing_keys_for_user(
other_user, {"master_key": other_master_key} other_user, {"master_key": other_master_key}
) )
) )
# test various signature failures (see below) # test various signature failures (see below)
ret = yield defer.ensureDeferred( ret = self.get_success(
self.handler.upload_signatures_for_device_keys( self.handler.upload_signatures_for_device_keys(
local_user, local_user,
{ {
@ -602,20 +558,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
) )
user_failures = ret["failures"][local_user] user_failures = ret["failures"][local_user]
self.assertEqual(user_failures[device_id]["errcode"], Codes.INVALID_SIGNATURE)
self.assertEqual( self.assertEqual(
user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE user_failures[master_pubkey]["errcode"], Codes.INVALID_SIGNATURE
) )
self.assertEqual( self.assertEqual(user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE
)
self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND)
other_user_failures = ret["failures"][other_user] other_user_failures = ret["failures"][other_user]
self.assertEqual(other_user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
self.assertEqual( self.assertEqual(
other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND other_user_failures[other_master_pubkey]["errcode"], Codes.UNKNOWN
)
self.assertEqual(
other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN
) )
# test successful signatures # test successful signatures
@ -623,7 +575,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
sign.sign_json(device_key, local_user, selfsigning_signing_key) sign.sign_json(device_key, local_user, selfsigning_signing_key)
sign.sign_json(master_key, local_user, device_signing_key) sign.sign_json(master_key, local_user, device_signing_key)
sign.sign_json(other_master_key, local_user, usersigning_signing_key) sign.sign_json(other_master_key, local_user, usersigning_signing_key)
ret = yield defer.ensureDeferred( ret = self.get_success(
self.handler.upload_signatures_for_device_keys( self.handler.upload_signatures_for_device_keys(
local_user, local_user,
{ {
@ -636,7 +588,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(ret["failures"], {}) self.assertEqual(ret["failures"], {})
# fetch the signed keys/devices and make sure that the signatures are there # fetch the signed keys/devices and make sure that the signatures are there
ret = yield defer.ensureDeferred( ret = self.get_success(
self.handler.query_devices( self.handler.query_devices(
{"device_keys": {local_user: [], other_user: []}}, 0, local_user {"device_keys": {local_user: [], other_user: []}}, 0, local_user
) )

View File

@ -19,14 +19,9 @@ import copy
import mock import mock
from twisted.internet import defer from synapse.api.errors import SynapseError
import synapse.api.errors from tests import unittest
import synapse.handlers.e2e_room_keys
import synapse.storage
from synapse.api import errors
from tests import unittest, utils
# sample room_key data for use in the tests # sample room_key data for use in the tests
room_keys = { room_keys = {
@ -45,51 +40,39 @@ room_keys = {
} }
class E2eRoomKeysHandlerTestCase(unittest.TestCase): class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def __init__(self, *args, **kwargs): def make_homeserver(self, reactor, clock):
super().__init__(*args, **kwargs) return self.setup_test_homeserver(replication_layer=mock.Mock())
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
@defer.inlineCallbacks def prepare(self, reactor, clock, hs):
def setUp(self): self.handler = hs.get_e2e_room_keys_handler()
self.hs = yield utils.setup_test_homeserver( self.local_user = "@boris:" + hs.hostname
self.addCleanup, replication_layer=mock.Mock()
)
self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
self.local_user = "@boris:" + self.hs.hostname
@defer.inlineCallbacks
def test_get_missing_current_version_info(self): def test_get_missing_current_version_info(self):
"""Check that we get a 404 if we ask for info about the current version """Check that we get a 404 if we ask for info about the current version
if there is no version. if there is no version.
""" """
res = None e = self.get_failure(
try: self.handler.get_version_info(self.local_user), SynapseError
yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) )
except errors.SynapseError as e: res = e.value.code
res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@defer.inlineCallbacks
def test_get_missing_version_info(self): def test_get_missing_version_info(self):
"""Check that we get a 404 if we ask for info about a specific version """Check that we get a 404 if we ask for info about a specific version
if it doesn't exist. if it doesn't exist.
""" """
res = None e = self.get_failure(
try: self.handler.get_version_info(self.local_user, "bogus_version"),
yield defer.ensureDeferred( SynapseError,
self.handler.get_version_info(self.local_user, "bogus_version") )
) res = e.value.code
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@defer.inlineCallbacks
def test_create_version(self): def test_create_version(self):
"""Check that we can create and then retrieve versions. """Check that we can create and then retrieve versions.
""" """
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -101,7 +84,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, "1") self.assertEqual(res, "1")
# check we can retrieve it as the current version # check we can retrieve it as the current version
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) res = self.get_success(self.handler.get_version_info(self.local_user))
version_etag = res["etag"] version_etag = res["etag"]
self.assertIsInstance(version_etag, str) self.assertIsInstance(version_etag, str)
del res["etag"] del res["etag"]
@ -116,9 +99,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
# check we can retrieve it as a specific version # check we can retrieve it as a specific version
res = yield defer.ensureDeferred( res = self.get_success(self.handler.get_version_info(self.local_user, "1"))
self.handler.get_version_info(self.local_user, "1")
)
self.assertEqual(res["etag"], version_etag) self.assertEqual(res["etag"], version_etag)
del res["etag"] del res["etag"]
self.assertDictEqual( self.assertDictEqual(
@ -132,7 +113,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
# upload a new one... # upload a new one...
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -144,7 +125,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, "2") self.assertEqual(res, "2")
# check we can retrieve it as the current version # check we can retrieve it as the current version
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) res = self.get_success(self.handler.get_version_info(self.local_user))
del res["etag"] del res["etag"]
self.assertDictEqual( self.assertDictEqual(
res, res,
@ -156,11 +137,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
}, },
) )
@defer.inlineCallbacks
def test_update_version(self): def test_update_version(self):
"""Check that we can update versions. """Check that we can update versions.
""" """
version = yield defer.ensureDeferred( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -171,7 +151,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.update_version( self.handler.update_version(
self.local_user, self.local_user,
version, version,
@ -185,7 +165,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, {}) self.assertDictEqual(res, {})
# check we can retrieve it as the current version # check we can retrieve it as the current version
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) res = self.get_success(self.handler.get_version_info(self.local_user))
del res["etag"] del res["etag"]
self.assertDictEqual( self.assertDictEqual(
res, res,
@ -197,32 +177,28 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
}, },
) )
@defer.inlineCallbacks
def test_update_missing_version(self): def test_update_missing_version(self):
"""Check that we get a 404 on updating nonexistent versions """Check that we get a 404 on updating nonexistent versions
""" """
res = None e = self.get_failure(
try: self.handler.update_version(
yield defer.ensureDeferred( self.local_user,
self.handler.update_version( "1",
self.local_user, {
"1", "algorithm": "m.megolm_backup.v1",
{ "auth_data": "revised_first_version_auth_data",
"algorithm": "m.megolm_backup.v1", "version": "1",
"auth_data": "revised_first_version_auth_data", },
"version": "1", ),
}, SynapseError,
) )
) res = e.value.code
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@defer.inlineCallbacks
def test_update_omitted_version(self): def test_update_omitted_version(self):
"""Check that the update succeeds if the version is missing from the body """Check that the update succeeds if the version is missing from the body
""" """
version = yield defer.ensureDeferred( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -233,7 +209,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
yield defer.ensureDeferred( self.get_success(
self.handler.update_version( self.handler.update_version(
self.local_user, self.local_user,
version, version,
@ -245,7 +221,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
# check we can retrieve it as the current version # check we can retrieve it as the current version
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) res = self.get_success(self.handler.get_version_info(self.local_user))
del res["etag"] # etag is opaque, so don't test its contents del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual( self.assertDictEqual(
res, res,
@ -257,11 +233,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
}, },
) )
@defer.inlineCallbacks
def test_update_bad_version(self): def test_update_bad_version(self):
"""Check that we get a 400 if the version in the body doesn't match """Check that we get a 400 if the version in the body doesn't match
""" """
version = yield defer.ensureDeferred( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -272,52 +247,41 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
res = None e = self.get_failure(
try: self.handler.update_version(
yield defer.ensureDeferred( self.local_user,
self.handler.update_version( version,
self.local_user, {
version, "algorithm": "m.megolm_backup.v1",
{ "auth_data": "revised_first_version_auth_data",
"algorithm": "m.megolm_backup.v1", "version": "incorrect",
"auth_data": "revised_first_version_auth_data", },
"version": "incorrect", ),
}, SynapseError,
) )
) res = e.value.code
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 400) self.assertEqual(res, 400)
@defer.inlineCallbacks
def test_delete_missing_version(self): def test_delete_missing_version(self):
"""Check that we get a 404 on deleting nonexistent versions """Check that we get a 404 on deleting nonexistent versions
""" """
res = None e = self.get_failure(
try: self.handler.delete_version(self.local_user, "1"), SynapseError
yield defer.ensureDeferred( )
self.handler.delete_version(self.local_user, "1") res = e.value.code
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@defer.inlineCallbacks
def test_delete_missing_current_version(self): def test_delete_missing_current_version(self):
"""Check that we get a 404 on deleting nonexistent current version """Check that we get a 404 on deleting nonexistent current version
""" """
res = None e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
try: res = e.value.code
yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@defer.inlineCallbacks
def test_delete_version(self): def test_delete_version(self):
"""Check that we can create and then delete versions. """Check that we can create and then delete versions.
""" """
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -329,36 +293,28 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(res, "1") self.assertEqual(res, "1")
# check we can delete it # check we can delete it
yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1")) self.get_success(self.handler.delete_version(self.local_user, "1"))
# check that it's gone # check that it's gone
res = None e = self.get_failure(
try: self.handler.get_version_info(self.local_user, "1"), SynapseError
yield defer.ensureDeferred( )
self.handler.get_version_info(self.local_user, "1") res = e.value.code
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@defer.inlineCallbacks
def test_get_missing_backup(self): def test_get_missing_backup(self):
"""Check that we get a 404 on querying missing backup """Check that we get a 404 on querying missing backup
""" """
res = None e = self.get_failure(
try: self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
yield defer.ensureDeferred( )
self.handler.get_room_keys(self.local_user, "bogus_version") res = e.value.code
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@defer.inlineCallbacks
def test_get_missing_room_keys(self): def test_get_missing_room_keys(self):
"""Check we get an empty response from an empty backup """Check we get an empty response from an empty backup
""" """
version = yield defer.ensureDeferred( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -369,33 +325,27 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
res = yield defer.ensureDeferred( res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.handler.get_room_keys(self.local_user, version)
)
self.assertDictEqual(res, {"rooms": {}}) self.assertDictEqual(res, {"rooms": {}})
# TODO: test the locking semantics when uploading room_keys, # TODO: test the locking semantics when uploading room_keys,
# although this is probably best done in sytest # although this is probably best done in sytest
@defer.inlineCallbacks
def test_upload_room_keys_no_versions(self): def test_upload_room_keys_no_versions(self):
"""Check that we get a 404 on uploading keys when no versions are defined """Check that we get a 404 on uploading keys when no versions are defined
""" """
res = None e = self.get_failure(
try: self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
yield defer.ensureDeferred( SynapseError,
self.handler.upload_room_keys(self.local_user, "no_version", room_keys) )
) res = e.value.code
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@defer.inlineCallbacks
def test_upload_room_keys_bogus_version(self): def test_upload_room_keys_bogus_version(self):
"""Check that we get a 404 on uploading keys when an nonexistent version """Check that we get a 404 on uploading keys when an nonexistent version
is specified is specified
""" """
version = yield defer.ensureDeferred( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -406,22 +356,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
res = None e = self.get_failure(
try: self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys),
yield defer.ensureDeferred( SynapseError,
self.handler.upload_room_keys( )
self.local_user, "bogus_version", room_keys res = e.value.code
)
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404) self.assertEqual(res, 404)
@defer.inlineCallbacks
def test_upload_room_keys_wrong_version(self): def test_upload_room_keys_wrong_version(self):
"""Check that we get a 403 on uploading keys for an old version """Check that we get a 403 on uploading keys for an old version
""" """
version = yield defer.ensureDeferred( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -432,7 +377,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
version = yield defer.ensureDeferred( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -443,20 +388,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(version, "2") self.assertEqual(version, "2")
res = None e = self.get_failure(
try: self.handler.upload_room_keys(self.local_user, "1", room_keys), SynapseError
yield defer.ensureDeferred( )
self.handler.upload_room_keys(self.local_user, "1", room_keys) res = e.value.code
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 403) self.assertEqual(res, 403)
@defer.inlineCallbacks
def test_upload_room_keys_insert(self): def test_upload_room_keys_insert(self):
"""Check that we can insert and retrieve keys for a session """Check that we can insert and retrieve keys for a session
""" """
version = yield defer.ensureDeferred( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -467,17 +408,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
yield defer.ensureDeferred( self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys) self.handler.upload_room_keys(self.local_user, version, room_keys)
) )
res = yield defer.ensureDeferred( res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.handler.get_room_keys(self.local_user, version)
)
self.assertDictEqual(res, room_keys) self.assertDictEqual(res, room_keys)
# check getting room_keys for a given room # check getting room_keys for a given room
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.get_room_keys( self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org" self.local_user, version, room_id="!abc:matrix.org"
) )
@ -485,18 +424,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, room_keys) self.assertDictEqual(res, room_keys)
# check getting room_keys for a given session_id # check getting room_keys for a given session_id
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.get_room_keys( self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
) )
) )
self.assertDictEqual(res, room_keys) self.assertDictEqual(res, room_keys)
@defer.inlineCallbacks
def test_upload_room_keys_merge(self): def test_upload_room_keys_merge(self):
"""Check that we can upload a new room_key for an existing session and """Check that we can upload a new room_key for an existing session and
have it correctly merged""" have it correctly merged"""
version = yield defer.ensureDeferred( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -507,12 +445,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
) )
self.assertEqual(version, "1") self.assertEqual(version, "1")
yield defer.ensureDeferred( self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys) self.handler.upload_room_keys(self.local_user, version, room_keys)
) )
# get the etag to compare to future versions # get the etag to compare to future versions
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) res = self.get_success(self.handler.get_version_info(self.local_user))
backup_etag = res["etag"] backup_etag = res["etag"]
self.assertEqual(res["count"], 1) self.assertEqual(res["count"], 1)
@ -522,37 +460,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# test that increasing the message_index doesn't replace the existing session # test that increasing the message_index doesn't replace the existing session
new_room_key["first_message_index"] = 2 new_room_key["first_message_index"] = 2
new_room_key["session_data"] = "new" new_room_key["session_data"] = "new"
yield defer.ensureDeferred( self.get_success(
self.handler.upload_room_keys(self.local_user, version, new_room_keys) self.handler.upload_room_keys(self.local_user, version, new_room_keys)
) )
res = yield defer.ensureDeferred( res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual( self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK", "SSBBTSBBIEZJU0gK",
) )
# the etag should be the same since the session did not change # the etag should be the same since the session did not change
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) res = self.get_success(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag) self.assertEqual(res["etag"], backup_etag)
# test that marking the session as verified however /does/ replace it # test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True new_room_key["is_verified"] = True
yield defer.ensureDeferred( self.get_success(
self.handler.upload_room_keys(self.local_user, version, new_room_keys) self.handler.upload_room_keys(self.local_user, version, new_room_keys)
) )
res = yield defer.ensureDeferred( res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual( self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
) )
# the etag should NOT be equal now, since the key changed # the etag should NOT be equal now, since the key changed
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) res = self.get_success(self.handler.get_version_info(self.local_user))
self.assertNotEqual(res["etag"], backup_etag) self.assertNotEqual(res["etag"], backup_etag)
backup_etag = res["etag"] backup_etag = res["etag"]
@ -560,28 +494,25 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# with a lower forwarding count # with a lower forwarding count
new_room_key["forwarded_count"] = 2 new_room_key["forwarded_count"] = 2
new_room_key["session_data"] = "other" new_room_key["session_data"] = "other"
yield defer.ensureDeferred( self.get_success(
self.handler.upload_room_keys(self.local_user, version, new_room_keys) self.handler.upload_room_keys(self.local_user, version, new_room_keys)
) )
res = yield defer.ensureDeferred( res = self.get_success(self.handler.get_room_keys(self.local_user, version))
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual( self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new" res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
) )
# the etag should be the same since the session did not change # the etag should be the same since the session did not change
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user)) res = self.get_success(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag) self.assertEqual(res["etag"], backup_etag)
# TODO: check edge cases as well as the common variations here # TODO: check edge cases as well as the common variations here
@defer.inlineCallbacks
def test_delete_room_keys(self): def test_delete_room_keys(self):
"""Check that we can insert and delete keys for a session """Check that we can insert and delete keys for a session
""" """
version = yield defer.ensureDeferred( version = self.get_success(
self.handler.create_version( self.handler.create_version(
self.local_user, self.local_user,
{ {
@ -593,13 +524,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertEqual(version, "1") self.assertEqual(version, "1")
# check for bulk-delete # check for bulk-delete
yield defer.ensureDeferred( self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys) self.handler.upload_room_keys(self.local_user, version, room_keys)
) )
yield defer.ensureDeferred( self.get_success(self.handler.delete_room_keys(self.local_user, version))
self.handler.delete_room_keys(self.local_user, version) res = self.get_success(
)
res = yield defer.ensureDeferred(
self.handler.get_room_keys( self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
) )
@ -607,15 +536,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, {"rooms": {}}) self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per room # check for bulk-delete per room
yield defer.ensureDeferred( self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys) self.handler.upload_room_keys(self.local_user, version, room_keys)
) )
yield defer.ensureDeferred( self.get_success(
self.handler.delete_room_keys( self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org" self.local_user, version, room_id="!abc:matrix.org"
) )
) )
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.get_room_keys( self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
) )
@ -623,15 +552,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
self.assertDictEqual(res, {"rooms": {}}) self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per session # check for bulk-delete per session
yield defer.ensureDeferred( self.get_success(
self.handler.upload_room_keys(self.local_user, version, room_keys) self.handler.upload_room_keys(self.local_user, version, room_keys)
) )
yield defer.ensureDeferred( self.get_success(
self.handler.delete_room_keys( self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
) )
) )
res = yield defer.ensureDeferred( res = self.get_success(
self.handler.get_room_keys( self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33" self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
) )

View File

@ -13,25 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from mock import Mock from mock import Mock
from twisted.internet import defer
import synapse.types import synapse.types
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
class ProfileTestCase(unittest.TestCase): class ProfileTestCase(unittest.HomeserverTestCase):
""" Tests profile management. """ """ Tests profile management. """
@defer.inlineCallbacks def make_homeserver(self, reactor, clock):
def setUp(self):
self.mock_federation = Mock() self.mock_federation = Mock()
self.mock_registry = Mock() self.mock_registry = Mock()
@ -42,39 +37,35 @@ class ProfileTestCase(unittest.TestCase):
self.mock_registry.register_query_handler = register_query_handler self.mock_registry.register_query_handler = register_query_handler
hs = yield setup_test_homeserver( hs = self.setup_test_homeserver(
self.addCleanup,
federation_client=self.mock_federation, federation_client=self.mock_federation,
federation_server=Mock(), federation_server=Mock(),
federation_registry=self.mock_registry, federation_registry=self.mock_registry,
) )
return hs
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.frank = UserID.from_string("@1234ABCD:test") self.frank = UserID.from_string("@1234ABCD:test")
self.bob = UserID.from_string("@4567:test") self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote") self.alice = UserID.from_string("@alice:remote")
yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart)) self.get_success(self.store.create_profile(self.frank.localpart))
self.handler = hs.get_profile_handler() self.handler = hs.get_profile_handler()
self.hs = hs
@defer.inlineCallbacks
def test_get_my_name(self): def test_get_my_name(self):
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank") self.store.set_profile_displayname(self.frank.localpart, "Frank")
) )
displayname = yield defer.ensureDeferred( displayname = self.get_success(self.handler.get_displayname(self.frank))
self.handler.get_displayname(self.frank)
)
self.assertEquals("Frank", displayname) self.assertEquals("Frank", displayname)
@defer.inlineCallbacks
def test_set_my_name(self): def test_set_my_name(self):
yield defer.ensureDeferred( self.get_success(
self.handler.set_displayname( self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr." self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
) )
@ -82,7 +73,7 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_displayname(self.frank.localpart) self.store.get_profile_displayname(self.frank.localpart)
) )
), ),
@ -90,7 +81,7 @@ class ProfileTestCase(unittest.TestCase):
) )
# Set displayname again # Set displayname again
yield defer.ensureDeferred( self.get_success(
self.handler.set_displayname( self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank" self.frank, synapse.types.create_requester(self.frank), "Frank"
) )
@ -98,7 +89,7 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_displayname(self.frank.localpart) self.store.get_profile_displayname(self.frank.localpart)
) )
), ),
@ -106,32 +97,27 @@ class ProfileTestCase(unittest.TestCase):
) )
# Set displayname to an empty string # Set displayname to an empty string
yield defer.ensureDeferred( self.get_success(
self.handler.set_displayname( self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "" self.frank, synapse.types.create_requester(self.frank), ""
) )
) )
self.assertIsNone( self.assertIsNone(
( (self.get_success(self.store.get_profile_displayname(self.frank.localpart)))
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
)
) )
@defer.inlineCallbacks
def test_set_my_name_if_disabled(self): def test_set_my_name_if_disabled(self):
self.hs.config.enable_set_displayname = False self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed # Setting displayname for the first time is allowed
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_displayname(self.frank.localpart, "Frank") self.store.set_profile_displayname(self.frank.localpart, "Frank")
) )
self.assertEquals( self.assertEquals(
( (
yield defer.ensureDeferred( self.get_success(
self.store.get_profile_displayname(self.frank.localpart) self.store.get_profile_displayname(self.frank.localpart)
) )
), ),
@ -139,33 +125,27 @@ class ProfileTestCase(unittest.TestCase):
) )
# Setting displayname a second time is forbidden # Setting displayname a second time is forbidden
d = defer.ensureDeferred( self.get_failure(
self.handler.set_displayname( self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr." self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
) ),
SynapseError,
) )
yield self.assertFailure(d, SynapseError)
@defer.inlineCallbacks
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self):
d = defer.ensureDeferred( self.get_failure(
self.handler.set_displayname( self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr." self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
) ),
AuthError,
) )
yield self.assertFailure(d, AuthError)
@defer.inlineCallbacks
def test_get_other_name(self): def test_get_other_name(self):
self.mock_federation.make_query.return_value = make_awaitable( self.mock_federation.make_query.return_value = make_awaitable(
{"displayname": "Alice"} {"displayname": "Alice"}
) )
displayname = yield defer.ensureDeferred( displayname = self.get_success(self.handler.get_displayname(self.alice))
self.handler.get_displayname(self.alice)
)
self.assertEquals(displayname, "Alice") self.assertEquals(displayname, "Alice")
self.mock_federation.make_query.assert_called_with( self.mock_federation.make_query.assert_called_with(
@ -175,14 +155,11 @@ class ProfileTestCase(unittest.TestCase):
ignore_backoff=True, ignore_backoff=True,
) )
@defer.inlineCallbacks
def test_incoming_fed_query(self): def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline")) self.get_success(self.store.create_profile("caroline"))
yield defer.ensureDeferred( self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
self.store.set_profile_displayname("caroline", "Caroline")
)
response = yield defer.ensureDeferred( response = self.get_success(
self.query_handlers["profile"]( self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"} {"user_id": "@caroline:test", "field": "displayname"}
) )
@ -190,20 +167,18 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals({"displayname": "Caroline"}, response) self.assertEquals({"displayname": "Caroline"}, response)
@defer.inlineCallbacks
def test_get_my_avatar(self): def test_get_my_avatar(self):
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_avatar_url( self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png" self.frank.localpart, "http://my.server/me.png"
) )
) )
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank)) avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
self.assertEquals("http://my.server/me.png", avatar_url) self.assertEquals("http://my.server/me.png", avatar_url)
@defer.inlineCallbacks
def test_set_my_avatar(self): def test_set_my_avatar(self):
yield defer.ensureDeferred( self.get_success(
self.handler.set_avatar_url( self.handler.set_avatar_url(
self.frank, self.frank,
synapse.types.create_requester(self.frank), synapse.types.create_requester(self.frank),
@ -212,16 +187,12 @@ class ProfileTestCase(unittest.TestCase):
) )
self.assertEquals( self.assertEquals(
( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/pic.gif", "http://my.server/pic.gif",
) )
# Set avatar again # Set avatar again
yield defer.ensureDeferred( self.get_success(
self.handler.set_avatar_url( self.handler.set_avatar_url(
self.frank, self.frank,
synapse.types.create_requester(self.frank), synapse.types.create_requester(self.frank),
@ -230,56 +201,42 @@ class ProfileTestCase(unittest.TestCase):
) )
self.assertEquals( self.assertEquals(
( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/me.png", "http://my.server/me.png",
) )
# Set avatar to an empty string # Set avatar to an empty string
yield defer.ensureDeferred( self.get_success(
self.handler.set_avatar_url( self.handler.set_avatar_url(
self.frank, synapse.types.create_requester(self.frank), "", self.frank, synapse.types.create_requester(self.frank), "",
) )
) )
self.assertIsNone( self.assertIsNone(
( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
) )
@defer.inlineCallbacks
def test_set_my_avatar_if_disabled(self): def test_set_my_avatar_if_disabled(self):
self.hs.config.enable_set_avatar_url = False self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed # Setting displayname for the first time is allowed
yield defer.ensureDeferred( self.get_success(
self.store.set_profile_avatar_url( self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png" self.frank.localpart, "http://my.server/me.png"
) )
) )
self.assertEquals( self.assertEquals(
( (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/me.png", "http://my.server/me.png",
) )
# Set avatar a second time is forbidden # Set avatar a second time is forbidden
d = defer.ensureDeferred( self.get_failure(
self.handler.set_avatar_url( self.handler.set_avatar_url(
self.frank, self.frank,
synapse.types.create_requester(self.frank), synapse.types.create_requester(self.frank),
"http://my.server/pic.gif", "http://my.server/pic.gif",
) ),
SynapseError,
) )
yield self.assertFailure(d, SynapseError)

View File

@ -18,8 +18,6 @@
from mock import Mock from mock import Mock
from twisted.internet import defer
from synapse.rest.client.v1 import room from synapse.rest.client.v1 import room
from synapse.types import UserID from synapse.types import UserID
@ -60,32 +58,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_datastore().insert_client_ip = _insert_client_ip hs.get_datastore().insert_client_ip = _insert_client_ip
def get_room_members(room_id):
if room_id == self.room_id:
return defer.succeed([self.user])
else:
return defer.succeed([])
@defer.inlineCallbacks
def fetch_room_distributions_into(
room_id, localusers=None, remotedomains=None, ignore_user=None
):
members = yield get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
hs.get_room_member_handler().fetch_room_distributions_into = (
fetch_room_distributions_into
)
return hs return hs
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):