/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package storage import ( "bytes" "context" "errors" "io" "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" "github.com/stretchr/testify/assert" ) func TestAzureGet(t *testing.T) { testCases := map[string]struct { client stubAzureBlobAPI unsetError bool wantErr bool }{ "success": { client: stubAzureBlobAPI{downloadData: []byte{0x1, 0x2, 0x3}}, }, "DownloadBuffer fails": { client: stubAzureBlobAPI{downloadErr: errors.New("failed")}, wantErr: true, }, "BlobNotFound error": { client: stubAzureBlobAPI{downloadErr: &azcore.ResponseError{ErrorCode: string(bloberror.BlobNotFound)}}, unsetError: true, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) client := &AzureStorage{ client: &tc.client, connectionString: "test", containerName: "test", opts: &AzureOpts{}, } out, err := client.Get(context.Background(), "test-key") if tc.wantErr { assert.Error(err) if tc.unsetError { assert.ErrorIs(err, ErrDEKUnset) } else { assert.False(errors.Is(err, ErrDEKUnset)) } return } assert.NoError(err) assert.Equal(tc.client.downloadData, out) }) } } func TestAzurePut(t *testing.T) { testCases := map[string]struct { client stubAzureBlobAPI wantErr bool }{ "success": { client: stubAzureBlobAPI{}, }, "Upload fails": { client: stubAzureBlobAPI{uploadErr: errors.New("failed")}, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) testData := []byte{0x1, 0x2, 0x3} client := &AzureStorage{ client: &tc.client, connectionString: "test", containerName: "test", opts: &AzureOpts{}, } err := client.Put(context.Background(), "test-key", testData) if tc.wantErr { assert.Error(err) return } assert.NoError(err) assert.Equal(testData, tc.client.uploadData) }) } } func TestCreateContainerOrContinue(t *testing.T) { testCases := map[string]struct { client stubAzureBlobAPI wantErr bool }{ "success": { client: stubAzureBlobAPI{}, }, "container already exists": { client: stubAzureBlobAPI{createErr: &azcore.ResponseError{ErrorCode: string(bloberror.ContainerAlreadyExists)}}, }, "CreateContainer fails": { client: stubAzureBlobAPI{createErr: errors.New("failed")}, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) client := &AzureStorage{ client: &tc.client, connectionString: "test", containerName: "test", opts: &AzureOpts{}, } err := client.createContainerOrContinue(context.Background()) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) assert.True(tc.client.createCalled) } }) } } type stubAzureBlobAPI struct { createErr error createCalled bool downloadErr error downloadData []byte uploadErr error uploadData []byte } func (s *stubAzureBlobAPI) CreateContainer(context.Context, string, *container.CreateOptions) (azblob.CreateContainerResponse, error) { s.createCalled = true return azblob.CreateContainerResponse{}, s.createErr } func (s *stubAzureBlobAPI) DownloadStream(context.Context, string, string, *blob.DownloadStreamOptions) (blob.DownloadStreamResponse, error) { res := blob.DownloadStreamResponse{} res.Body = io.NopCloser(bytes.NewReader(s.downloadData)) return res, s.downloadErr } func (s *stubAzureBlobAPI) UploadStream(_ context.Context, _, _ string, data io.Reader, _ *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) { uploadData, err := io.ReadAll(data) if err != nil { return azblob.UploadStreamResponse{}, err } s.uploadData = uploadData return azblob.UploadStreamResponse{}, s.uploadErr }