Add Azure KMS unit tests

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-03-24 18:00:17 +01:00 committed by Daniel Weiße
parent 19bb65338d
commit fefff8ee92
4 changed files with 494 additions and 9 deletions

View File

@ -31,9 +31,14 @@ const (
// VaultSuffix is the suffix added to a Vault name to create a valid Vault URL. // VaultSuffix is the suffix added to a Vault name to create a valid Vault URL.
type VaultSuffix string type VaultSuffix string
type kmsClientAPI interface {
SetSecret(ctx context.Context, secretName string, value string, options *azsecrets.SetSecretOptions) (azsecrets.SetSecretResponse, error)
GetSecret(ctx context.Context, secretName string, options *azsecrets.GetSecretOptions) (azsecrets.GetSecretResponse, error)
}
// KMSClient implements the CloudKMS interface for Azure Key Vault. // KMSClient implements the CloudKMS interface for Azure Key Vault.
type KMSClient struct { type KMSClient struct {
client *azsecrets.Client client kmsClientAPI
storage kms.Storage storage kms.Storage
opts *Opts opts *Opts
} }

153
kms/kms/azure/azure_test.go Normal file
View File

@ -0,0 +1,153 @@
package azure
import (
"context"
"encoding/base64"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azsecrets"
"github.com/edgelesssys/constellation/kms/kms"
"github.com/edgelesssys/constellation/kms/storage"
"github.com/stretchr/testify/assert"
)
type stubAzureClient struct {
setSecretCalled bool
setSecretErr error
getSecretErr error
secret []byte
}
func (s *stubAzureClient) SetSecret(ctx context.Context, secretName string, value string, options *azsecrets.SetSecretOptions) (azsecrets.SetSecretResponse, error) {
s.setSecretCalled = true
return azsecrets.SetSecretResponse{}, s.setSecretErr
}
func (s *stubAzureClient) GetSecret(ctx context.Context, secretName string, options *azsecrets.GetSecretOptions) (azsecrets.GetSecretResponse, error) {
return azsecrets.GetSecretResponse{
Secret: azsecrets.Secret{Value: to.StringPtr(base64.StdEncoding.EncodeToString(s.secret))},
}, s.getSecretErr
}
func TestKMSCreateKEK(t *testing.T) {
someErr := errors.New("error")
importKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
testCases := map[string]struct {
client *stubAzureClient
importKey []byte
errExpected bool
}{
"create new kek successful": {
client: &stubAzureClient{},
},
"import kek successful": {
client: &stubAzureClient{},
importKey: importKey,
},
"SetSecret fails on new": {
client: &stubAzureClient{setSecretErr: someErr},
errExpected: true,
},
"SetSecret fails on import": {
client: &stubAzureClient{setSecretErr: someErr},
importKey: importKey,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &KMSClient{
client: tc.client,
opts: &Opts{},
}
err := client.CreateKEK(context.Background(), "test-key", tc.importKey)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.True(tc.client.setSecretCalled)
}
})
}
}
func TestKMSGetDEK(t *testing.T) {
someErr := errors.New("error")
wrapKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
testCases := map[string]struct {
client kmsClientAPI
storage kms.Storage
errExpected bool
}{
"successful for new key": {
client: &stubAzureClient{secret: wrapKey},
storage: storage.NewMemMapStorage(),
},
"successful for existing key": {
// test keys taken from `kms/util/crypto_test.go`
client: &stubAzureClient{secret: []byte{0xD6, 0x8A, 0xED, 0xF5, 0xDB, 0x89, 0x95, 0x66, 0xA9, 0xFF, 0xD9, 0x31, 0x27, 0x4E, 0x30, 0x2D, 0x21, 0xA9, 0x46, 0x21, 0x16, 0x6C, 0x16, 0x17, 0xD1, 0x96, 0x5D, 0xB2, 0xE9, 0x0E, 0x96, 0xD1}},
storage: &stubStorage{key: []byte{0x14, 0x48, 0xC4, 0xEA, 0x4B, 0x4B, 0xCA, 0xE4, 0x5A, 0xD4, 0xCC, 0xE3, 0xF7, 0xDD, 0xD5, 0x78, 0xA5, 0xA9, 0xEF, 0x9A, 0x93, 0x36, 0x09, 0xD6, 0x23, 0x01, 0xF5, 0x5F, 0xE1, 0x20, 0xDD, 0xFC, 0xBC, 0xF3, 0xA9, 0x67, 0x8B, 0x89, 0x54, 0x96}},
},
"Get from storage fails": {
client: &stubAzureClient{},
storage: &stubStorage{getErr: someErr},
errExpected: true,
},
"Put to storage fails": {
client: &stubAzureClient{secret: wrapKey},
storage: &stubStorage{
getErr: storage.ErrDEKUnset,
putErr: someErr,
},
errExpected: true,
},
"GetSecret fails": {
client: &stubAzureClient{getSecretErr: someErr},
storage: storage.NewMemMapStorage(),
errExpected: true,
},
"GetSecret fails with unknown kek": {
client: &stubAzureClient{getSecretErr: errors.New("SecretNotFound")},
storage: storage.NewMemMapStorage(),
errExpected: true,
},
"key wrapping fails": {
client: &stubAzureClient{secret: []byte{0x1}},
storage: storage.NewMemMapStorage(),
errExpected: true,
},
"key unwrapping fails": {
client: &stubAzureClient{secret: wrapKey},
storage: &stubStorage{key: []byte{0x1}},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := KMSClient{
client: tc.client,
storage: tc.storage,
opts: &Opts{},
}
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
if tc.errExpected {
assert.Error(err)
} else {
assert.Len(dek, 32)
assert.NoError(err)
}
})
}
}

View File

@ -18,16 +18,28 @@ import (
"github.com/edgelesssys/constellation/kms/storage" "github.com/edgelesssys/constellation/kms/storage"
) )
type hsmClientAPI interface {
CreateOCTKey(ctx context.Context, name string, options *azkeys.CreateOCTKeyOptions) (azkeys.CreateOCTKeyResponse, error)
ImportKey(ctx context.Context, keyName string, key azkeys.JSONWebKey, options *azkeys.ImportKeyOptions) (azkeys.ImportKeyResponse, error)
GetKey(ctx context.Context, keyName string, options *azkeys.GetKeyOptions) (azkeys.GetKeyResponse, error)
}
type cryptoClientAPI interface {
UnwrapKey(ctx context.Context, alg crypto.KeyWrapAlgorithm, encryptedKey []byte, options *crypto.UnwrapKeyOptions) (crypto.UnwrapKeyResponse, error)
WrapKey(ctx context.Context, alg crypto.KeyWrapAlgorithm, key []byte, options *crypto.WrapKeyOptions) (crypto.WrapKeyResponse, error)
}
// Suffix for HSM Vaults. // Suffix for HSM Vaults.
const HSMDefaultCloud VaultSuffix = ".managedhsm.azure.net/" const HSMDefaultCloud VaultSuffix = ".managedhsm.azure.net/"
// HSMClient implements the CloudKMS interface for Azure managed HSM. // HSMClient implements the CloudKMS interface for Azure managed HSM.
type HSMClient struct { type HSMClient struct {
credentials azcore.TokenCredential credentials azcore.TokenCredential
client *azkeys.Client client hsmClientAPI
storage kms.Storage storage kms.Storage
vaultURL string vaultURL string
opts *Opts newCryptoClient func(keyURL string, credential azcore.TokenCredential, options *crypto.ClientOptions) (cryptoClientAPI, error)
opts *crypto.ClientOptions
} }
// NewHSM initializes a KMS client for Azure manged HSM Key Vault. // NewHSM initializes a KMS client for Azure manged HSM Key Vault.
@ -58,7 +70,14 @@ func NewHSM(ctx context.Context, vaultName string, store kms.Storage, opts *Opts
store = storage.NewMemMapStorage() store = storage.NewMemMapStorage()
} }
return &HSMClient{vaultURL: vaultURL, client: client, credentials: cred, storage: store, opts: opts}, nil return &HSMClient{
vaultURL: vaultURL,
client: client,
credentials: cred,
storage: store,
opts: (*crypto.ClientOptions)(opts.client),
newCryptoClient: cryptoClientFactory,
}, nil
} }
// CreateKEK creates a new Key Encryption Key using Azure managed HSM. // CreateKEK creates a new Key Encryption Key using Azure managed HSM.
@ -125,7 +144,7 @@ func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekS
return nil, fmt.Errorf("unable to detect key version: %w", err) return nil, fmt.Errorf("unable to detect key version: %w", err)
} }
cryptoClient, err := crypto.NewClient(fmt.Sprintf("%skeys/%s/%s", c.vaultURL, kekID, version), c.credentials, (*crypto.ClientOptions)(c.opts.client)) cryptoClient, err := c.newCryptoClient(fmt.Sprintf("%skeys/%s/%s", c.vaultURL, kekID, version), c.credentials, c.opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("creating crypto client for KEK: %s: %w", kekID, err) return nil, fmt.Errorf("creating crypto client for KEK: %s: %w", kekID, err)
} }
@ -144,7 +163,7 @@ func (c *HSMClient) putDEK(ctx context.Context, kekID, keyID string, plainDEK []
if err != nil { if err != nil {
return fmt.Errorf("unable to detect key version: %w", err) return fmt.Errorf("unable to detect key version: %w", err)
} }
cryptoClient, err := crypto.NewClient(fmt.Sprintf("%skeys/%s/%s", c.vaultURL, kekID, version), c.credentials, nil) cryptoClient, err := c.newCryptoClient(fmt.Sprintf("%skeys/%s/%s", c.vaultURL, kekID, version), c.credentials, c.opts)
if err != nil { if err != nil {
return fmt.Errorf("creating crypto client for KEK: %s: %w", kekID, err) return fmt.Errorf("creating crypto client for KEK: %s: %w", kekID, err)
} }
@ -176,3 +195,7 @@ func (c *HSMClient) getKeyVersion(ctx context.Context, kekID string) (string, er
return path[1], nil return path[1], nil
} }
func cryptoClientFactory(keyURL string, credential azcore.TokenCredential, options *crypto.ClientOptions) (cryptoClientAPI, error) {
return crypto.NewClient(keyURL, credential, options)
}

304
kms/kms/azure/hsm_test.go Normal file
View File

@ -0,0 +1,304 @@
package azure
import (
"context"
"errors"
"fmt"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys/crypto"
"github.com/Azure/go-autorest/autorest/to"
"github.com/edgelesssys/constellation/kms/kms"
"github.com/edgelesssys/constellation/kms/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type stubHSMClient struct {
keyCreated bool
createOCTKeyErr error
importKeyErr error
getKeyErr error
keyVersion string
}
func (s *stubHSMClient) CreateOCTKey(ctx context.Context, name string, options *azkeys.CreateOCTKeyOptions) (azkeys.CreateOCTKeyResponse, error) {
s.keyCreated = true
return azkeys.CreateOCTKeyResponse{}, s.createOCTKeyErr
}
func (s *stubHSMClient) ImportKey(ctx context.Context, keyName string, key azkeys.JSONWebKey, options *azkeys.ImportKeyOptions) (azkeys.ImportKeyResponse, error) {
s.keyCreated = true
return azkeys.ImportKeyResponse{}, s.importKeyErr
}
func (s *stubHSMClient) GetKey(ctx context.Context, keyName string, options *azkeys.GetKeyOptions) (azkeys.GetKeyResponse, error) {
return azkeys.GetKeyResponse{
KeyBundle: azkeys.KeyBundle{
Key: &azkeys.JSONWebKey{
ID: to.StringPtr(s.keyVersion),
},
},
}, s.getKeyErr
}
type stubCryptoClient struct {
createErr error
unwrapKeyErr error
unwrapKeyResult []byte
wrapKeyErr error
}
func newStubCryptoClientFactory(stub *stubCryptoClient) func(keyURL string, credential azcore.TokenCredential, options *crypto.ClientOptions) (cryptoClientAPI, error) {
return func(keyURL string, credential azcore.TokenCredential, options *crypto.ClientOptions) (cryptoClientAPI, error) {
return stub, stub.createErr
}
}
func (s *stubCryptoClient) UnwrapKey(ctx context.Context, alg crypto.KeyWrapAlgorithm, encryptedKey []byte, options *crypto.UnwrapKeyOptions) (crypto.UnwrapKeyResponse, error) {
return crypto.UnwrapKeyResponse{
KeyOperationResult: crypto.KeyOperationResult{
Result: s.unwrapKeyResult,
},
}, s.unwrapKeyErr
}
func (s *stubCryptoClient) WrapKey(ctx context.Context, alg crypto.KeyWrapAlgorithm, key []byte, options *crypto.WrapKeyOptions) (crypto.WrapKeyResponse, error) {
return crypto.WrapKeyResponse{}, s.wrapKeyErr
}
type stubStorage struct {
key []byte
getErr error
putErr error
}
func (s *stubStorage) Get(context.Context, string) ([]byte, error) {
return s.key, s.getErr
}
func (s *stubStorage) Put(context.Context, string, []byte) error {
return s.putErr
}
func TestHSMCreateKEK(t *testing.T) {
someErr := errors.New("error")
importKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
testCases := map[string]struct {
client *stubHSMClient
importKey []byte
errExpected bool
}{
"create new kek successful": {
client: &stubHSMClient{},
},
"CreateOCTKey fails": {
client: &stubHSMClient{createOCTKeyErr: someErr},
errExpected: true,
},
"import key successful": {
client: &stubHSMClient{},
importKey: importKey,
},
"ImportKey fails": {
client: &stubHSMClient{importKeyErr: someErr},
importKey: importKey,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := HSMClient{
client: tc.client,
storage: storage.NewMemMapStorage(),
}
err := client.CreateKEK(context.Background(), "test-key", tc.importKey)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.True(tc.client.keyCreated)
}
})
}
}
func TestHSMGetNewDEK(t *testing.T) {
someErr := errors.New("error")
keyVersion := "https://test.managedhsm.azure.net/keys/test-key/test-key-version"
testCases := map[string]struct {
client hsmClientAPI
storage kms.Storage
cryptoClient *stubCryptoClient
errExpected bool
}{
"successful": {
client: &stubHSMClient{keyVersion: keyVersion},
cryptoClient: &stubCryptoClient{},
storage: storage.NewMemMapStorage(),
},
"Get from storage fails": {
client: &stubHSMClient{keyVersion: keyVersion},
cryptoClient: &stubCryptoClient{},
storage: &stubStorage{getErr: someErr},
errExpected: true,
},
"Put to storage fails": {
client: &stubHSMClient{keyVersion: keyVersion},
cryptoClient: &stubCryptoClient{},
storage: &stubStorage{
getErr: storage.ErrDEKUnset,
putErr: someErr,
},
errExpected: true,
},
"GetKey fails": {
client: &stubHSMClient{getKeyErr: someErr},
cryptoClient: &stubCryptoClient{},
storage: storage.NewMemMapStorage(),
errExpected: true,
},
"WrapKey fails": {
client: &stubHSMClient{keyVersion: keyVersion},
cryptoClient: &stubCryptoClient{wrapKeyErr: someErr},
storage: storage.NewMemMapStorage(),
errExpected: true,
},
"creating crypto client fails": {
client: &stubHSMClient{keyVersion: keyVersion},
cryptoClient: &stubCryptoClient{createErr: someErr},
storage: storage.NewMemMapStorage(),
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := HSMClient{
client: tc.client,
newCryptoClient: newStubCryptoClientFactory(tc.cryptoClient),
storage: tc.storage,
opts: &crypto.ClientOptions{},
}
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
if tc.errExpected {
assert.Error(err)
} else {
assert.Len(dek, 32)
assert.NoError(err)
}
})
}
}
func TestHSMGetExistingDEK(t *testing.T) {
someErr := errors.New("error")
keyVersion := "https://test.managedhsm.azure.net/keys/test-key/test-key-version"
testKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
testCases := map[string]struct {
client hsmClientAPI
cryptoClient *stubCryptoClient
errExpected bool
}{
"successful": {
client: &stubHSMClient{keyVersion: keyVersion},
cryptoClient: &stubCryptoClient{unwrapKeyResult: testKey},
},
"GetKey fails": {
client: &stubHSMClient{
keyVersion: keyVersion,
getKeyErr: someErr,
},
cryptoClient: &stubCryptoClient{},
errExpected: true,
},
"UnwrapKey fails": {
client: &stubHSMClient{keyVersion: keyVersion},
cryptoClient: &stubCryptoClient{unwrapKeyErr: someErr},
errExpected: true,
},
"creating crypto client fails": {
client: &stubHSMClient{keyVersion: keyVersion},
cryptoClient: &stubCryptoClient{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
keyID := "volume-01"
storage := storage.NewMemMapStorage()
require.NoError(storage.Put(context.Background(), keyID, testKey))
client := HSMClient{
client: tc.client,
newCryptoClient: newStubCryptoClientFactory(tc.cryptoClient),
storage: storage,
opts: &crypto.ClientOptions{},
}
dek, err := client.GetDEK(context.Background(), "test-key", keyID, len(testKey))
if tc.errExpected {
assert.Error(err)
} else {
assert.Len(dek, len(testKey))
assert.NoError(err)
}
})
}
}
func TestGetKeyVersion(t *testing.T) {
testVersion := "test-key-version"
testCases := map[string]struct {
client *stubHSMClient
errExpected bool
}{
"valid key version": {
client: &stubHSMClient{keyVersion: fmt.Sprintf("https://test.managedhsm.azure.net/keys/test-key/%s", testVersion)},
},
"GetKey fails": {
client: &stubHSMClient{getKeyErr: errors.New("error")},
errExpected: true,
},
"key ID is not an URL": {
client: &stubHSMClient{keyVersion: string([]byte{0x0, 0x1, 0x2})},
errExpected: true,
},
"invalid key ID URL": {
client: &stubHSMClient{keyVersion: "https://test.managedhsm.azure.net/keys/test-key/test-key-version/another-version/and-another-one"},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := HSMClient{client: tc.client}
keyVersion, err := client.getKeyVersion(context.Background(), "test")
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(testVersion, keyVersion)
}
})
}
}