mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-11 23:49:30 -05:00
ci: static file uploader with automatic cache invalidation (#1833)
This commit is contained in:
parent
8686c5e7e2
commit
29b93065b3
36
internal/staticupload/BUILD.bazel
Normal file
36
internal/staticupload/BUILD.bazel
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||||
|
load("//bazel/go:go_test.bzl", "go_test")
|
||||||
|
|
||||||
|
go_library(
|
||||||
|
name = "staticupload",
|
||||||
|
srcs = [
|
||||||
|
"delete.go",
|
||||||
|
"staticupload.go",
|
||||||
|
"upload.go",
|
||||||
|
],
|
||||||
|
importpath = "github.com/edgelesssys/constellation/v2/internal/staticupload",
|
||||||
|
visibility = ["//:__subpackages__"],
|
||||||
|
deps = [
|
||||||
|
"@com_github_aws_aws_sdk_go_v2_config//:config",
|
||||||
|
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
|
||||||
|
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//:cloudfront",
|
||||||
|
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//types",
|
||||||
|
"@com_github_aws_aws_sdk_go_v2_service_s3//:s3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
go_test(
|
||||||
|
name = "staticupload_test",
|
||||||
|
srcs = ["staticupload_test.go"],
|
||||||
|
embed = [":staticupload"],
|
||||||
|
deps = [
|
||||||
|
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
|
||||||
|
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//:cloudfront",
|
||||||
|
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//types",
|
||||||
|
"@com_github_aws_aws_sdk_go_v2_service_s3//:s3",
|
||||||
|
"@com_github_aws_aws_sdk_go_v2_service_s3//types",
|
||||||
|
"@com_github_stretchr_testify//assert",
|
||||||
|
"@com_github_stretchr_testify//require",
|
||||||
|
"@org_uber_go_goleak//:goleak",
|
||||||
|
],
|
||||||
|
)
|
66
internal/staticupload/delete.go
Normal file
66
internal/staticupload/delete.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) Edgeless Systems GmbH
|
||||||
|
|
||||||
|
SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package staticupload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeleteObject deletes the given key from S3 and invalidates the CDN cache.
|
||||||
|
// It returns the delete output or an error.
|
||||||
|
// The error will be of type InvalidationError if the CDN cache could not be invalidated.
|
||||||
|
func (c *Client) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput,
|
||||||
|
optFns ...func(*s3.Options),
|
||||||
|
) (*s3.DeleteObjectOutput, error) {
|
||||||
|
if params == nil || params.Key == nil {
|
||||||
|
return nil, errors.New("key is not set")
|
||||||
|
}
|
||||||
|
output, err := c.s3Client.DeleteObject(ctx, params, optFns...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("deleting object: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.invalidate(ctx, []string{*params.Key}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteObjects deletes the given objects from S3 and invalidates the CDN cache.
|
||||||
|
// It returns the delete output or an error.
|
||||||
|
// The error will be of type InvalidationError if the CDN cache could not be invalidated.
|
||||||
|
func (c *Client) DeleteObjects(
|
||||||
|
ctx context.Context, params *s3.DeleteObjectsInput,
|
||||||
|
optFns ...func(*s3.Options),
|
||||||
|
) (*s3.DeleteObjectsOutput, error) {
|
||||||
|
if params == nil || params.Delete == nil || params.Delete.Objects == nil {
|
||||||
|
return nil, errors.New("objects are not set")
|
||||||
|
}
|
||||||
|
for _, obj := range params.Delete.Objects {
|
||||||
|
if obj.Key == nil {
|
||||||
|
return nil, errors.New("key is not set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output, deleteErr := c.s3Client.DeleteObjects(ctx, params, optFns...)
|
||||||
|
if deleteErr != nil {
|
||||||
|
return nil, fmt.Errorf("deleting objects: %w", deleteErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := make([]string, len(params.Delete.Objects))
|
||||||
|
for i, obj := range params.Delete.Objects {
|
||||||
|
keys[i] = *obj.Key
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.invalidate(ctx, keys); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return output, nil
|
||||||
|
}
|
242
internal/staticupload/staticupload.go
Normal file
242
internal/staticupload/staticupload.go
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) Edgeless Systems GmbH
|
||||||
|
|
||||||
|
SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package staticupload provides a static file uploader/updater/remover for the CDN / static API.
|
||||||
|
|
||||||
|
This uploader uses AWS S3 as a backend and cloudfront as a CDN.
|
||||||
|
It understands how to upload files and invalidate the CDN cache accordingly.
|
||||||
|
*/
|
||||||
|
package staticupload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||||
|
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/cloudfront"
|
||||||
|
cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client is a static file uploader/updater/remover for the CDN / static API.
|
||||||
|
// It has the same interface as the S3 uploader.
|
||||||
|
type Client struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
cdnClient cdnClient
|
||||||
|
uploadClient uploadClient
|
||||||
|
s3Client objectStorageClient
|
||||||
|
distributionID string
|
||||||
|
|
||||||
|
cacheInvalidationStrategy CacheInvalidationStrategy
|
||||||
|
cacheInvalidationWaitTimeout time.Duration
|
||||||
|
// dirtyKeys is a list of keys that still needs to be invalidated by us.
|
||||||
|
dirtyKeys []string
|
||||||
|
// invalidationIDs is a list of invalidation IDs that are currently in progress.
|
||||||
|
invalidationIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config is the configuration for the Client.
|
||||||
|
type Config struct {
|
||||||
|
// Region is the AWS region to use.
|
||||||
|
Region string
|
||||||
|
// Bucket is the name of the S3 bucket to use.
|
||||||
|
Bucket string
|
||||||
|
// DistributionID is the ID of the CloudFront distribution to use.
|
||||||
|
DistributionID string
|
||||||
|
CacheInvalidationStrategy CacheInvalidationStrategy
|
||||||
|
// CacheInvalidationWaitTimeout is the timeout to wait for the CDN cache to invalidate.
|
||||||
|
// set to 0 to disable waiting for the CDN cache to invalidate.
|
||||||
|
CacheInvalidationWaitTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// CacheInvalidationStrategy is the strategy to use for invalidating the CDN cache.
|
||||||
|
type CacheInvalidationStrategy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// CacheInvalidateEager invalidates the CDN cache immediately for every key that is uploaded.
|
||||||
|
CacheInvalidateEager CacheInvalidationStrategy = iota
|
||||||
|
// CacheInvalidateBatchOnClose invalidates the CDN cache in batches when the client is closed.
|
||||||
|
// This is useful when uploading many files at once but will fail if Close is not called.
|
||||||
|
CacheInvalidateBatchOnClose
|
||||||
|
)
|
||||||
|
|
||||||
|
// InvalidationError is an error that occurs when invalidating the CDN cache.
|
||||||
|
type InvalidationError struct {
|
||||||
|
inner error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns the error message.
|
||||||
|
func (e InvalidationError) Error() string {
|
||||||
|
return fmt.Sprintf("invalidating CDN cache: %v", e.inner)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the inner error.
|
||||||
|
func (e InvalidationError) Unwrap() error {
|
||||||
|
return e.inner
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Client.
|
||||||
|
func New(
|
||||||
|
ctx context.Context,
|
||||||
|
config Config,
|
||||||
|
) (*Client, error) {
|
||||||
|
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(config.Region))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s3Client := s3.NewFromConfig(cfg)
|
||||||
|
uploadClient := s3manager.NewUploader(s3Client)
|
||||||
|
|
||||||
|
cdnClient := cloudfront.NewFromConfig(cfg)
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
cdnClient: cdnClient,
|
||||||
|
s3Client: s3Client,
|
||||||
|
uploadClient: uploadClient,
|
||||||
|
distributionID: config.DistributionID,
|
||||||
|
cacheInvalidationStrategy: config.CacheInvalidationStrategy,
|
||||||
|
cacheInvalidationWaitTimeout: config.CacheInvalidationWaitTimeout,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the client.
|
||||||
|
// It waits for all invalidations to finish.
|
||||||
|
// It returns nil on success or an error.
|
||||||
|
// The error will be of type InvalidationError if the CDN cache could not be invalidated.
|
||||||
|
func (c *Client) Close(ctx context.Context) error {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
// invalidate all dirty keys that have not been invalidated yet
|
||||||
|
invalidationID, err := c.invalidateCacheForKeys(ctx, c.dirtyKeys)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.invalidationIDs = append(c.invalidationIDs, invalidationID)
|
||||||
|
c.dirtyKeys = nil
|
||||||
|
|
||||||
|
return c.waitForInvalidations(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalidate invalidates the CDN cache for the given keys.
|
||||||
|
// It either performs the invalidation immediately or adds them to the list of dirty keys.
|
||||||
|
func (c *Client) invalidate(ctx context.Context, keys []string) error {
|
||||||
|
if c.cacheInvalidationStrategy == CacheInvalidateBatchOnClose {
|
||||||
|
// save as dirty key for batch invalidation on Close
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
c.dirtyKeys = append(c.dirtyKeys, keys...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// eagerly invalidate the CDN cache
|
||||||
|
invalidationID, err := c.invalidateCacheForKeys(ctx, keys)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
c.invalidationIDs = append(c.invalidationIDs, invalidationID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalidateCacheForKeys invalidates the CDN cache for the given list of keys.
|
||||||
|
// It returns the invalidation ID without waiting for the invalidation to finish.
|
||||||
|
// The list of keys must not be longer than 3000 as specified by AWS:
|
||||||
|
// https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/Invalidation.html#InvalidationLimits
|
||||||
|
func (c *Client) invalidateCacheForKeys(ctx context.Context, keys []string) (string, error) {
|
||||||
|
if len(keys) > 3000 {
|
||||||
|
return "", InvalidationError{inner: fmt.Errorf("too many keys to invalidate: %d", len(keys))}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, key := range keys {
|
||||||
|
if !strings.HasPrefix(key, "/") {
|
||||||
|
keys[i] = "/" + key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
in := &cloudfront.CreateInvalidationInput{
|
||||||
|
DistributionId: &c.distributionID,
|
||||||
|
InvalidationBatch: &cftypes.InvalidationBatch{
|
||||||
|
CallerReference: ptr(fmt.Sprintf("%d", time.Now().Unix())),
|
||||||
|
Paths: &cftypes.Paths{
|
||||||
|
Items: keys,
|
||||||
|
Quantity: ptr(int32(len(keys))),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
invalidation, err := c.cdnClient.CreateInvalidation(ctx, in)
|
||||||
|
if err != nil {
|
||||||
|
return "", InvalidationError{inner: fmt.Errorf("creating invalidation: %w", err)}
|
||||||
|
}
|
||||||
|
if invalidation.Invalidation == nil || invalidation.Invalidation.Id == nil {
|
||||||
|
return "", InvalidationError{inner: fmt.Errorf("invalidation ID is not set")}
|
||||||
|
}
|
||||||
|
return *invalidation.Invalidation.Id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForInvalidations waits for all invalidations to finish.
|
||||||
|
func (c *Client) waitForInvalidations(ctx context.Context) error {
|
||||||
|
if c.cacheInvalidationWaitTimeout == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
waiter := cloudfront.NewInvalidationCompletedWaiter(c.cdnClient)
|
||||||
|
for _, invalidationID := range c.invalidationIDs {
|
||||||
|
waitIn := &cloudfront.GetInvalidationInput{
|
||||||
|
DistributionId: &c.distributionID,
|
||||||
|
Id: &invalidationID,
|
||||||
|
}
|
||||||
|
if err := waiter.Wait(ctx, waitIn, c.cacheInvalidationWaitTimeout); err != nil {
|
||||||
|
return InvalidationError{inner: fmt.Errorf("waiting for invalidation to complete: %w", err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.invalidationIDs = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type uploadClient interface {
|
||||||
|
Upload(
|
||||||
|
ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader),
|
||||||
|
) (*s3manager.UploadOutput, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type deleteClient interface {
|
||||||
|
DeleteObject(ctx context.Context, params *s3.DeleteObjectInput,
|
||||||
|
optFns ...func(*s3.Options),
|
||||||
|
) (*s3.DeleteObjectOutput, error)
|
||||||
|
DeleteObjects(
|
||||||
|
ctx context.Context, params *s3.DeleteObjectsInput,
|
||||||
|
optFns ...func(*s3.Options),
|
||||||
|
) (*s3.DeleteObjectsOutput, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type cdnClient interface {
|
||||||
|
CreateInvalidation(
|
||||||
|
ctx context.Context, params *cloudfront.CreateInvalidationInput, optFns ...func(*cloudfront.Options),
|
||||||
|
) (*cloudfront.CreateInvalidationOutput, error)
|
||||||
|
GetInvalidation(
|
||||||
|
context.Context, *cloudfront.GetInvalidationInput, ...func(*cloudfront.Options),
|
||||||
|
) (*cloudfront.GetInvalidationOutput, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type objectStorageClient interface {
|
||||||
|
deleteClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// statically assert that Client implements the uploadClient interface.
|
||||||
|
var _ uploadClient = (*Client)(nil)
|
||||||
|
|
||||||
|
// statically assert that Client implements the deleteClient interface.
|
||||||
|
var _ objectStorageClient = (*Client)(nil)
|
||||||
|
|
||||||
|
func ptr[T any](t T) *T {
|
||||||
|
return &t
|
||||||
|
}
|
553
internal/staticupload/staticupload_test.go
Normal file
553
internal/staticupload/staticupload_test.go
Normal file
@ -0,0 +1,553 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) Edgeless Systems GmbH
|
||||||
|
|
||||||
|
SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package staticupload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/cloudfront"
|
||||||
|
cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||||
|
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.uber.org/goleak"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
goleak.VerifyTestMain(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpload(t *testing.T) {
|
||||||
|
newInput := func() *s3.PutObjectInput {
|
||||||
|
return &s3.PutObjectInput{
|
||||||
|
Bucket: ptr("test-bucket"),
|
||||||
|
Key: ptr("test-key"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
in *s3.PutObjectInput
|
||||||
|
cacheInvalidationStrategy CacheInvalidationStrategy
|
||||||
|
cacheInvalidationWaitTimeout time.Duration
|
||||||
|
uploadFails bool
|
||||||
|
invalidationFails bool
|
||||||
|
wantInvalidations int
|
||||||
|
wantCacheInvalidationErr bool
|
||||||
|
wantErr bool
|
||||||
|
wantDirtyKeys []string
|
||||||
|
wantInvalidationIDs []string
|
||||||
|
}{
|
||||||
|
"eager invalidation": {
|
||||||
|
in: newInput(),
|
||||||
|
cacheInvalidationStrategy: CacheInvalidateEager,
|
||||||
|
cacheInvalidationWaitTimeout: time.Microsecond,
|
||||||
|
wantInvalidations: 1,
|
||||||
|
wantInvalidationIDs: []string{"test-invalidation-id-1"},
|
||||||
|
},
|
||||||
|
"lazy invalidation": {
|
||||||
|
in: newInput(),
|
||||||
|
cacheInvalidationStrategy: CacheInvalidateBatchOnClose,
|
||||||
|
cacheInvalidationWaitTimeout: time.Microsecond,
|
||||||
|
wantDirtyKeys: []string{"test-key"},
|
||||||
|
},
|
||||||
|
"upload fails": {
|
||||||
|
in: newInput(),
|
||||||
|
uploadFails: true,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"invalidation fails": {
|
||||||
|
in: newInput(),
|
||||||
|
invalidationFails: true,
|
||||||
|
wantCacheInvalidationErr: true,
|
||||||
|
},
|
||||||
|
"input is nil": {
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"key is nil": {
|
||||||
|
in: &s3.PutObjectInput{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
cdnClient := &fakeCDNClient{}
|
||||||
|
uploadClient := &stubUploadClient{}
|
||||||
|
if tc.invalidationFails {
|
||||||
|
cdnClient.createInvalidationErr = errors.New("invalidation failed")
|
||||||
|
}
|
||||||
|
if tc.uploadFails {
|
||||||
|
uploadClient.uploadErr = errors.New("upload failed")
|
||||||
|
}
|
||||||
|
if tc.in != nil {
|
||||||
|
tc.in.Body = bytes.NewReader([]byte("test-data"))
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &Client{
|
||||||
|
cdnClient: cdnClient,
|
||||||
|
uploadClient: uploadClient,
|
||||||
|
distributionID: "test-distribution-id",
|
||||||
|
cacheInvalidationStrategy: tc.cacheInvalidationStrategy,
|
||||||
|
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
|
||||||
|
}
|
||||||
|
_, err := client.Upload(context.Background(), tc.in)
|
||||||
|
|
||||||
|
if tc.wantCacheInvalidationErr {
|
||||||
|
assert.ErrorAs(err, &InvalidationError{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.False(errors.As(err, &InvalidationError{}))
|
||||||
|
assert.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(err)
|
||||||
|
assert.Equal(tc.wantDirtyKeys, client.dirtyKeys)
|
||||||
|
assert.Equal(tc.wantInvalidationIDs, client.invalidationIDs)
|
||||||
|
assert.Equal("test-data", string(uploadClient.uploadedData))
|
||||||
|
assert.Equal(tc.wantInvalidations, cdnClient.createInvalidationCounter)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteObject(t *testing.T) {
|
||||||
|
newObjectInput := func(nilInput, nilKey bool) *s3.DeleteObjectInput {
|
||||||
|
if nilInput {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if nilKey {
|
||||||
|
return &s3.DeleteObjectInput{}
|
||||||
|
}
|
||||||
|
return &s3.DeleteObjectInput{
|
||||||
|
Bucket: ptr("test-bucket"),
|
||||||
|
Key: ptr("test-key"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
newObjectsInput := func(nilInput, nilKey bool) *s3.DeleteObjectsInput {
|
||||||
|
if nilInput {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if nilKey {
|
||||||
|
return &s3.DeleteObjectsInput{
|
||||||
|
Delete: &s3types.Delete{
|
||||||
|
Objects: []s3types.ObjectIdentifier{{}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &s3.DeleteObjectsInput{
|
||||||
|
Bucket: ptr("test-bucket"),
|
||||||
|
Delete: &s3types.Delete{
|
||||||
|
Objects: []s3types.ObjectIdentifier{
|
||||||
|
{Key: ptr("test-key")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
nilInput bool
|
||||||
|
nilKey bool
|
||||||
|
cacheInvalidationStrategy CacheInvalidationStrategy
|
||||||
|
cacheInvalidationWaitTimeout time.Duration
|
||||||
|
deleteFails bool
|
||||||
|
invalidationFails bool
|
||||||
|
wantInvalidations int
|
||||||
|
wantCacheInvalidationErr bool
|
||||||
|
wantErr bool
|
||||||
|
wantDirtyKeys []string
|
||||||
|
wantInvalidationIDs []string
|
||||||
|
}{
|
||||||
|
"eager invalidation": {
|
||||||
|
cacheInvalidationStrategy: CacheInvalidateEager,
|
||||||
|
cacheInvalidationWaitTimeout: time.Microsecond,
|
||||||
|
wantInvalidations: 1,
|
||||||
|
wantInvalidationIDs: []string{"test-invalidation-id-1"},
|
||||||
|
},
|
||||||
|
"lazy invalidation": {
|
||||||
|
cacheInvalidationStrategy: CacheInvalidateBatchOnClose,
|
||||||
|
cacheInvalidationWaitTimeout: time.Microsecond,
|
||||||
|
wantDirtyKeys: []string{"test-key"},
|
||||||
|
},
|
||||||
|
"delete fails": {
|
||||||
|
deleteFails: true,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"invalidation fails": {
|
||||||
|
invalidationFails: true,
|
||||||
|
wantCacheInvalidationErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
cdnClient := &fakeCDNClient{}
|
||||||
|
s3Client := &stubObjectStorageClient{}
|
||||||
|
if tc.invalidationFails {
|
||||||
|
cdnClient.createInvalidationErr = errors.New("invalidation failed")
|
||||||
|
}
|
||||||
|
if tc.deleteFails {
|
||||||
|
s3Client.err = errors.New("delete failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &Client{
|
||||||
|
cdnClient: cdnClient,
|
||||||
|
s3Client: s3Client,
|
||||||
|
distributionID: "test-distribution-id",
|
||||||
|
cacheInvalidationStrategy: tc.cacheInvalidationStrategy,
|
||||||
|
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
|
||||||
|
}
|
||||||
|
_, err := client.DeleteObject(context.Background(), newObjectInput(tc.nilInput, tc.nilKey))
|
||||||
|
|
||||||
|
if tc.wantCacheInvalidationErr {
|
||||||
|
assert.ErrorAs(err, &InvalidationError{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.False(errors.As(err, &InvalidationError{}))
|
||||||
|
assert.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(err)
|
||||||
|
assert.Equal(tc.wantDirtyKeys, client.dirtyKeys)
|
||||||
|
assert.Equal(tc.wantInvalidationIDs, client.invalidationIDs)
|
||||||
|
assert.Equal(tc.wantInvalidations, cdnClient.createInvalidationCounter)
|
||||||
|
})
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
cdnClient := &fakeCDNClient{}
|
||||||
|
s3Client := &stubObjectStorageClient{}
|
||||||
|
if tc.invalidationFails {
|
||||||
|
cdnClient.createInvalidationErr = errors.New("invalidation failed")
|
||||||
|
}
|
||||||
|
if tc.deleteFails {
|
||||||
|
s3Client.err = errors.New("delete failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &Client{
|
||||||
|
cdnClient: cdnClient,
|
||||||
|
s3Client: s3Client,
|
||||||
|
distributionID: "test-distribution-id",
|
||||||
|
cacheInvalidationStrategy: tc.cacheInvalidationStrategy,
|
||||||
|
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
|
||||||
|
}
|
||||||
|
_, err := client.DeleteObjects(context.Background(), newObjectsInput(tc.nilInput, tc.nilKey))
|
||||||
|
|
||||||
|
if tc.wantCacheInvalidationErr {
|
||||||
|
assert.ErrorAs(err, &InvalidationError{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.False(errors.As(err, &InvalidationError{}))
|
||||||
|
assert.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(err)
|
||||||
|
assert.Equal(tc.wantDirtyKeys, client.dirtyKeys)
|
||||||
|
assert.Equal(tc.wantInvalidationIDs, client.invalidationIDs)
|
||||||
|
assert.Equal(tc.wantInvalidations, cdnClient.createInvalidationCounter)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClose(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
dirtyKeys []string
|
||||||
|
invalidationIDs []string
|
||||||
|
cacheInvalidationWaitTimeout time.Duration
|
||||||
|
invalidationFails bool
|
||||||
|
invalidationStatus map[string]*string
|
||||||
|
wantInvalidations int
|
||||||
|
wantDanglingInvalidationIDs []string
|
||||||
|
wantStatusChecks map[string]int
|
||||||
|
wantCacheInvalidationErr bool
|
||||||
|
}{
|
||||||
|
"mixed invalidation": {
|
||||||
|
dirtyKeys: []string{"test-key-1", "test-key-2"},
|
||||||
|
invalidationIDs: []string{
|
||||||
|
"test-invalidation-id-2",
|
||||||
|
"test-invalidation-id-3",
|
||||||
|
},
|
||||||
|
cacheInvalidationWaitTimeout: time.Microsecond,
|
||||||
|
invalidationStatus: map[string]*string{
|
||||||
|
"test-invalidation-id-1": ptr("Completed"),
|
||||||
|
"test-invalidation-id-2": ptr("Completed"),
|
||||||
|
"test-invalidation-id-3": ptr("Completed"),
|
||||||
|
},
|
||||||
|
wantInvalidations: 1, // keys are batched
|
||||||
|
wantStatusChecks: map[string]int{
|
||||||
|
"test-invalidation-id-1": 1,
|
||||||
|
"test-invalidation-id-2": 1,
|
||||||
|
"test-invalidation-id-3": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"dirty key invalidation": {
|
||||||
|
dirtyKeys: []string{"test-key-1", "test-key-2"},
|
||||||
|
cacheInvalidationWaitTimeout: time.Microsecond,
|
||||||
|
invalidationStatus: map[string]*string{
|
||||||
|
"test-invalidation-id-1": ptr("Completed"),
|
||||||
|
},
|
||||||
|
wantInvalidations: 1, // keys are batched
|
||||||
|
wantStatusChecks: map[string]int{
|
||||||
|
"test-invalidation-id-1": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"not waiting for invalidation": {
|
||||||
|
dirtyKeys: []string{"test-key-1", "test-key-2"},
|
||||||
|
invalidationIDs: []string{
|
||||||
|
"test-invalidation-id-2",
|
||||||
|
"test-invalidation-id-3",
|
||||||
|
},
|
||||||
|
wantInvalidations: 1, // keys are batched
|
||||||
|
wantDanglingInvalidationIDs: []string{
|
||||||
|
"test-invalidation-id-2",
|
||||||
|
"test-invalidation-id-3",
|
||||||
|
"test-invalidation-id-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"invalidation fails": {
|
||||||
|
dirtyKeys: []string{"test-key-1", "test-key-2"},
|
||||||
|
invalidationFails: true,
|
||||||
|
wantCacheInvalidationErr: true,
|
||||||
|
},
|
||||||
|
"many keys": {
|
||||||
|
dirtyKeys: func() []string {
|
||||||
|
keys := make([]string, 3000)
|
||||||
|
for i := range keys {
|
||||||
|
keys[i] = fmt.Sprintf("test-key-%d", i)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}(),
|
||||||
|
wantInvalidations: 1, // keys are batched
|
||||||
|
wantDanglingInvalidationIDs: []string{
|
||||||
|
"test-invalidation-id-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"too many keys": {
|
||||||
|
dirtyKeys: func() []string {
|
||||||
|
keys := make([]string, 3001)
|
||||||
|
for i := range keys {
|
||||||
|
keys[i] = fmt.Sprintf("test-key-%d", i)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}(),
|
||||||
|
wantCacheInvalidationErr: true,
|
||||||
|
},
|
||||||
|
"waiting for invalidation times out": {
|
||||||
|
invalidationIDs: []string{
|
||||||
|
"test-invalidation-id-2",
|
||||||
|
"test-invalidation-id-3",
|
||||||
|
},
|
||||||
|
cacheInvalidationWaitTimeout: time.Microsecond,
|
||||||
|
wantCacheInvalidationErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
cdnClient := &fakeCDNClient{
|
||||||
|
status: tc.invalidationStatus,
|
||||||
|
}
|
||||||
|
uploadClient := &stubUploadClient{
|
||||||
|
uploadErr: errors.New("Upload should not be called"),
|
||||||
|
}
|
||||||
|
if tc.invalidationFails {
|
||||||
|
cdnClient.createInvalidationErr = errors.New("invalidation failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &Client{
|
||||||
|
cdnClient: cdnClient,
|
||||||
|
uploadClient: uploadClient,
|
||||||
|
distributionID: "test-distribution-id",
|
||||||
|
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
|
||||||
|
dirtyKeys: tc.dirtyKeys,
|
||||||
|
invalidationIDs: tc.invalidationIDs,
|
||||||
|
}
|
||||||
|
err := client.Close(context.Background())
|
||||||
|
|
||||||
|
if tc.wantCacheInvalidationErr {
|
||||||
|
assert.ErrorAs(err, &InvalidationError{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(err)
|
||||||
|
assert.Empty(client.dirtyKeys)
|
||||||
|
assert.Equal(tc.wantDanglingInvalidationIDs, client.invalidationIDs)
|
||||||
|
assert.Equal(tc.wantInvalidations, cdnClient.createInvalidationCounter)
|
||||||
|
assert.Equal(tc.wantStatusChecks, cdnClient.statusCheckCounter)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrency(_ *testing.T) {
|
||||||
|
newInput := func() *s3.PutObjectInput {
|
||||||
|
return &s3.PutObjectInput{
|
||||||
|
Bucket: ptr("test-bucket"),
|
||||||
|
Key: ptr("test-key"),
|
||||||
|
Body: bytes.NewReader([]byte("test-data")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cdnClient := &fakeCDNClient{}
|
||||||
|
s3Client := &stubObjectStorageClient{}
|
||||||
|
uploadClient := &stubUploadClient{}
|
||||||
|
|
||||||
|
client := &Client{
|
||||||
|
cdnClient: cdnClient,
|
||||||
|
s3Client: s3Client,
|
||||||
|
uploadClient: uploadClient,
|
||||||
|
distributionID: "test-distribution-id",
|
||||||
|
cacheInvalidationWaitTimeout: 50 * time.Millisecond,
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
upload := func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = client.Upload(context.Background(), newInput())
|
||||||
|
}
|
||||||
|
deleteObject := func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = client.DeleteObject(context.Background(), &s3.DeleteObjectInput{
|
||||||
|
Bucket: ptr("test-bucket"),
|
||||||
|
Key: ptr("test-key"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
deleteObjects := func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = client.DeleteObjects(context.Background(), &s3.DeleteObjectsInput{
|
||||||
|
Bucket: ptr("test-bucket"),
|
||||||
|
Delete: &s3types.Delete{
|
||||||
|
Objects: []s3types.ObjectIdentifier{
|
||||||
|
{Key: ptr("test-key")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
closeClient := func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = client.Close(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(4)
|
||||||
|
go upload()
|
||||||
|
go deleteObject()
|
||||||
|
go deleteObjects()
|
||||||
|
go closeClient()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeCDNClient struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
createInvalidationCounter int
|
||||||
|
statusCheckCounter map[string]int
|
||||||
|
status map[string]*string
|
||||||
|
|
||||||
|
createInvalidationErr error
|
||||||
|
getInvalidationErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeCDNClient) CreateInvalidation(
|
||||||
|
_ context.Context, _ *cloudfront.CreateInvalidationInput, _ ...func(*cloudfront.Options),
|
||||||
|
) (*cloudfront.CreateInvalidationOutput, error) {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
c.createInvalidationCounter++
|
||||||
|
ctr := c.createInvalidationCounter
|
||||||
|
return &cloudfront.CreateInvalidationOutput{
|
||||||
|
Invalidation: &cftypes.Invalidation{
|
||||||
|
Id: ptr(fmt.Sprintf("test-invalidation-id-%d", ctr)),
|
||||||
|
},
|
||||||
|
}, c.createInvalidationErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *fakeCDNClient) GetInvalidation(
|
||||||
|
_ context.Context, input *cloudfront.GetInvalidationInput, _ ...func(*cloudfront.Options),
|
||||||
|
) (*cloudfront.GetInvalidationOutput, error) {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
if c.statusCheckCounter == nil {
|
||||||
|
c.statusCheckCounter = make(map[string]int)
|
||||||
|
}
|
||||||
|
c.statusCheckCounter[*input.Id]++
|
||||||
|
status := "Unknown"
|
||||||
|
if s, ok := c.status[*input.Id]; ok {
|
||||||
|
status = *s
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cloudfront.GetInvalidationOutput{
|
||||||
|
Invalidation: &cftypes.Invalidation{
|
||||||
|
CreateTime: ptr(time.Now()),
|
||||||
|
Id: input.Id,
|
||||||
|
Status: ptr(status),
|
||||||
|
},
|
||||||
|
}, c.getInvalidationErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubUploadClient struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
uploadErr error
|
||||||
|
uploadedData []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubUploadClient) Upload(
|
||||||
|
_ context.Context, input *s3.PutObjectInput,
|
||||||
|
_ ...func(*s3manager.Uploader),
|
||||||
|
) (*s3manager.UploadOutput, error) {
|
||||||
|
var err error
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
s.uploadedData, err = io.ReadAll(input.Body)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return nil, s.uploadErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubObjectStorageClient struct {
|
||||||
|
deleteObjectOut *s3.DeleteObjectOutput
|
||||||
|
deleteObjectsOut *s3.DeleteObjectsOutput
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubObjectStorageClient) DeleteObject(_ context.Context, _ *s3.DeleteObjectInput,
|
||||||
|
_ ...func(*s3.Options),
|
||||||
|
) (*s3.DeleteObjectOutput, error) {
|
||||||
|
return s.deleteObjectOut, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubObjectStorageClient) DeleteObjects(
|
||||||
|
_ context.Context, _ *s3.DeleteObjectsInput,
|
||||||
|
_ ...func(*s3.Options),
|
||||||
|
) (*s3.DeleteObjectsOutput, error) {
|
||||||
|
return s.deleteObjectsOut, s.err
|
||||||
|
}
|
36
internal/staticupload/upload.go
Normal file
36
internal/staticupload/upload.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) Edgeless Systems GmbH
|
||||||
|
|
||||||
|
SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package staticupload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Upload uploads the given object to S3 and invalidates the CDN cache.
|
||||||
|
// It returns the upload output or an error.
|
||||||
|
// The error will be of type InvalidationError if the CDN cache could not be invalidated.
|
||||||
|
func (c *Client) Upload(
|
||||||
|
ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader),
|
||||||
|
) (*s3manager.UploadOutput, error) {
|
||||||
|
if input == nil || input.Key == nil {
|
||||||
|
return nil, errors.New("key is not set")
|
||||||
|
}
|
||||||
|
output, uploadErr := c.uploadClient.Upload(ctx, input, opts...)
|
||||||
|
if uploadErr != nil {
|
||||||
|
return nil, fmt.Errorf("uploading object: %w", uploadErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.invalidate(ctx, []string{*input.Key}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return output, nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user