constellation/keyservice/internal/storage/azurestorage_test.go
Otto Bittner 90b88e1cf9 kms: rename kms to keyservice
In the light of extending our eKMS support it will be helpful
to have a tighter use of the word "KMS".
KMS should refer to the actual component that manages keys.
The keyservice, also called KMS in the constellation code,
does not manage keys itself. It talks to a KMS backend,
which in turn does the actual key management.
2023-01-16 11:56:34 +01:00

177 lines
4.2 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package storage
import (
"bytes"
"context"
"errors"
"io"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
"github.com/stretchr/testify/assert"
)
func TestAzureGet(t *testing.T) {
testCases := map[string]struct {
client stubAzureBlobAPI
unsetError bool
wantErr bool
}{
"success": {
client: stubAzureBlobAPI{downloadData: []byte{0x1, 0x2, 0x3}},
},
"DownloadBuffer fails": {
client: stubAzureBlobAPI{downloadErr: errors.New("failed")},
wantErr: true,
},
"BlobNotFound error": {
client: stubAzureBlobAPI{downloadErr: &azcore.ResponseError{ErrorCode: string(bloberror.BlobNotFound)}},
unsetError: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &AzureStorage{
client: &tc.client,
connectionString: "test",
containerName: "test",
opts: &AzureOpts{},
}
out, err := client.Get(context.Background(), "test-key")
if tc.wantErr {
assert.Error(err)
if tc.unsetError {
assert.ErrorIs(err, ErrDEKUnset)
} else {
assert.False(errors.Is(err, ErrDEKUnset))
}
return
}
assert.NoError(err)
assert.Equal(tc.client.downloadData, out)
})
}
}
func TestAzurePut(t *testing.T) {
testCases := map[string]struct {
client stubAzureBlobAPI
wantErr bool
}{
"success": {
client: stubAzureBlobAPI{},
},
"Upload fails": {
client: stubAzureBlobAPI{uploadErr: errors.New("failed")},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
testData := []byte{0x1, 0x2, 0x3}
client := &AzureStorage{
client: &tc.client,
connectionString: "test",
containerName: "test",
opts: &AzureOpts{},
}
err := client.Put(context.Background(), "test-key", testData)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(testData, tc.client.uploadData)
})
}
}
func TestCreateContainerOrContinue(t *testing.T) {
testCases := map[string]struct {
client stubAzureBlobAPI
wantErr bool
}{
"success": {
client: stubAzureBlobAPI{},
},
"container already exists": {
client: stubAzureBlobAPI{createErr: &azcore.ResponseError{ErrorCode: string(bloberror.ContainerAlreadyExists)}},
},
"CreateContainer fails": {
client: stubAzureBlobAPI{createErr: errors.New("failed")},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &AzureStorage{
client: &tc.client,
connectionString: "test",
containerName: "test",
opts: &AzureOpts{},
}
err := client.createContainerOrContinue(context.Background())
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.True(tc.client.createCalled)
}
})
}
}
type stubAzureBlobAPI struct {
createErr error
createCalled bool
downloadErr error
downloadData []byte
uploadErr error
uploadData []byte
}
func (s *stubAzureBlobAPI) CreateContainer(context.Context, string, *container.CreateOptions) (azblob.CreateContainerResponse, error) {
s.createCalled = true
return azblob.CreateContainerResponse{}, s.createErr
}
func (s *stubAzureBlobAPI) DownloadStream(context.Context, string, string, *blob.DownloadStreamOptions) (blob.DownloadStreamResponse, error) {
res := blob.DownloadStreamResponse{}
res.Body = io.NopCloser(bytes.NewReader(s.downloadData))
return res, s.downloadErr
}
func (s *stubAzureBlobAPI) UploadStream(_ context.Context, _, _ string, data io.Reader, _ *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) {
uploadData, err := io.ReadAll(data)
if err != nil {
return azblob.UploadStreamResponse{}, err
}
s.uploadData = uploadData
return azblob.UploadStreamResponse{}, s.uploadErr
}