constellation/internal/staticupload/staticupload.go

293 lines
9.2 KiB
Go
Raw Normal View History

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
/*
Package staticupload provides a static file uploader/updater/remover for the CDN / static API.
This uploader uses AWS S3 as a backend and cloudfront as a CDN.
It understands how to upload files and invalidate the CDN cache accordingly.
*/
package staticupload
import (
"context"
"fmt"
2024-02-08 09:20:01 -05:00
"log/slog"
"strings"
"sync"
"time"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/cloudfront"
cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/google/uuid"
)
// Client is a static file uploader/updater/remover for the CDN / static API.
// It has the same interface as the S3 uploader.
type Client struct {
mux sync.Mutex
cdnClient cdnClient
uploadClient uploadClient
s3Client objectStorageClient
distributionID string
bucketID string
cacheInvalidationStrategy CacheInvalidationStrategy
cacheInvalidationWaitTimeout time.Duration
// dirtyKeys is a list of keys that still needs to be invalidated by us.
dirtyKeys []string
// invalidationIDs is a list of invalidation IDs that are currently in progress.
invalidationIDs []string
2024-02-08 09:20:01 -05:00
logger *slog.Logger
}
// Config is the configuration for the Client.
type Config struct {
// Region is the AWS region to use.
Region string
// Bucket is the name of the S3 bucket to use.
Bucket string
// DistributionID is the ID of the CloudFront distribution to use.
DistributionID string
CacheInvalidationStrategy CacheInvalidationStrategy
// CacheInvalidationWaitTimeout is the timeout to wait for the CDN cache to invalidate.
// set to 0 to disable waiting for the CDN cache to invalidate.
CacheInvalidationWaitTimeout time.Duration
}
// SetsDefault checks if all necessary values are set and sets default values otherwise.
func (c *Config) SetsDefault() {
if c.DistributionID == "" {
c.DistributionID = constants.CDNDefaultDistributionID
}
}
// CacheInvalidationStrategy is the strategy to use for invalidating the CDN cache.
type CacheInvalidationStrategy int
const (
// CacheInvalidateEager invalidates the CDN cache immediately for every key that is uploaded.
CacheInvalidateEager CacheInvalidationStrategy = iota
// CacheInvalidateBatchOnFlush invalidates the CDN cache in batches when the client is flushed / closed.
// This is useful when uploading many files at once but may fail to invalidate the cache if close is not called.
CacheInvalidateBatchOnFlush
)
// InvalidationError is an error that occurs when invalidating the CDN cache.
type InvalidationError struct {
inner error
}
// NewInvalidationError creates a new InvalidationError.
func NewInvalidationError(err error) *InvalidationError {
return &InvalidationError{inner: err}
}
// Error returns the error message.
func (e *InvalidationError) Error() string {
return fmt.Sprintf("invalidating CDN cache: %v", e.inner)
}
// Unwrap returns the inner error.
func (e *InvalidationError) Unwrap() error {
return e.inner
}
// New creates a new Client. Call CloseFunc when done with operations.
2024-02-08 09:20:01 -05:00
func New(ctx context.Context, config Config, log *slog.Logger) (*Client, CloseFunc, error) {
config.SetsDefault()
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(config.Region))
if err != nil {
return nil, nil, err
}
s3Client := s3.NewFromConfig(cfg)
uploadClient := s3manager.NewUploader(s3Client)
cdnClient := cloudfront.NewFromConfig(cfg)
client := &Client{
cdnClient: cdnClient,
s3Client: s3Client,
uploadClient: uploadClient,
distributionID: config.DistributionID,
cacheInvalidationStrategy: config.CacheInvalidationStrategy,
cacheInvalidationWaitTimeout: config.CacheInvalidationWaitTimeout,
bucketID: config.Bucket,
logger: log,
}
return client, client.Flush, nil
}
// Flush flushes the client by invalidating the CDN cache for modified keys.
// It waits for all invalidations to finish.
// It returns nil on success or an error.
// The error will be of type InvalidationError if the CDN cache could not be invalidated.
func (c *Client) Flush(ctx context.Context) error {
c.mux.Lock()
defer c.mux.Unlock()
2024-02-08 09:20:01 -05:00
c.logger.Debug(fmt.Sprintf("Invalidating keys: %s", c.dirtyKeys))
if len(c.dirtyKeys) == 0 {
return nil
}
// invalidate all dirty keys that have not been invalidated yet
invalidationID, err := c.invalidateCacheForKeys(ctx, c.dirtyKeys)
if err != nil {
return err
}
c.invalidationIDs = append(c.invalidationIDs, invalidationID)
c.dirtyKeys = nil
return c.waitForInvalidations(ctx)
}
// invalidate invalidates the CDN cache for the given keys.
// It either performs the invalidation immediately or adds them to the list of dirty keys.
func (c *Client) invalidate(ctx context.Context, keys []string) error {
if c.cacheInvalidationStrategy == CacheInvalidateBatchOnFlush {
// save as dirty key for batch invalidation on Close
c.mux.Lock()
defer c.mux.Unlock()
c.dirtyKeys = append(c.dirtyKeys, keys...)
return nil
}
if len(keys) == 0 {
return nil
}
// eagerly invalidate the CDN cache
invalidationID, err := c.invalidateCacheForKeys(ctx, keys)
if err != nil {
return err
}
c.mux.Lock()
defer c.mux.Unlock()
c.invalidationIDs = append(c.invalidationIDs, invalidationID)
return nil
}
// invalidateCacheForKeys invalidates the CDN cache for the given list of keys.
// It returns the invalidation ID without waiting for the invalidation to finish.
// The list of keys must not be longer than 3000 as specified by AWS:
// https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/Invalidation.html#InvalidationLimits
func (c *Client) invalidateCacheForKeys(ctx context.Context, keys []string) (string, error) {
if len(keys) > 3000 {
return "", NewInvalidationError(fmt.Errorf("too many keys to invalidate: %d", len(keys)))
}
for i, key := range keys {
if !strings.HasPrefix(key, "/") {
keys[i] = "/" + key
}
}
in := &cloudfront.CreateInvalidationInput{
DistributionId: &c.distributionID,
InvalidationBatch: &cftypes.InvalidationBatch{
CallerReference: ptr(uuid.New().String()),
Paths: &cftypes.Paths{
Items: keys,
Quantity: ptr(int32(len(keys))),
},
},
}
invalidation, err := c.cdnClient.CreateInvalidation(ctx, in)
if err != nil {
return "", NewInvalidationError(fmt.Errorf("creating invalidation: %w", err))
}
if invalidation.Invalidation == nil || invalidation.Invalidation.Id == nil {
return "", NewInvalidationError(fmt.Errorf("invalidation ID is not set"))
}
return *invalidation.Invalidation.Id, nil
}
// waitForInvalidations waits for all invalidations to finish.
func (c *Client) waitForInvalidations(ctx context.Context) error {
if c.cacheInvalidationWaitTimeout == 0 {
2024-02-08 09:20:01 -05:00
c.logger.Warn("cacheInvalidationWaitTimeout set to 0, not waiting for invalidations to finish")
return nil
}
waiter := cloudfront.NewInvalidationCompletedWaiter(c.cdnClient)
2024-02-08 09:20:01 -05:00
c.logger.Debug(fmt.Sprintf("Waiting for invalidations %s in distribution %s", c.invalidationIDs, c.distributionID))
for _, invalidationID := range c.invalidationIDs {
waitIn := &cloudfront.GetInvalidationInput{
DistributionId: &c.distributionID,
Id: &invalidationID,
}
if err := waiter.Wait(ctx, waitIn, c.cacheInvalidationWaitTimeout); err != nil {
return NewInvalidationError(fmt.Errorf("waiting for invalidation to complete: %w", err))
}
}
2024-02-08 09:20:01 -05:00
c.logger.Debug("Invalidations finished")
c.invalidationIDs = nil
return nil
}
type uploadClient interface {
Upload(
ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader),
) (*s3manager.UploadOutput, error)
}
type getClient interface {
GetObject(
ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options),
) (*s3.GetObjectOutput, error)
}
type listClient interface {
ListObjectsV2(
ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options),
) (*s3.ListObjectsV2Output, error)
}
type deleteClient interface {
DeleteObject(ctx context.Context, params *s3.DeleteObjectInput,
optFns ...func(*s3.Options),
) (*s3.DeleteObjectOutput, error)
DeleteObjects(
ctx context.Context, params *s3.DeleteObjectsInput,
optFns ...func(*s3.Options),
) (*s3.DeleteObjectsOutput, error)
}
type cdnClient interface {
CreateInvalidation(
ctx context.Context, params *cloudfront.CreateInvalidationInput, optFns ...func(*cloudfront.Options),
) (*cloudfront.CreateInvalidationOutput, error)
GetInvalidation(
context.Context, *cloudfront.GetInvalidationInput, ...func(*cloudfront.Options),
) (*cloudfront.GetInvalidationOutput, error)
}
type objectStorageClient interface {
getClient
listClient
deleteClient
}
// statically assert that Client implements the uploadClient interface.
var _ uploadClient = (*Client)(nil)
// statically assert that Client implements the deleteClient interface.
var _ objectStorageClient = (*Client)(nil)
func ptr[T any](t T) *T {
return &t
}
// CloseFunc is a function that closes the client.
type CloseFunc func(ctx context.Context) error