diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 546050f8e..7bc174070 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -327,20 +327,35 @@ class E2eRoomKeysHandler(object): Returns: A deferred of an empty dict. """ - try: - old_info = yield self.store.get_e2e_room_keys_version_info(user_id, version) - except StoreError as e: - if e.code == 404: - raise NotFoundError("Unknown backup version") - else: - raise - if old_info["algorithm"] != version_info["algorithm"]: + if "version" not in version_info: raise SynapseError( 400, - "Algorithm does not match", + "Missing version in body", + Codes.MISSING_PARAM + ) + if version_info["version"] != version: + raise SynapseError( + 400, + "Version in body does not match", Codes.INVALID_PARAM ) + with (yield self._upload_linearizer.queue(user_id)): + try: + old_info = yield self.store.get_e2e_room_keys_version_info( + user_id, version + ) + except StoreError as e: + if e.code == 404: + raise NotFoundError("Unknown backup version") + else: + raise + if old_info["algorithm"] != version_info["algorithm"]: + raise SynapseError( + 400, + "Algorithm does not match", + Codes.INVALID_PARAM + ) - yield self.store.update_e2e_room_keys_version(user_id, version, version_info) + yield self.store.update_e2e_room_keys_version(user_id, version, version_info) - defer.returnValue({}) + defer.returnValue({}) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 1c39d2af1..220a0de30 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -394,7 +394,8 @@ class RoomKeysVersionServlet(RestServlet): "signatures": { "ed25519:something": "hijklmnop" } - } + }, + "version": "42" } HTTP/1.1 200 OK diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index c8994f416..1c49bbbc3 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -125,6 +125,78 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase): "auth_data": "second_version_auth_data", }) + @defer.inlineCallbacks + def test_update_version(self): + """Check that we can update versions. + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + res = yield self.handler.update_version(self.local_user, version, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version + }) + self.assertDictEqual(res, {}) + + # check we can retrieve it as the current version + res = yield self.handler.get_version_info(self.local_user) + self.assertDictEqual(res, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": version + }) + + @defer.inlineCallbacks + def test_update_missing_version(self): + """Check that we get a 404 on updating nonexistent versions + """ + res = None + try: + yield self.handler.update_version(self.local_user, "1", { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "1" + }) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 404) + + @defer.inlineCallbacks + def test_update_bad_version(self): + """Check that we get a 400 if the version in the body is missing or + doesn't match + """ + version = yield self.handler.create_version(self.local_user, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "first_version_auth_data", + }) + self.assertEqual(version, "1") + + res = None + try: + yield self.handler.update_version(self.local_user, version, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data" + }) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 400) + + res = None + try: + yield self.handler.update_version(self.local_user, version, { + "algorithm": "m.megolm_backup.v1", + "auth_data": "revised_first_version_auth_data", + "version": "incorrect" + }) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 400) + @defer.inlineCallbacks def test_delete_missing_version(self): """Check that we get a 404 on deleting nonexistent versions