Merge pull request #7157 from matrix-org/rev.outbound_device_pokes_tests

Add tests for outbound device pokes
This commit is contained in:
Richard van der Hoff 2020-03-30 13:59:07 +01:00 committed by GitHub
commit 6486c96b65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 318 additions and 22 deletions

1
changelog.d/7157.misc Normal file
View File

@ -0,0 +1 @@
Add tests for outbound device pokes.

View File

@ -27,8 +27,8 @@ class FrontendProxyTests(HomeserverTestCase):
return hs return hs
def default_config(self, name="test"): def default_config(self):
c = super().default_config(name) c = super().default_config()
c["worker_app"] = "synapse.app.frontend_proxy" c["worker_app"] = "synapse.app.frontend_proxy"
return c return c

View File

@ -29,8 +29,8 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
) )
return hs return hs
def default_config(self, name="test"): def default_config(self):
conf = super().default_config(name) conf = super().default_config()
# we're using FederationReaderServer, which uses a SlavedStore, so we # we're using FederationReaderServer, which uses a SlavedStore, so we
# have to tell the FederationHandler not to try to access stuff that is only # have to tell the FederationHandler not to try to access stuff that is only
# in the primary store. # in the primary store.

View File

@ -33,8 +33,8 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
login.register_servlets, login.register_servlets,
] ]
def default_config(self, name="test"): def default_config(self):
config = super().default_config(name=name) config = super().default_config()
config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05} config["limit_remote_rooms"] = {"enabled": True, "complexity": 0.05}
return config return config

View File

@ -12,19 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from mock import Mock from mock import Mock
from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
from twisted.internet import defer from twisted.internet import defer
from synapse.types import ReadReceipt from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
class FederationSenderTestCases(HomeserverTestCase): class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
return super(FederationSenderTestCases, self).setup_test_homeserver( return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]), state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]), federation_transport_client=Mock(spec=["send_transaction"]),
) )
@ -147,3 +153,294 @@ class FederationSenderTestCases(HomeserverTestCase):
} }
], ],
) )
class FederationSenderDevicesTestCases(HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
]
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
def default_config(self):
c = super().default_config()
c["send_federation"] = True
return c
def prepare(self, reactor, clock, hs):
# stub out get_current_hosts_in_room
mock_state_handler = hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
# stub out get_users_who_share_room_with_user so that it claims that
# `@user2:host2` is in the room
def get_users_who_share_room_with_user(user_id):
return defer.succeed({"@user2:host2"})
hs.get_datastore().get_users_who_share_room_with_user = (
get_users_who_share_room_with_user
)
# whenever send_transaction is called, record the edu data
self.edus = []
self.hs.get_federation_transport_client().send_transaction.side_effect = (
self.record_transaction
)
def record_transaction(self, txn, json_cb):
data = json_cb()
self.edus.extend(data["edus"])
return defer.succeed({})
def test_send_device_updates(self):
"""Basic case: each device update should result in an EDU"""
# create a device
u1 = self.register_user("user", "pass")
self.login(u1, "pass", device_id="D1")
# expect one edu
self.assertEqual(len(self.edus), 1)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
# a second call should produce no new device EDUs
self.hs.get_federation_sender().send_device_messages("host2")
self.pump()
self.assertEqual(self.edus, [])
# a second device
self.login("user", "pass", device_id="D2")
self.assertEqual(len(self.edus), 1)
self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
def test_upload_signatures(self):
"""Uploading signatures on some devices should produce updates for that user"""
e2e_handler = self.hs.get_e2e_keys_handler()
# register two devices
u1 = self.register_user("user", "pass")
self.login(u1, "pass", device_id="D1")
self.login(u1, "pass", device_id="D2")
# expect two edus
self.assertEqual(len(self.edus), 2)
stream_id = None
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
# upload signing keys for each device
device1_signing_key = self.generate_and_upload_device_signing_key(u1, "D1")
device2_signing_key = self.generate_and_upload_device_signing_key(u1, "D2")
# expect two more edus
self.assertEqual(len(self.edus), 2)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
# upload master key and self-signing key
master_signing_key = generate_self_id_key()
master_key = {
"user_id": u1,
"usage": ["master"],
"keys": {key_id(master_signing_key): encode_pubkey(master_signing_key)},
}
# private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8
selfsigning_signing_key = generate_self_id_key()
selfsigning_key = {
"user_id": u1,
"usage": ["self_signing"],
"keys": {
key_id(selfsigning_signing_key): encode_pubkey(selfsigning_signing_key)
},
}
sign.sign_json(selfsigning_key, u1, master_signing_key)
cross_signing_keys = {
"master_key": master_key,
"self_signing_key": selfsigning_key,
}
self.get_success(
e2e_handler.upload_signing_keys_for_user(u1, cross_signing_keys)
)
# expect signing key update edu
self.assertEqual(len(self.edus), 1)
self.assertEqual(self.edus.pop(0)["edu_type"], "org.matrix.signing_key_update")
# sign the devices
d1_json = build_device_dict(u1, "D1", device1_signing_key)
sign.sign_json(d1_json, u1, selfsigning_signing_key)
d2_json = build_device_dict(u1, "D2", device2_signing_key)
sign.sign_json(d2_json, u1, selfsigning_signing_key)
ret = self.get_success(
e2e_handler.upload_signatures_for_device_keys(
u1, {u1: {"D1": d1_json, "D2": d2_json}},
)
)
self.assertEqual(ret["failures"], {})
# expect two edus, in one or two transactions. We don't know what order the
# devices will be updated.
self.assertEqual(len(self.edus), 2)
stream_id = None # FIXME: there is a discontinuity in the stream IDs: see #7142
for edu in self.edus:
self.assertEqual(edu["edu_type"], "m.device_list_update")
c = edu["content"]
if stream_id is not None:
self.assertEqual(c["prev_id"], [stream_id])
stream_id = c["stream_id"]
devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2"}, devices)
def test_delete_devices(self):
"""If devices are deleted, that should result in EDUs too"""
# create devices
u1 = self.register_user("user", "pass")
self.login("user", "pass", device_id="D1")
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
# expect three edus
self.assertEqual(len(self.edus), 3)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
# delete them again
self.get_success(
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
# expect three edus, in an unknown order
self.assertEqual(len(self.edus), 3)
for edu in self.edus:
self.assertEqual(edu["edu_type"], "m.device_list_update")
c = edu["content"]
self.assertGreaterEqual(
c.items(),
{"user_id": u1, "prev_id": [stream_id], "deleted": True}.items(),
)
stream_id = c["stream_id"]
devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2", "D3"}, devices)
def test_unreachable_server(self):
"""If the destination server is unreachable, all the updates should get sent on
recovery
"""
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail("fail")
# create devices
u1 = self.register_user("user", "pass")
self.login("user", "pass", device_id="D1")
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
# delete them again
self.get_success(
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
self.assertGreaterEqual(mock_send_txn.call_count, 4)
# recover the server
mock_send_txn.side_effect = self.record_transaction
self.hs.get_federation_sender().send_device_messages("host2")
self.pump()
# for each device, there should be a single update
self.assertEqual(len(self.edus), 3)
stream_id = None
for edu in self.edus:
self.assertEqual(edu["edu_type"], "m.device_list_update")
c = edu["content"]
self.assertEqual(c["prev_id"], [stream_id] if stream_id is not None else [])
stream_id = c["stream_id"]
devices = {edu["content"]["device_id"] for edu in self.edus}
self.assertEqual({"D1", "D2", "D3"}, devices)
def check_device_update_edu(
self,
edu: JsonDict,
user_id: str,
device_id: str,
prev_stream_id: Optional[int],
) -> int:
"""Check that the given EDU is an update for the given device
Returns the stream_id.
"""
self.assertEqual(edu["edu_type"], "m.device_list_update")
content = edu["content"]
expected = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_stream_id] if prev_stream_id is not None else [],
}
self.assertLessEqual(expected.items(), content.items())
return content["stream_id"]
def check_signing_key_update_txn(self, txn: JsonDict,) -> None:
"""Check that the txn has an EDU with a signing key update.
"""
edus = txn["edus"]
self.assertEqual(len(edus), 1)
def generate_and_upload_device_signing_key(
self, user_id: str, device_id: str
) -> SigningKey:
"""Generate a signing keypair for the given device, and upload it"""
sk = key.generate_signing_key(device_id)
device_dict = build_device_dict(user_id, device_id, sk)
self.get_success(
self.hs.get_e2e_keys_handler().upload_keys_for_user(
user_id, device_id, {"device_keys": device_dict},
)
)
return sk
def generate_self_id_key() -> SigningKey:
"""generate a signing key whose version is its public key
... as used by the cross-signing-keys.
"""
k = key.generate_signing_key("x")
k.version = encode_pubkey(k)
return k
def key_id(k: BaseKey) -> str:
return "%s:%s" % (k.alg, k.version)
def encode_pubkey(sk: SigningKey) -> str:
"""Encode the public key corresponding to the given signing key as base64"""
return key.encode_verify_key_base64(key.get_verify_key(sk))
def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
"""Build a dict representing the given device"""
return {
"user_id": user_id,
"device_id": device_id,
"algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"],
"keys": {
"curve25519:" + device_id: "curve25519+key",
key_id(sk): encode_pubkey(sk),
},
}

View File

@ -34,7 +34,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """ """ Tests the RegistrationHandler. """
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs_config = self.default_config("test") hs_config = self.default_config()
# some of the tests rely on us having a user consent version # some of the tests rely on us having a user consent version
hs_config["user_consent"] = { hs_config["user_consent"] = {

View File

@ -36,8 +36,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets] servlets = [register.register_servlets]
url = b"/_matrix/client/r0/register" url = b"/_matrix/client/r0/register"
def default_config(self, name="test"): def default_config(self):
config = super().default_config(name) config = super().default_config()
config["allow_guest_access"] = True config["allow_guest_access"] = True
return config return config

View File

@ -143,8 +143,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
endpoint, to check that the two implementations are compatible. endpoint, to check that the two implementations are compatible.
""" """
def default_config(self, *args, **kwargs): def default_config(self):
config = super().default_config(*args, **kwargs) config = super().default_config()
# replace the signing key with our own # replace the signing key with our own
self.hs_signing_key = signedjson.key.generate_signing_key("kssk") self.hs_signing_key = signedjson.key.generate_signing_key("kssk")

View File

@ -28,7 +28,7 @@ from tests import unittest
class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs_config = self.default_config("test") hs_config = self.default_config()
hs_config["server_notices"] = { hs_config["server_notices"] = {
"system_mxid_localpart": "server", "system_mxid_localpart": "server",
"system_mxid_display_name": "test display name", "system_mxid_display_name": "test display name",

View File

@ -28,8 +28,8 @@ from tests import unittest
class TermsTestCase(unittest.HomeserverTestCase): class TermsTestCase(unittest.HomeserverTestCase):
servlets = [register_servlets] servlets = [register_servlets]
def default_config(self, name="test"): def default_config(self):
config = super().default_config(name) config = super().default_config()
config.update( config.update(
{ {
"public_baseurl": "https://example.org/", "public_baseurl": "https://example.org/",

View File

@ -315,14 +315,11 @@ class HomeserverTestCase(TestCase):
return resource return resource
def default_config(self, name="test"): def default_config(self):
""" """
Get a default HomeServer config dict. Get a default HomeServer config dict.
Args:
name (str): The homeserver name/domain.
""" """
config = default_config(name) config = default_config("test")
# apply any additional config which was specified via the override_config # apply any additional config which was specified via the override_config
# decorator. # decorator.
@ -497,6 +494,7 @@ class HomeserverTestCase(TestCase):
"password": password, "password": password,
"admin": admin, "admin": admin,
"mac": want_mac, "mac": want_mac,
"inhibit_login": True,
} }
) )
request, channel = self.make_request( request, channel = self.make_request(