constellation/keyservice/kms/azure/hsm_test.go

221 lines
5.4 KiB
Go
Raw Normal View History

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys"
"github.com/edgelesssys/constellation/v2/keyservice/internal/storage"
"github.com/edgelesssys/constellation/v2/keyservice/kms"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type stubHSMClient struct {
keyCreated bool
createOCTKeyErr error
importKeyErr error
getKeyErr error
keyID string
unwrapKeyErr error
unwrapKeyResult []byte
wrapKeyErr error
}
func (s *stubHSMClient) CreateKey(ctx context.Context, name string, parameters azkeys.CreateKeyParameters, options *azkeys.CreateKeyOptions) (azkeys.CreateKeyResponse, error) {
s.keyCreated = true
return azkeys.CreateKeyResponse{}, s.createOCTKeyErr
}
func (s *stubHSMClient) ImportKey(ctx context.Context, name string, parameters azkeys.ImportKeyParameters, options *azkeys.ImportKeyOptions) (azkeys.ImportKeyResponse, error) {
s.keyCreated = true
return azkeys.ImportKeyResponse{}, s.importKeyErr
}
func (s *stubHSMClient) GetKey(ctx context.Context, name string, version string, options *azkeys.GetKeyOptions) (azkeys.GetKeyResponse, error) {
return azkeys.GetKeyResponse{
KeyBundle: azkeys.KeyBundle{
Key: &azkeys.JSONWebKey{
KID: to.Ptr(azkeys.ID(s.keyID)),
},
},
}, s.getKeyErr
}
func (s *stubHSMClient) UnwrapKey(ctx context.Context, name string, version string, parameters azkeys.KeyOperationsParameters, options *azkeys.UnwrapKeyOptions) (azkeys.UnwrapKeyResponse, error) {
return azkeys.UnwrapKeyResponse{
KeyOperationResult: azkeys.KeyOperationResult{
Result: s.unwrapKeyResult,
},
}, s.unwrapKeyErr
}
func (s *stubHSMClient) WrapKey(ctx context.Context, name string, version string, parameters azkeys.KeyOperationsParameters, options *azkeys.WrapKeyOptions) (azkeys.WrapKeyResponse, error) {
return azkeys.WrapKeyResponse{}, s.wrapKeyErr
}
type stubStorage struct {
key []byte
getErr error
putErr error
}
func (s *stubStorage) Get(context.Context, string) ([]byte, error) {
return s.key, s.getErr
}
func (s *stubStorage) Put(context.Context, string, []byte) error {
return s.putErr
}
func TestHSMCreateKEK(t *testing.T) {
someErr := errors.New("error")
importKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
testCases := map[string]struct {
client *stubHSMClient
importKey []byte
wantErr bool
}{
"create new kek successful": {
client: &stubHSMClient{},
},
"CreateOCTKey fails": {
client: &stubHSMClient{createOCTKeyErr: someErr},
wantErr: true,
},
"import key successful": {
client: &stubHSMClient{},
importKey: importKey,
},
"ImportKey fails": {
client: &stubHSMClient{importKeyErr: someErr},
importKey: importKey,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := HSMClient{
client: tc.client,
storage: storage.NewMemMapStorage(),
}
err := client.CreateKEK(context.Background(), "test-key", tc.importKey)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.True(tc.client.keyCreated)
}
})
}
}
func TestHSMGetNewDEK(t *testing.T) {
someErr := errors.New("error")
keyID := "https://test.managedhsm.azure.net/keys/test-key/test-key-version"
testCases := map[string]struct {
client hsmClientAPI
storage kms.Storage
wantErr bool
}{
"successful": {
client: &stubHSMClient{keyID: keyID},
storage: storage.NewMemMapStorage(),
},
"Get from storage fails": {
client: &stubHSMClient{keyID: keyID},
storage: &stubStorage{getErr: someErr},
wantErr: true,
},
"Put to storage fails": {
client: &stubHSMClient{keyID: keyID},
storage: &stubStorage{
getErr: storage.ErrDEKUnset,
putErr: someErr,
},
wantErr: true,
},
"WrapKey fails": {
client: &stubHSMClient{keyID: keyID, wrapKeyErr: someErr},
storage: storage.NewMemMapStorage(),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := HSMClient{
client: tc.client,
storage: tc.storage,
}
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
if tc.wantErr {
assert.Error(err)
} else {
assert.Len(dek, 32)
assert.NoError(err)
}
})
}
}
func TestHSMGetExistingDEK(t *testing.T) {
someErr := errors.New("error")
keyVersion := "https://test.managedhsm.azure.net/keys/test-key/test-key-version"
testKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
testCases := map[string]struct {
client hsmClientAPI
wantErr bool
}{
"successful": {
client: &stubHSMClient{keyID: keyVersion, unwrapKeyResult: testKey},
},
"UnwrapKey fails": {
client: &stubHSMClient{keyID: keyVersion, unwrapKeyErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
keyID := "volume-01"
storage := storage.NewMemMapStorage()
require.NoError(storage.Put(context.Background(), keyID, testKey))
client := HSMClient{
client: tc.client,
storage: storage,
}
dek, err := client.GetDEK(context.Background(), "test-key", keyID, len(testKey))
if tc.wantErr {
assert.Error(err)
} else {
assert.Len(dek, len(testKey))
assert.NoError(err)
}
})
}
}