mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-09-21 13:34:48 -04:00
Add Azure storage tests
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
436ade2dc9
commit
ef5c85dad2
4 changed files with 359 additions and 82 deletions
|
@ -6,23 +6,42 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||
"github.com/edgelesssys/constellation/kms/config"
|
||||
)
|
||||
|
||||
type azureContainerAPI interface {
|
||||
Create(ctx context.Context, options *azblob.CreateContainerOptions) (azblob.ContainerCreateResponse, error)
|
||||
NewBlockBlobClient(blobName string) azureBlobAPI
|
||||
}
|
||||
|
||||
type azureBlobAPI interface {
|
||||
DownloadBlobToWriterAt(ctx context.Context, offset int64, count int64, writer io.WriterAt, o azblob.HighLevelDownloadFromBlobOptions) error
|
||||
Upload(ctx context.Context, body io.ReadSeekCloser, options *azblob.UploadBlockBlobOptions) (azblob.BlockBlobUploadResponse, error)
|
||||
}
|
||||
|
||||
type wrappedAzureClient struct {
|
||||
azblob.ContainerClient
|
||||
}
|
||||
|
||||
func (c wrappedAzureClient) NewBlockBlobClient(blobName string) azureBlobAPI {
|
||||
return c.ContainerClient.NewBlockBlobClient(blobName)
|
||||
}
|
||||
|
||||
// AzureStorage is an implementation of the Storage interface, storing keys in the Azure Blob Store.
|
||||
type AzureStorage struct {
|
||||
client azblob.ContainerClient
|
||||
opts *AzureOpts
|
||||
newClient func(ctx context.Context, connectionString, containerName string, opts *azblob.ClientOptions) (azureContainerAPI, error)
|
||||
connectionString string
|
||||
containerName string
|
||||
opts *AzureOpts
|
||||
}
|
||||
|
||||
// AzureOpts are additional options to be used when interacting with the Azure API.
|
||||
type AzureOpts struct {
|
||||
download *azblob.DownloadBlobOptions
|
||||
upload *azblob.UploadBlockBlobOptions
|
||||
service *azblob.ClientOptions
|
||||
upload *azblob.UploadBlockBlobOptions
|
||||
service *azblob.ClientOptions
|
||||
}
|
||||
|
||||
// NewAzureStorage initializes a storage client using Azure's Blob Storage: https://azure.microsoft.com/en-us/services/storage/blobs/
|
||||
|
@ -34,28 +53,40 @@ func NewAzureStorage(ctx context.Context, connectionString, containerName string
|
|||
if opts == nil {
|
||||
opts = &AzureOpts{}
|
||||
}
|
||||
service, err := azblob.NewServiceClientFromConnectionString(connectionString, opts.service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating storage client from connection string: %w", err)
|
||||
|
||||
s := &AzureStorage{
|
||||
newClient: azureContainerClientFactory,
|
||||
connectionString: connectionString,
|
||||
containerName: containerName,
|
||||
opts: opts,
|
||||
}
|
||||
client := service.NewContainerClient(containerName)
|
||||
|
||||
// Try to create a new storage container, continue if it already exists
|
||||
_, err = client.Create(ctx, &azblob.CreateContainerOptions{
|
||||
Metadata: config.StorageTags,
|
||||
})
|
||||
if (err != nil) && !strings.Contains(err.Error(), string(azblob.StorageErrorCodeContainerAlreadyExists)) {
|
||||
return nil, fmt.Errorf("creating storage container: %w", err)
|
||||
if err := s.createContainerOrContinue(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AzureStorage{client: client, opts: opts}, nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Get returns a DEK from from Azure Blob Storage by key ID.
|
||||
func (s *AzureStorage) Get(ctx context.Context, keyID string) ([]byte, error) {
|
||||
client := s.client.NewBlockBlobClient(keyID)
|
||||
res, err := client.Download(ctx, s.opts.download)
|
||||
client, err := s.newBlobClient(ctx, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// the Azure SDK requires an io.WriterAt, the AWS SDK provides a utility function to create one from a byte slice
|
||||
keyBuffer := manager.NewWriteAtBuffer([]byte{})
|
||||
|
||||
opts := azblob.HighLevelDownloadFromBlobOptions{
|
||||
RetryReaderOptionsPerBlock: azblob.RetryReaderOptions{
|
||||
MaxRetryRequests: 5,
|
||||
TreatEarlyCloseAsError: true,
|
||||
},
|
||||
}
|
||||
|
||||
if err := client.DownloadBlobToWriterAt(ctx, 0, 0, keyBuffer, opts); err != nil {
|
||||
var storeErr *azblob.StorageError
|
||||
if errors.As(err, &storeErr) && (storeErr.ErrorCode == azblob.StorageErrorCodeBlobNotFound) {
|
||||
return nil, ErrDEKUnset
|
||||
|
@ -63,36 +94,63 @@ func (s *AzureStorage) Get(ctx context.Context, keyID string) ([]byte, error) {
|
|||
return nil, fmt.Errorf("downloading DEK from storage: %w", err)
|
||||
}
|
||||
|
||||
key := &bytes.Buffer{}
|
||||
reader := res.Body(&azblob.RetryReaderOptions{MaxRetryRequests: 5, TreatEarlyCloseAsError: true})
|
||||
defer reader.Close()
|
||||
_, err = key.ReadFrom(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("downloading DEK from storage: %w", err)
|
||||
}
|
||||
|
||||
return key.Bytes(), nil
|
||||
return keyBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// Put saves a DEK to Azure Blob Storage by key ID.
|
||||
func (s *AzureStorage) Put(ctx context.Context, keyID string, encDEK []byte) error {
|
||||
client := s.client.NewBlockBlobClient(keyID)
|
||||
if _, err := client.Upload(ctx, newNopCloser(bytes.NewReader(encDEK)), s.opts.upload); err != nil {
|
||||
client, err := s.newBlobClient(ctx, keyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := client.Upload(ctx, readSeekNopCloser{bytes.NewReader(encDEK)}, s.opts.upload); err != nil {
|
||||
return fmt.Errorf("uploading DEK to storage: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// nopCloser is a wrapper for io.ReadSeeker implementing the Close method. This is required by the Azure SDK.
|
||||
type nopCloser struct {
|
||||
// createContainerOrContinue creates a new storage container if necessary, or continues if it already exists.
|
||||
func (s *AzureStorage) createContainerOrContinue(ctx context.Context) error {
|
||||
client, err := s.newClient(ctx, s.connectionString, s.containerName, s.opts.service)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var storeErr *azblob.StorageError
|
||||
_, err = client.Create(ctx, &azblob.CreateContainerOptions{
|
||||
Metadata: config.StorageTags,
|
||||
})
|
||||
if (err == nil) || (errors.As(err, &storeErr) && (storeErr.ErrorCode == azblob.StorageErrorCodeContainerAlreadyExists)) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("creating storage container: %w", err)
|
||||
}
|
||||
|
||||
// newBlobClient is a convenience function to create BlockBlobClients.
|
||||
func (s *AzureStorage) newBlobClient(ctx context.Context, blobName string) (azureBlobAPI, error) {
|
||||
c, err := s.newClient(ctx, s.connectionString, s.containerName, s.opts.service)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.NewBlockBlobClient(blobName), nil
|
||||
}
|
||||
|
||||
func azureContainerClientFactory(ctx context.Context, connectionString, containerName string, opts *azblob.ClientOptions) (azureContainerAPI, error) {
|
||||
service, err := azblob.NewServiceClientFromConnectionString(connectionString, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating storage client from connection string: %w", err)
|
||||
}
|
||||
|
||||
return wrappedAzureClient{service.NewContainerClient(containerName)}, nil
|
||||
}
|
||||
|
||||
// readSeekNopCloser is a wrapper for io.ReadSeeker implementing the Close method. This is required by the Azure SDK.
|
||||
type readSeekNopCloser struct {
|
||||
io.ReadSeeker
|
||||
}
|
||||
|
||||
func (n nopCloser) Close() error {
|
||||
func (n readSeekNopCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// newNopCloser returns a ReadSeekCloser with a no-op close method wrapping the provided io.ReadSeeker.
|
||||
func newNopCloser(rs io.ReadSeeker) io.ReadSeekCloser {
|
||||
return nopCloser{rs}
|
||||
}
|
||||
|
|
213
kms/storage/azurestorage_test.go
Normal file
213
kms/storage/azurestorage_test.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type stubAzureContainerAPI struct {
|
||||
newClientErr error
|
||||
createErr error
|
||||
createCalled *bool
|
||||
blockBlobAPI stubAzureBlockBlobAPI
|
||||
}
|
||||
|
||||
func newStubClientFactory(stub stubAzureContainerAPI) func(ctx context.Context, connectionString, containerName string, opts *azblob.ClientOptions) (azureContainerAPI, error) {
|
||||
return func(ctx context.Context, connectionString, containerName string, opts *azblob.ClientOptions) (azureContainerAPI, error) {
|
||||
return stub, stub.newClientErr
|
||||
}
|
||||
}
|
||||
|
||||
func (s stubAzureContainerAPI) Create(ctx context.Context, options *azblob.CreateContainerOptions) (azblob.ContainerCreateResponse, error) {
|
||||
*s.createCalled = true
|
||||
return azblob.ContainerCreateResponse{}, s.createErr
|
||||
}
|
||||
|
||||
func (s stubAzureContainerAPI) NewBlockBlobClient(blobName string) azureBlobAPI {
|
||||
return s.blockBlobAPI
|
||||
}
|
||||
|
||||
type stubAzureBlockBlobAPI struct {
|
||||
downloadBlobToWriterAtErr error
|
||||
downloadBlobToWriterOutput []byte
|
||||
uploadErr error
|
||||
uploadData chan []byte
|
||||
}
|
||||
|
||||
func (s stubAzureBlockBlobAPI) DownloadBlobToWriterAt(ctx context.Context, offset int64, count int64, writer io.WriterAt, o azblob.HighLevelDownloadFromBlobOptions) error {
|
||||
if _, err := writer.WriteAt(s.downloadBlobToWriterOutput, 0); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s.downloadBlobToWriterAtErr
|
||||
}
|
||||
|
||||
func (s stubAzureBlockBlobAPI) Upload(ctx context.Context, body io.ReadSeekCloser, options *azblob.UploadBlockBlobOptions) (azblob.BlockBlobUploadResponse, error) {
|
||||
res, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s.uploadData <- res
|
||||
return azblob.BlockBlobUploadResponse{}, s.uploadErr
|
||||
}
|
||||
|
||||
func TestAzureGet(t *testing.T) {
|
||||
someErr := errors.New("error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
client stubAzureContainerAPI
|
||||
unsetError bool
|
||||
errExpected bool
|
||||
}{
|
||||
"success": {
|
||||
client: stubAzureContainerAPI{
|
||||
blockBlobAPI: stubAzureBlockBlobAPI{downloadBlobToWriterOutput: []byte("test-data")},
|
||||
},
|
||||
},
|
||||
"creating client fails": {
|
||||
client: stubAzureContainerAPI{newClientErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"DownloadBlobToBuffer fails": {
|
||||
client: stubAzureContainerAPI{
|
||||
blockBlobAPI: stubAzureBlockBlobAPI{downloadBlobToWriterAtErr: someErr},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
"BlobNotFound error": {
|
||||
client: stubAzureContainerAPI{
|
||||
blockBlobAPI: stubAzureBlockBlobAPI{
|
||||
downloadBlobToWriterAtErr: &azblob.StorageError{
|
||||
ErrorCode: azblob.StorageErrorCodeBlobNotFound,
|
||||
},
|
||||
},
|
||||
},
|
||||
unsetError: true,
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := &AzureStorage{
|
||||
newClient: newStubClientFactory(tc.client),
|
||||
connectionString: "test",
|
||||
containerName: "test",
|
||||
opts: &AzureOpts{},
|
||||
}
|
||||
|
||||
out, err := client.Get(context.Background(), "test-key")
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
|
||||
if tc.unsetError {
|
||||
assert.ErrorIs(err, ErrDEKUnset)
|
||||
} else {
|
||||
assert.False(errors.Is(err, ErrDEKUnset))
|
||||
}
|
||||
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.client.blockBlobAPI.downloadBlobToWriterOutput, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzurePut(t *testing.T) {
|
||||
someErr := errors.New("error")
|
||||
|
||||
testCases := map[string]struct {
|
||||
client stubAzureContainerAPI
|
||||
errExpected bool
|
||||
}{
|
||||
"success": {
|
||||
client: stubAzureContainerAPI{},
|
||||
},
|
||||
"creating client fails": {
|
||||
client: stubAzureContainerAPI{newClientErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"Upload fails": {
|
||||
client: stubAzureContainerAPI{
|
||||
blockBlobAPI: stubAzureBlockBlobAPI{uploadErr: someErr},
|
||||
},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
testData := []byte{0x1, 0x2, 0x3}
|
||||
tc.client.blockBlobAPI.uploadData = make(chan []byte, len(testData))
|
||||
|
||||
client := &AzureStorage{
|
||||
newClient: newStubClientFactory(tc.client),
|
||||
connectionString: "test",
|
||||
containerName: "test",
|
||||
opts: &AzureOpts{},
|
||||
}
|
||||
|
||||
err := client.Put(context.Background(), "test-key", testData)
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(testData, <-tc.client.blockBlobAPI.uploadData)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateContainerOrContinue(t *testing.T) {
|
||||
someErr := errors.New("error")
|
||||
testCases := map[string]struct {
|
||||
client stubAzureContainerAPI
|
||||
errExpected bool
|
||||
}{
|
||||
"success": {
|
||||
client: stubAzureContainerAPI{},
|
||||
},
|
||||
"container already exists": {
|
||||
client: stubAzureContainerAPI{createErr: &azblob.StorageError{ErrorCode: azblob.StorageErrorCodeContainerAlreadyExists}},
|
||||
},
|
||||
"creating client fails": {
|
||||
client: stubAzureContainerAPI{newClientErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
"Create fails": {
|
||||
client: stubAzureContainerAPI{createErr: someErr},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
tc.client.createCalled = new(bool)
|
||||
client := &AzureStorage{
|
||||
newClient: newStubClientFactory(tc.client),
|
||||
connectionString: "test",
|
||||
containerName: "test",
|
||||
opts: &AzureOpts{},
|
||||
}
|
||||
|
||||
err := client.createContainerOrContinue(context.Background())
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.True(*tc.client.createCalled)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue