ci: use aws s3 client that invalidates cloudfront cache for places that modify Constellation api (#1839)

This commit is contained in:
Malte Poll 2023-06-02 11:20:01 +02:00 committed by GitHub
parent 93569ff54c
commit e1d3afe8d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 398 additions and 186 deletions

View File

@ -50,12 +50,12 @@ func main() {
cliInfo.Kubernetes = append(cliInfo.Kubernetes, v.ClusterVersion) cliInfo.Kubernetes = append(cliInfo.Kubernetes, v.ClusterVersion)
} }
c, err := client.NewClient(ctx, "eu-central-1", "cdn-constellation-backend", "E1H77EZTHC3NE4", false, log) c, cclose, err := client.NewClient(ctx, "eu-central-1", "cdn-constellation-backend", "E1H77EZTHC3NE4", false, log)
if err != nil { if err != nil {
log.Fatalf("creating s3 client: %w", err) log.Fatalf("creating s3 client: %w", err)
} }
defer func() { defer func() {
if err := c.InvalidateCache(ctx); err != nil { if err := cclose(ctx); err != nil {
log.Fatalf("invalidating cache: %w", err) log.Fatalf("invalidating cache: %w", err)
} }
}() }()

View File

@ -76,10 +76,15 @@ func runCmd(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("unmarshalling version file: %w", err) return fmt.Errorf("unmarshalling version file: %w", err)
} }
sut, err := attestationconfigclient.New(ctx, cfg, []byte(cosignPwd), privateKey) sut, sutClose, err := attestationconfigclient.New(ctx, cfg, []byte(cosignPwd), privateKey)
if err != nil { if err != nil {
return fmt.Errorf("creating repo: %w", err) return fmt.Errorf("creating repo: %w", err)
} }
defer func() {
if err := sutClose(ctx); err != nil {
fmt.Printf("closing repo: %v\n", err)
}
}()
if err := sut.UploadAzureSEVSNP(ctx, versions, time.Now()); err != nil { if err := sut.UploadAzureSEVSNP(ctx, versions, time.Now()); err != nil {
return fmt.Errorf("uploading version: %w", err) return fmt.Errorf("uploading version: %w", err)

View File

@ -46,10 +46,15 @@ func runAWS(cmd *cobra.Command, _ []string) error {
log := logger.New(logger.PlainLog, flags.logLevel) log := logger.New(logger.PlainLog, flags.logLevel)
log.Debugf("Parsed flags: %+v", flags) log.Debugf("Parsed flags: %+v", flags)
archiveC, err := archive.New(cmd.Context(), flags.region, flags.bucket, log) archiveC, archiveCClose, err := archive.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log)
if err != nil { if err != nil {
return err return err
} }
defer func() {
if err := archiveCClose(cmd.Context()); err != nil {
log.Errorf("closing archive client: %v", err)
}
}()
uploadC, err := awsupload.New(flags.awsRegion, flags.awsBucket, log) uploadC, err := awsupload.New(flags.awsRegion, flags.awsBucket, log)
if err != nil { if err != nil {

View File

@ -47,10 +47,15 @@ func runAzure(cmd *cobra.Command, _ []string) error {
log := logger.New(logger.PlainLog, flags.logLevel) log := logger.New(logger.PlainLog, flags.logLevel)
log.Debugf("Parsed flags: %+v", flags) log.Debugf("Parsed flags: %+v", flags)
archiveC, err := archive.New(cmd.Context(), flags.region, flags.bucket, log) archiveC, archiveCClose, err := archive.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log)
if err != nil { if err != nil {
return err return err
} }
defer func() {
if err := archiveCClose(cmd.Context()); err != nil {
log.Errorf("closing archive client: %v", err)
}
}()
uploadC, err := azureupload.New(flags.azSubscription, flags.azLocation, flags.azResourceGroup, log) uploadC, err := azureupload.New(flags.azSubscription, flags.azLocation, flags.azResourceGroup, log)
if err != nil { if err != nil {

View File

@ -27,6 +27,7 @@ type commonFlags struct {
timestamp time.Time timestamp time.Time
region string region string
bucket string bucket string
distributionID string
out string out string
logLevel zapcore.Level logLevel zapcore.Level
} }
@ -75,6 +76,10 @@ func parseCommonFlags(cmd *cobra.Command) (commonFlags, error) {
if err != nil { if err != nil {
return commonFlags{}, err return commonFlags{}, err
} }
distributionID, err := cmd.Flags().GetString("distribution-id")
if err != nil {
return commonFlags{}, err
}
out, err := cmd.Flags().GetString("out") out, err := cmd.Flags().GetString("out")
if err != nil { if err != nil {
return commonFlags{}, err return commonFlags{}, err
@ -96,6 +101,7 @@ func parseCommonFlags(cmd *cobra.Command) (commonFlags, error) {
timestamp: timestmp, timestamp: timestmp,
region: region, region: region,
bucket: bucket, bucket: bucket,
distributionID: distributionID,
out: out, out: out,
logLevel: logLevel, logLevel: logLevel,
}, nil }, nil
@ -201,9 +207,10 @@ func parseGCPFlags(cmd *cobra.Command) (gcpFlags, error) {
} }
type s3Flags struct { type s3Flags struct {
region string region string
bucket string bucket string
logLevel zapcore.Level distributionID string
logLevel zapcore.Level
} }
func parseS3Flags(cmd *cobra.Command) (s3Flags, error) { func parseS3Flags(cmd *cobra.Command) (s3Flags, error) {
@ -215,6 +222,10 @@ func parseS3Flags(cmd *cobra.Command) (s3Flags, error) {
if err != nil { if err != nil {
return s3Flags{}, err return s3Flags{}, err
} }
distributionID, err := cmd.Flags().GetString("distribution-id")
if err != nil {
return s3Flags{}, err
}
verbose, err := cmd.Flags().GetBool("verbose") verbose, err := cmd.Flags().GetBool("verbose")
if err != nil { if err != nil {
return s3Flags{}, err return s3Flags{}, err
@ -225,9 +236,10 @@ func parseS3Flags(cmd *cobra.Command) (s3Flags, error) {
} }
return s3Flags{ return s3Flags{
region: region, region: region,
bucket: bucket, bucket: bucket,
logLevel: logLevel, distributionID: distributionID,
logLevel: logLevel,
}, nil }, nil
} }

View File

@ -47,10 +47,15 @@ func runGCP(cmd *cobra.Command, _ []string) error {
log := logger.New(logger.PlainLog, flags.logLevel) log := logger.New(logger.PlainLog, flags.logLevel)
log.Debugf("Parsed flags: %+v", flags) log.Debugf("Parsed flags: %+v", flags)
archiveC, err := archive.New(cmd.Context(), flags.region, flags.bucket, log) archiveC, archiveCClose, err := archive.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log)
if err != nil { if err != nil {
return err return err
} }
defer func() {
if err := archiveCClose(cmd.Context()); err != nil {
log.Errorf("closing archive client: %v", err)
}
}()
uploadC, err := gcpupload.New(cmd.Context(), flags.gcpProject, flags.gcpLocation, flags.gcpBucket, log) uploadC, err := gcpupload.New(cmd.Context(), flags.gcpProject, flags.gcpLocation, flags.gcpBucket, log)
if err != nil { if err != nil {

View File

@ -31,6 +31,7 @@ func NewImageCmd() *cobra.Command {
cmd.PersistentFlags().String("timestamp", "", "Optional timestamp to use for resource names. Uses format 2006-01-02T15:04:05Z07:00.") cmd.PersistentFlags().String("timestamp", "", "Optional timestamp to use for resource names. Uses format 2006-01-02T15:04:05Z07:00.")
cmd.PersistentFlags().String("region", "eu-central-1", "AWS region of the archive S3 bucket") cmd.PersistentFlags().String("region", "eu-central-1", "AWS region of the archive S3 bucket")
cmd.PersistentFlags().String("bucket", "cdn-constellation-backend", "S3 bucket name of the archive") cmd.PersistentFlags().String("bucket", "cdn-constellation-backend", "S3 bucket name of the archive")
cmd.PersistentFlags().String("distribution-id", "E1H77EZTHC3NE4", "CloudFront distribution ID of the API")
cmd.PersistentFlags().String("out", "", "Optional path to write the upload result to. If not set, the result is written to stdout.") cmd.PersistentFlags().String("out", "", "Optional path to write the upload result to. If not set, the result is written to stdout.")
cmd.PersistentFlags().Bool("verbose", false, "Enable verbose output") cmd.PersistentFlags().Bool("verbose", false, "Enable verbose output")
must(cmd.MarkPersistentFlagRequired("raw-image")) must(cmd.MarkPersistentFlagRequired("raw-image"))

View File

@ -31,6 +31,7 @@ func NewInfoCmd() *cobra.Command {
cmd.Flags().String("region", "eu-central-1", "AWS region of the archive S3 bucket") cmd.Flags().String("region", "eu-central-1", "AWS region of the archive S3 bucket")
cmd.Flags().String("bucket", "cdn-constellation-backend", "S3 bucket name of the archive") cmd.Flags().String("bucket", "cdn-constellation-backend", "S3 bucket name of the archive")
cmd.Flags().String("distribution-id", "E1H77EZTHC3NE4", "CloudFront distribution ID of the API")
cmd.Flags().Bool("verbose", false, "Enable verbose output") cmd.Flags().Bool("verbose", false, "Enable verbose output")
return cmd return cmd
@ -54,10 +55,15 @@ func runInfo(cmd *cobra.Command, args []string) error {
return err return err
} }
uploadC, err := infoupload.New(cmd.Context(), flags.region, flags.bucket, log) uploadC, uploadCClose, err := infoupload.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log)
if err != nil { if err != nil {
return fmt.Errorf("uploading image info: %w", err) return fmt.Errorf("uploading image info: %w", err)
} }
defer func() {
if err := uploadCClose(cmd.Context()); err != nil {
log.Errorf("closing upload client: %v", err)
}
}()
url, err := uploadC.Upload(cmd.Context(), info) url, err := uploadC.Upload(cmd.Context(), info)
if err != nil { if err != nil {

View File

@ -31,6 +31,7 @@ func newMeasurementsUploadCmd() *cobra.Command {
cmd.Flags().String("signature", "", "Path to signature file to upload") cmd.Flags().String("signature", "", "Path to signature file to upload")
cmd.Flags().String("region", "eu-central-1", "AWS region of the archive S3 bucket") cmd.Flags().String("region", "eu-central-1", "AWS region of the archive S3 bucket")
cmd.Flags().String("bucket", "cdn-constellation-backend", "S3 bucket name of the archive") cmd.Flags().String("bucket", "cdn-constellation-backend", "S3 bucket name of the archive")
cmd.Flags().String("distribution-id", "E1H77EZTHC3NE4", "CloudFront distribution ID of the API")
cmd.Flags().Bool("verbose", false, "Enable verbose output") cmd.Flags().Bool("verbose", false, "Enable verbose output")
must(cmd.MarkFlagRequired("measurements")) must(cmd.MarkFlagRequired("measurements"))
@ -53,10 +54,15 @@ func runMeasurementsUpload(cmd *cobra.Command, _ []string) error {
log := logger.New(logger.PlainLog, flags.logLevel) log := logger.New(logger.PlainLog, flags.logLevel)
log.Debugf("Parsed flags: %+v", flags) log.Debugf("Parsed flags: %+v", flags)
uploadC, err := measurementsuploader.New(cmd.Context(), flags.region, flags.bucket, log) uploadC, uploadCClose, err := measurementsuploader.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log)
if err != nil { if err != nil {
return fmt.Errorf("uploading image info: %w", err) return fmt.Errorf("uploading image info: %w", err)
} }
defer func() {
if err := uploadCClose(cmd.Context()); err != nil {
log.Errorf("closing upload client: %v", err)
}
}()
measurements, err := os.Open(flags.measurementsPath) measurements, err := os.Open(flags.measurementsPath)
if err != nil { if err != nil {

View File

@ -33,10 +33,15 @@ func runNOP(cmd *cobra.Command, provider cloudprovider.Provider, _ []string) err
log := logger.New(logger.PlainLog, flags.logLevel) log := logger.New(logger.PlainLog, flags.logLevel)
log.Debugf("Parsed flags: %+v", flags) log.Debugf("Parsed flags: %+v", flags)
archiveC, err := archive.New(cmd.Context(), flags.region, flags.bucket, log) archiveC, archiveCClose, err := archive.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log)
if err != nil { if err != nil {
return err return err
} }
defer func() {
if err := archiveCClose(cmd.Context()); err != nil {
log.Errorf("closing archive client: %v", err)
}
}()
uploadC := nopupload.New(log) uploadC := nopupload.New(log)

View File

@ -13,6 +13,7 @@ go_library(
"//internal/sigstore", "//internal/sigstore",
"//internal/staticupload", "//internal/staticupload",
"//internal/variant", "//internal/variant",
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
"@com_github_aws_aws_sdk_go_v2_service_s3//:s3", "@com_github_aws_aws_sdk_go_v2_service_s3//:s3",
], ],
) )

View File

@ -16,6 +16,7 @@ import (
"sort" "sort"
"time" "time"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfig" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfig"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
@ -27,18 +28,38 @@ import (
// Client manages (modifies) the version information for the attestation variants. // Client manages (modifies) the version information for the attestation variants.
type Client struct { type Client struct {
*staticupload.Client s3Client
cosignPwd []byte // used to decrypt the cosign private key s3ClientClose func(ctx context.Context) error
privKey []byte // used to sign bucketID string
cosignPwd []byte // used to decrypt the cosign private key
privKey []byte // used to sign
} }
// New returns a new Client. // New returns a new Client.
func New(ctx context.Context, cfg staticupload.Config, cosignPwd, privateKey []byte) (*Client, error) { func New(ctx context.Context, cfg staticupload.Config, cosignPwd, privateKey []byte) (*Client, CloseFunc, error) {
client, err := staticupload.New(ctx, cfg) client, clientClose, err := staticupload.New(ctx, cfg)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create s3 storage: %w", err) return nil, nil, fmt.Errorf("failed to create s3 storage: %w", err)
} }
return &Client{client, cosignPwd, privateKey}, nil repo := &Client{
s3Client: client,
s3ClientClose: clientClose,
bucketID: cfg.Bucket,
cosignPwd: cosignPwd,
privKey: privateKey,
}
repoClose := func(ctx context.Context) error {
return repo.Close(ctx)
}
return repo, repoClose, nil
}
// Close closes the Client.
func (a Client) Close(ctx context.Context) error {
if a.s3ClientClose == nil {
return nil
}
return a.s3ClientClose(ctx)
} }
// UploadAzureSEVSNP uploads the latest version numbers of the Azure SEVSNP. // UploadAzureSEVSNP uploads the latest version numbers of the Azure SEVSNP.
@ -51,7 +72,7 @@ func (a Client) UploadAzureSEVSNP(ctx context.Context, versions attestationconfi
fname := date.Format("2006-01-02-15-04") + ".json" fname := date.Format("2006-01-02-15-04") + ".json"
filePath := fmt.Sprintf("%s/%s/%s", constants.CDNAttestationConfigPrefixV1, variant.String(), fname) filePath := fmt.Sprintf("%s/%s/%s", constants.CDNAttestationConfigPrefixV1, variant.String(), fname)
err = put(ctx, a.Client, filePath, versionBytes) err = put(ctx, a.s3Client, a.bucketID, filePath, versionBytes)
if err != nil { if err != nil {
return err return err
} }
@ -69,7 +90,7 @@ func (a Client) createAndUploadSignature(ctx context.Context, content []byte, fi
if err != nil { if err != nil {
return fmt.Errorf("sign version file: %w", err) return fmt.Errorf("sign version file: %w", err)
} }
err = put(ctx, a.Client, filePath+".sig", signature) err = put(ctx, a.s3Client, a.bucketID, filePath+".sig", signature)
if err != nil { if err != nil {
return fmt.Errorf("upload signature: %w", err) return fmt.Errorf("upload signature: %w", err)
} }
@ -79,7 +100,7 @@ func (a Client) createAndUploadSignature(ctx context.Context, content []byte, fi
// List returns the list of versions for the given attestation type. // List returns the list of versions for the given attestation type.
func (a Client) List(ctx context.Context, attestation variant.Variant) ([]string, error) { func (a Client) List(ctx context.Context, attestation variant.Variant) ([]string, error) {
key := path.Join(constants.CDNAttestationConfigPrefixV1, attestation.String(), "list") key := path.Join(constants.CDNAttestationConfigPrefixV1, attestation.String(), "list")
bt, err := get(ctx, a.Client, key) bt, err := get(ctx, a.s3Client, a.bucketID, key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -97,13 +118,13 @@ func (a Client) DeleteList(ctx context.Context, attestation variant.Variant) err
if err != nil { if err != nil {
return err return err
} }
return put(ctx, a.Client, path.Join(constants.CDNAttestationConfigPrefixV1, attestation.String(), "list"), bt) return put(ctx, a.s3Client, a.bucketID, path.Join(constants.CDNAttestationConfigPrefixV1, attestation.String(), "list"), bt)
} }
func (a Client) addVersionToList(ctx context.Context, attestation variant.Variant, fname string) error { func (a Client) addVersionToList(ctx context.Context, attestation variant.Variant, fname string) error {
versions := []string{} versions := []string{}
key := path.Join(constants.CDNAttestationConfigPrefixV1, attestation.String(), "list") key := path.Join(constants.CDNAttestationConfigPrefixV1, attestation.String(), "list")
bt, err := get(ctx, a.Client, key) bt, err := get(ctx, a.s3Client, a.bucketID, key)
if err == nil { if err == nil {
if err := json.Unmarshal(bt, &versions); err != nil { if err := json.Unmarshal(bt, &versions); err != nil {
return err return err
@ -118,13 +139,13 @@ func (a Client) addVersionToList(ctx context.Context, attestation variant.Varian
if err != nil { if err != nil {
return err return err
} }
return put(ctx, a.Client, key, json) return put(ctx, a.s3Client, a.bucketID, key, json)
} }
// get is a convenience method. // get is a convenience method.
func get(ctx context.Context, client *staticupload.Client, path string) ([]byte, error) { func get(ctx context.Context, client s3Client, bucket, path string) ([]byte, error) {
getObjectInput := &s3.GetObjectInput{ getObjectInput := &s3.GetObjectInput{
Bucket: &client.BucketID, Bucket: &bucket,
Key: &path, Key: &path,
} }
output, err := client.GetObject(ctx, getObjectInput) output, err := client.GetObject(ctx, getObjectInput)
@ -135,12 +156,24 @@ func get(ctx context.Context, client *staticupload.Client, path string) ([]byte,
} }
// put is a convenience method. // put is a convenience method.
func put(ctx context.Context, client *staticupload.Client, path string, data []byte) error { func put(ctx context.Context, client s3Client, bucket, path string, data []byte) error {
putObjectInput := &s3.PutObjectInput{ putObjectInput := &s3.PutObjectInput{
Bucket: &client.BucketID, Bucket: &bucket,
Key: &path, Key: &path,
Body: bytes.NewReader(data), Body: bytes.NewReader(data),
} }
_, err := client.Upload(ctx, putObjectInput) _, err := client.Upload(ctx, putObjectInput)
return err return err
} }
type s3Client interface {
GetObject(
ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options),
) (*s3.GetObjectOutput, error)
Upload(
ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader),
) (*s3manager.UploadOutput, error)
}
// CloseFunc is a function that closes the client.
type CloseFunc func(ctx context.Context) error

View File

@ -74,8 +74,9 @@ var versionValues = attestationconfig.AzureSEVSNPVersion{
func TestUploadAzureSEVSNPVersions(t *testing.T) { func TestUploadAzureSEVSNPVersions(t *testing.T) {
ctx := context.Background() ctx := context.Background()
client, err := client.New(ctx, cfg, []byte(*cosignPwd), privateKey) client, clientClose, err := client.New(ctx, cfg, []byte(*cosignPwd), privateKey)
require.NoError(t, err) require.NoError(t, err)
defer func() { _ = clientClose(ctx) }()
d := time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC) d := time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC)
require.NoError(t, client.UploadAzureSEVSNP(ctx, versionValues, d)) require.NoError(t, client.UploadAzureSEVSNP(ctx, versionValues, d))
} }

View File

@ -7,11 +7,8 @@ go_library(
visibility = ["//:__subpackages__"], visibility = ["//:__subpackages__"],
deps = [ deps = [
"//internal/logger", "//internal/logger",
"@com_github_aws_aws_sdk_go_v2//aws", "//internal/staticupload",
"@com_github_aws_aws_sdk_go_v2_config//:config",
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//:cloudfront",
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//types",
"@com_github_aws_aws_sdk_go_v2_service_s3//:s3", "@com_github_aws_aws_sdk_go_v2_service_s3//:s3",
"@com_github_aws_aws_sdk_go_v2_service_s3//types", "@com_github_aws_aws_sdk_go_v2_service_s3//types",
"@org_uber_go_zap//:zap", "@org_uber_go_zap//:zap",

View File

@ -35,25 +35,20 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/cloudfront"
cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/staticupload"
"go.uber.org/zap" "go.uber.org/zap"
) )
// Client is the client for the versions API. // Client is the client for the versions API.
type Client struct { type Client struct {
config aws.Config uploadClient uploadClient
cloudfrontClient *cloudfront.Client s3Client s3Client
s3Client *s3.Client s3ClientClose func(ctx context.Context) error
uploadClient *s3manager.Uploader
bucket string bucket string
distributionID string
cacheInvalidationWaitTimeout time.Duration cacheInvalidationWaitTimeout time.Duration
dirtyPaths []string // written paths to be invalidated dirtyPaths []string // written paths to be invalidated
@ -66,90 +61,68 @@ type Client struct {
// This client can be used to fetch objects but cannot write updates. // This client can be used to fetch objects but cannot write updates.
func NewReadOnlyClient(ctx context.Context, region, bucket, distributionID string, func NewReadOnlyClient(ctx context.Context, region, bucket, distributionID string,
log *logger.Logger, log *logger.Logger,
) (*Client, error) { ) (*Client, CloseFunc, error) {
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{
Region: region,
Bucket: bucket,
DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
})
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
s3c := s3.NewFromConfig(cfg)
return &Client{ client := &Client{
config: cfg, s3Client: staticUploadClient,
s3Client: s3c, s3ClientClose: staticUploadClientClose,
bucket: bucket, bucket: bucket,
distributionID: distributionID, DryRun: true,
DryRun: true, Log: log,
Log: log, }
}, nil clientClose := func(ctx context.Context) error {
return client.Close(ctx)
}
return client, clientClose, nil
} }
// NewClient creates a new client for the versions API. // NewClient creates a new client for the versions API.
func NewClient(ctx context.Context, region, bucket, distributionID string, dryRun bool, func NewClient(ctx context.Context, region, bucket, distributionID string, dryRun bool,
log *logger.Logger, log *logger.Logger,
) (*Client, error) { ) (*Client, CloseFunc, error) {
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{
Region: region,
Bucket: bucket,
DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
})
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
cloudfrontC := cloudfront.NewFromConfig(cfg) client := &Client{
s3C := s3.NewFromConfig(cfg) uploadClient: staticUploadClient,
uploadC := s3manager.NewUploader(s3C) s3Client: staticUploadClient,
s3ClientClose: staticUploadClientClose,
return &Client{
config: cfg,
cloudfrontClient: cloudfrontC,
s3Client: s3C,
uploadClient: uploadC,
bucket: bucket, bucket: bucket,
distributionID: distributionID,
DryRun: dryRun, DryRun: dryRun,
Log: log, Log: log,
cacheInvalidationWaitTimeout: 10 * time.Minute, cacheInvalidationWaitTimeout: 10 * time.Minute,
}, nil }
clientClose := func(ctx context.Context) error {
return client.Close(ctx)
}
return client, clientClose, nil
} }
// InvalidateCache invalidates the CDN cache for the paths that have been written. // Close closes the client.
// The function should be deferred after the client has been created. // It invalidates the CDN cache for all uploaded files.
func (c *Client) InvalidateCache(ctx context.Context) error { func (c *Client) Close(ctx context.Context) error {
if len(c.dirtyPaths) == 0 { if c.s3ClientClose == nil {
c.Log.Debugf("No dirty paths, skipping cache invalidation")
return nil return nil
} }
return c.s3ClientClose(ctx)
if c.DryRun {
c.Log.With(zap.String("distributionID", c.distributionID), zap.Strings("dirtyPaths", c.dirtyPaths)).Debugf("DryRun: cloudfront create invalidation")
return nil
}
c.Log.Debugf("Paths to invalidate: %v", c.dirtyPaths)
in := &cloudfront.CreateInvalidationInput{
DistributionId: &c.distributionID,
InvalidationBatch: &cftypes.InvalidationBatch{
CallerReference: ptr(fmt.Sprintf("%d", time.Now().Unix())),
Paths: &cftypes.Paths{
Items: c.dirtyPaths,
Quantity: ptr(int32(len(c.dirtyPaths))),
},
},
}
invalidation, err := c.cloudfrontClient.CreateInvalidation(ctx, in)
if err != nil {
return fmt.Errorf("creating invalidation: %w", err)
}
c.Log.Debugf("Waiting for invalidation %s to complete", *invalidation.Invalidation.Id)
waiter := cloudfront.NewInvalidationCompletedWaiter(c.cloudfrontClient)
waitIn := &cloudfront.GetInvalidationInput{
DistributionId: &c.distributionID,
Id: invalidation.Invalidation.Id,
}
if err := waiter.Wait(ctx, waitIn, c.cacheInvalidationWaitTimeout); err != nil {
return fmt.Errorf("waiting for invalidation to complete: %w", err)
}
return nil
} }
// DeletePath deletes all objects at a given path from a s3 bucket. // DeletePath deletes all objects at a given path from a s3 bucket.
@ -289,3 +262,22 @@ func (e *NotFoundError) Error() string {
func (e *NotFoundError) Unwrap() error { func (e *NotFoundError) Unwrap() error {
return e.err return e.err
} }
type s3Client interface {
GetObject(
ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options),
) (*s3.GetObjectOutput, error)
ListObjectsV2(
ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options),
) (*s3.ListObjectsV2Output, error)
DeleteObjects(
ctx context.Context, params *s3.DeleteObjectsInput, optFns ...func(*s3.Options),
) (*s3.DeleteObjectsOutput, error)
}
type uploadClient interface {
Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
}
// CloseFunc is a function that closes the client.
type CloseFunc func(ctx context.Context) error

View File

@ -73,13 +73,13 @@ func runAdd(cmd *cobra.Command, _ []string) (retErr error) {
} }
log.Debugf("Creating versions API client") log.Debugf("Creating versions API client")
client, err := verclient.NewClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, flags.dryRun, log) client, clientClose, err := verclient.NewClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, flags.dryRun, log)
if err != nil { if err != nil {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("creating client: %w", err)
} }
defer func(retErr *error) { defer func(retErr *error) {
log.Infof("Invalidating cache. This may take some time") log.Infof("Invalidating cache. This may take some time")
if err := client.InvalidateCache(cmd.Context()); err != nil && retErr == nil { if err := clientClose(cmd.Context()); err != nil && retErr == nil {
*retErr = fmt.Errorf("invalidating cache: %w", err) *retErr = fmt.Errorf("invalidating cache: %w", err)
} }
}(&retErr) }(&retErr)

View File

@ -47,10 +47,15 @@ func runLatest(cmd *cobra.Command, _ []string) error {
} }
log.Debugf("Creating versions API client") log.Debugf("Creating versions API client")
client, err := verclient.NewReadOnlyClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) client, clientClose, err := verclient.NewReadOnlyClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log)
if err != nil { if err != nil {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("creating client: %w", err)
} }
defer func() {
if err := clientClose(cmd.Context()); err != nil {
log.Errorf("Closing versions API client: %v", err)
}
}()
log.Debugf("Requesting latest version") log.Debugf("Requesting latest version")
latest := versionsapi.Latest{ latest := versionsapi.Latest{

View File

@ -53,10 +53,15 @@ func runList(cmd *cobra.Command, _ []string) error {
} }
log.Debugf("Creating versions API client") log.Debugf("Creating versions API client")
client, err := verclient.NewReadOnlyClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) client, clientClose, err := verclient.NewReadOnlyClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log)
if err != nil { if err != nil {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("creating client: %w", err)
} }
defer func() {
if err := clientClose(cmd.Context()); err != nil {
log.Errorf("Closing versions API client: %v", err)
}
}()
var minorVersions []string var minorVersions []string
if flags.minorVersion != "" { if flags.minorVersion != "" {

View File

@ -102,13 +102,13 @@ func runRemove(cmd *cobra.Command, _ []string) (retErr error) {
} }
log.Debugf("Creating versions API client") log.Debugf("Creating versions API client")
verclient, err := verclient.NewClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, flags.dryrun, log) verclient, verclientClose, err := verclient.NewClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, flags.dryrun, log)
if err != nil { if err != nil {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("creating client: %w", err)
} }
defer func(retErr *error) { defer func(retErr *error) {
log.Infof("Invalidating cache. This may take some time") log.Infof("Invalidating cache. This may take some time")
if err := verclient.InvalidateCache(cmd.Context()); err != nil && retErr == nil { if err := verclientClose(cmd.Context()); err != nil && retErr == nil {
*retErr = fmt.Errorf("invalidating cache: %w", err) *retErr = fmt.Errorf("invalidating cache: %w", err)
} }
}(&retErr) }(&retErr)

View File

@ -27,23 +27,49 @@ import (
// VersionsClient is a client for the versions API. // VersionsClient is a client for the versions API.
type VersionsClient struct { type VersionsClient struct {
*apiclient.Client *apiclient.Client
clientClose func(ctx context.Context) error
} }
// NewClient creates a new client for the versions API. // NewClient creates a new client for the versions API.
func NewClient(ctx context.Context, region, bucket, distributionID string, dryRun bool, func NewClient(ctx context.Context, region, bucket, distributionID string, dryRun bool,
log *logger.Logger, log *logger.Logger,
) (*VersionsClient, error) { ) (*VersionsClient, CloseFunc, error) {
genericClient, err := apiclient.NewClient(ctx, region, bucket, distributionID, dryRun, log) genericClient, genericClientClose, err := apiclient.NewClient(ctx, region, bucket, distributionID, dryRun, log)
return &VersionsClient{genericClient}, err versionsClient := &VersionsClient{
genericClient,
genericClientClose,
}
versionsClientClose := func(ctx context.Context) error {
return versionsClient.Close(ctx)
}
return versionsClient, versionsClientClose, err
} }
// NewReadOnlyClient creates a new read-only client. // NewReadOnlyClient creates a new read-only client.
// This client can be used to fetch objects but cannot write updates. // This client can be used to fetch objects but cannot write updates.
func NewReadOnlyClient(ctx context.Context, region, bucket, distributionID string, func NewReadOnlyClient(ctx context.Context, region, bucket, distributionID string,
log *logger.Logger, log *logger.Logger,
) (*VersionsClient, error) { ) (*VersionsClient, CloseFunc, error) {
genericClient, err := apiclient.NewReadOnlyClient(ctx, region, bucket, distributionID, log) genericClient, genericClientClose, err := apiclient.NewReadOnlyClient(ctx, region, bucket, distributionID, log)
return &VersionsClient{genericClient}, err if err != nil {
return nil, nil, err
}
versionsClient := &VersionsClient{
genericClient,
genericClientClose,
}
versionsClientClose := func(ctx context.Context) error {
return versionsClient.Close(ctx)
}
return versionsClient, versionsClientClose, err
}
// Close closes the client.
func (c *VersionsClient) Close(ctx context.Context) error {
if c.clientClose == nil {
return nil
}
return c.clientClose(ctx)
} }
// FetchVersionList fetches the given version list from the versions API. // FetchVersionList fetches the given version list from the versions API.
@ -228,3 +254,6 @@ func (c *VersionsClient) deleteVersionFromLatest(ctx context.Context, ver versio
return nil return nil
} }
// CloseFunc is a function that closes the client.
type CloseFunc func(ctx context.Context) error

View File

@ -9,7 +9,7 @@ go_library(
"//internal/api/versions", "//internal/api/versions",
"//internal/constants", "//internal/constants",
"//internal/logger", "//internal/logger",
"@com_github_aws_aws_sdk_go_v2_config//:config", "//internal/staticupload",
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
"@com_github_aws_aws_sdk_go_v2_service_s3//:s3", "@com_github_aws_aws_sdk_go_v2_service_s3//:s3",
"@com_github_aws_aws_sdk_go_v2_service_s3//types", "@com_github_aws_aws_sdk_go_v2_service_s3//types",

View File

@ -12,18 +12,19 @@ import (
"io" "io"
"net/url" "net/url"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
versionsapi "github.com/edgelesssys/constellation/v2/internal/api/versions" versionsapi "github.com/edgelesssys/constellation/v2/internal/api/versions"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/staticupload"
) )
// Archivist uploads OS images to S3. // Archivist uploads OS images to S3.
type Archivist struct { type Archivist struct {
uploadClient uploadClient uploadClient uploadClient
uploadClientClose func(ctx context.Context) error
// bucket is the name of the S3 bucket to use. // bucket is the name of the S3 bucket to use.
bucket string bucket string
@ -31,19 +32,37 @@ type Archivist struct {
} }
// New creates a new Archivist. // New creates a new Archivist.
func New(ctx context.Context, region, bucket string, log *logger.Logger) (*Archivist, error) { func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Archivist, CloseFunc, error) {
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{
Region: region,
Bucket: bucket,
DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
})
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
s3client := s3.NewFromConfig(cfg)
uploadClient := s3manager.NewUploader(s3client)
return &Archivist{ archivist := &Archivist{
uploadClient: uploadClient, uploadClient: staticUploadClient,
bucket: bucket, uploadClientClose: staticUploadClientClose,
log: log, bucket: bucket,
}, nil log: log,
}
archivistClose := func(ctx context.Context) error {
return archivist.Close(ctx)
}
return archivist, archivistClose, nil
}
// Close closes the uploader.
// It invalidates the CDN cache for all uploaded files.
func (a *Archivist) Close(ctx context.Context) error {
if a.uploadClientClose == nil {
return nil
}
return a.uploadClientClose(ctx)
} }
// Archive reads the OS image in img and uploads it as key. // Archive reads the OS image in img and uploads it as key.
@ -65,3 +84,6 @@ func (a *Archivist) Archive(ctx context.Context, version versionsapi.Version, cs
type uploadClient interface { type uploadClient interface {
Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
} }
// CloseFunc is a function that closes the client.
type CloseFunc func(ctx context.Context) error

View File

@ -9,7 +9,7 @@ go_library(
"//internal/api/versions", "//internal/api/versions",
"//internal/constants", "//internal/constants",
"//internal/logger", "//internal/logger",
"@com_github_aws_aws_sdk_go_v2_config//:config", "//internal/staticupload",
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
"@com_github_aws_aws_sdk_go_v2_service_s3//:s3", "@com_github_aws_aws_sdk_go_v2_service_s3//:s3",
"@com_github_aws_aws_sdk_go_v2_service_s3//types", "@com_github_aws_aws_sdk_go_v2_service_s3//types",

View File

@ -13,18 +13,19 @@ import (
"encoding/json" "encoding/json"
"net/url" "net/url"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
versionsapi "github.com/edgelesssys/constellation/v2/internal/api/versions" versionsapi "github.com/edgelesssys/constellation/v2/internal/api/versions"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/staticupload"
) )
// Uploader uploads image info to S3. // Uploader uploads image info to S3.
type Uploader struct { type Uploader struct {
uploadClient uploadClient uploadClient uploadClient
uploadClientClose func(ctx context.Context) error
// bucket is the name of the S3 bucket to use. // bucket is the name of the S3 bucket to use.
bucket string bucket string
@ -32,19 +33,36 @@ type Uploader struct {
} }
// New creates a new Uploader. // New creates a new Uploader.
func New(ctx context.Context, region, bucket string, log *logger.Logger) (*Uploader, error) { func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Uploader, CloseFunc, error) {
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{
Region: region,
Bucket: bucket,
DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
})
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
s3client := s3.NewFromConfig(cfg)
uploadClient := s3manager.NewUploader(s3client)
return &Uploader{ uploader := &Uploader{
uploadClient: uploadClient, uploadClient: staticUploadClient,
bucket: bucket, uploadClientClose: staticUploadClientClose,
log: log, bucket: bucket,
}, nil log: log,
}
uploaderClose := func(ctx context.Context) error {
return uploader.Close(ctx)
}
return uploader, uploaderClose, nil
}
// Close closes the uploader.
// It invalidates the CDN cache for all uploaded files.
func (a *Uploader) Close(ctx context.Context) error {
if a.uploadClientClose == nil {
return nil
}
return a.uploadClientClose(ctx)
} }
// Upload marshals the image info to JSON and uploads it to S3. // Upload marshals the image info to JSON and uploads it to S3.
@ -76,3 +94,6 @@ func (a *Uploader) Upload(ctx context.Context, imageInfo versionsapi.ImageInfo)
type uploadClient interface { type uploadClient interface {
Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
} }
// CloseFunc is a function that closes the client.
type CloseFunc func(ctx context.Context) error

View File

@ -10,7 +10,7 @@ go_library(
"//internal/attestation/measurements", "//internal/attestation/measurements",
"//internal/constants", "//internal/constants",
"//internal/logger", "//internal/logger",
"@com_github_aws_aws_sdk_go_v2_config//:config", "//internal/staticupload",
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
"@com_github_aws_aws_sdk_go_v2_service_s3//:s3", "@com_github_aws_aws_sdk_go_v2_service_s3//:s3",
"@com_github_aws_aws_sdk_go_v2_service_s3//types", "@com_github_aws_aws_sdk_go_v2_service_s3//types",

View File

@ -14,7 +14,6 @@ import (
"io" "io"
"net/url" "net/url"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
@ -22,11 +21,13 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/staticupload"
) )
// Uploader uploads image info to S3. // Uploader uploads image info to S3.
type Uploader struct { type Uploader struct {
uploadClient uploadClient uploadClient uploadClient
uploadClientClose func(ctx context.Context) error
// bucket is the name of the S3 bucket to use. // bucket is the name of the S3 bucket to use.
bucket string bucket string
@ -34,19 +35,36 @@ type Uploader struct {
} }
// New creates a new Uploader. // New creates a new Uploader.
func New(ctx context.Context, region, bucket string, log *logger.Logger) (*Uploader, error) { func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Uploader, CloseFunc, error) {
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{
Region: region,
Bucket: bucket,
DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
})
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
s3client := s3.NewFromConfig(cfg)
uploadClient := s3manager.NewUploader(s3client)
return &Uploader{ uploader := &Uploader{
uploadClient: uploadClient, uploadClient: staticUploadClient,
bucket: bucket, uploadClientClose: staticUploadClientClose,
log: log, bucket: bucket,
}, nil log: log,
}
uploaderClose := func(ctx context.Context) error {
return uploader.Close(ctx)
}
return uploader, uploaderClose, nil
}
// Close closes the uploader.
// It invalidates the CDN cache for all uploaded files.
func (a *Uploader) Close(ctx context.Context) error {
if a.uploadClientClose == nil {
return nil
}
return a.uploadClientClose(ctx)
} }
// Upload uploads the measurements v2 JSON file and its signature to S3. // Upload uploads the measurements v2 JSON file and its signature to S3.
@ -97,3 +115,6 @@ func (a *Uploader) Upload(ctx context.Context, rawMeasurement, signature io.Read
type uploadClient interface { type uploadClient interface {
Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
} }
// CloseFunc is a function that closes the client.
type CloseFunc func(ctx context.Context) error

View File

@ -12,7 +12,16 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
) )
// GetObject returns an object from from AWS S3 Storage. // GetObject retrieves objects from Amazon S3.
func (s *Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { func (c *Client) GetObject(
return s.s3Client.GetObject(ctx, params, optFns...) ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options),
) (*s3.GetObjectOutput, error) {
return c.s3Client.GetObject(ctx, params, optFns...)
}
// ListObjectsV2 returns some or all (up to 1,000) of the objects in a bucket.
func (c *Client) ListObjectsV2(
ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options),
) (*s3.ListObjectsV2Output, error) {
return c.s3Client.ListObjectsV2(ctx, params, optFns...)
} }

View File

@ -36,7 +36,7 @@ type Client struct {
uploadClient uploadClient uploadClient uploadClient
s3Client objectStorageClient s3Client objectStorageClient
distributionID string distributionID string
BucketID string bucketID string
cacheInvalidationStrategy CacheInvalidationStrategy cacheInvalidationStrategy CacheInvalidationStrategy
cacheInvalidationWaitTimeout time.Duration cacheInvalidationWaitTimeout time.Duration
@ -73,9 +73,9 @@ type CacheInvalidationStrategy int
const ( const (
// CacheInvalidateEager invalidates the CDN cache immediately for every key that is uploaded. // CacheInvalidateEager invalidates the CDN cache immediately for every key that is uploaded.
CacheInvalidateEager CacheInvalidationStrategy = iota CacheInvalidateEager CacheInvalidationStrategy = iota
// CacheInvalidateBatchOnClose invalidates the CDN cache in batches when the client is closed. // CacheInvalidateBatchOnFlush invalidates the CDN cache in batches when the client is flushed / closed.
// This is useful when uploading many files at once but will fail if Close is not called. // This is useful when uploading many files at once but may fail to invalidate the cache if close is not called.
CacheInvalidateBatchOnClose CacheInvalidateBatchOnFlush
) )
// InvalidationError is an error that occurs when invalidating the CDN cache. // InvalidationError is an error that occurs when invalidating the CDN cache.
@ -94,33 +94,39 @@ func (e InvalidationError) Unwrap() error {
} }
// New creates a new Client. // New creates a new Client.
func New(ctx context.Context, config Config) (*Client, error) { func New(ctx context.Context, config Config) (*Client, CloseFunc, error) {
config.SetsDefault() config.SetsDefault()
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(config.Region)) cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(config.Region))
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
s3Client := s3.NewFromConfig(cfg) s3Client := s3.NewFromConfig(cfg)
uploadClient := s3manager.NewUploader(s3Client) uploadClient := s3manager.NewUploader(s3Client)
cdnClient := cloudfront.NewFromConfig(cfg) cdnClient := cloudfront.NewFromConfig(cfg)
return &Client{ client := &Client{
cdnClient: cdnClient, cdnClient: cdnClient,
s3Client: s3Client, s3Client: s3Client,
uploadClient: uploadClient, uploadClient: uploadClient,
distributionID: config.DistributionID, distributionID: config.DistributionID,
cacheInvalidationStrategy: config.CacheInvalidationStrategy, cacheInvalidationStrategy: config.CacheInvalidationStrategy,
cacheInvalidationWaitTimeout: config.CacheInvalidationWaitTimeout, cacheInvalidationWaitTimeout: config.CacheInvalidationWaitTimeout,
BucketID: config.Bucket, bucketID: config.Bucket,
}, nil }
clientClose := func(ctx context.Context) error {
// ensure that all keys are invalidated
return client.Flush(ctx)
}
return client, clientClose, nil
} }
// Close closes the client. // Flush flushes the client by invalidating the CDN cache for modified keys.
// It waits for all invalidations to finish. // It waits for all invalidations to finish.
// It returns nil on success or an error. // It returns nil on success or an error.
// The error will be of type InvalidationError if the CDN cache could not be invalidated. // The error will be of type InvalidationError if the CDN cache could not be invalidated.
func (c *Client) Close(ctx context.Context) error { func (c *Client) Flush(ctx context.Context) error {
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()
@ -138,7 +144,7 @@ func (c *Client) Close(ctx context.Context) error {
// invalidate invalidates the CDN cache for the given keys. // invalidate invalidates the CDN cache for the given keys.
// It either performs the invalidation immediately or adds them to the list of dirty keys. // It either performs the invalidation immediately or adds them to the list of dirty keys.
func (c *Client) invalidate(ctx context.Context, keys []string) error { func (c *Client) invalidate(ctx context.Context, keys []string) error {
if c.cacheInvalidationStrategy == CacheInvalidateBatchOnClose { if c.cacheInvalidationStrategy == CacheInvalidateBatchOnFlush {
// save as dirty key for batch invalidation on Close // save as dirty key for batch invalidation on Close
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()
@ -218,7 +224,15 @@ type uploadClient interface {
} }
type getClient interface { type getClient interface {
GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) GetObject(
ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options),
) (*s3.GetObjectOutput, error)
}
type listClient interface {
ListObjectsV2(
ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options),
) (*s3.ListObjectsV2Output, error)
} }
type deleteClient interface { type deleteClient interface {
@ -241,8 +255,9 @@ type cdnClient interface {
} }
type objectStorageClient interface { type objectStorageClient interface {
deleteClient
getClient getClient
listClient
deleteClient
} }
// statically assert that Client implements the uploadClient interface. // statically assert that Client implements the uploadClient interface.
@ -254,3 +269,6 @@ var _ objectStorageClient = (*Client)(nil)
func ptr[T any](t T) *T { func ptr[T any](t T) *T {
return &t return &t
} }
// CloseFunc is a function that closes the client.
type CloseFunc func(ctx context.Context) error

View File

@ -59,7 +59,7 @@ func TestUpload(t *testing.T) {
}, },
"lazy invalidation": { "lazy invalidation": {
in: newInput(), in: newInput(),
cacheInvalidationStrategy: CacheInvalidateBatchOnClose, cacheInvalidationStrategy: CacheInvalidateBatchOnFlush,
cacheInvalidationWaitTimeout: time.Microsecond, cacheInvalidationWaitTimeout: time.Microsecond,
wantDirtyKeys: []string{"test-key"}, wantDirtyKeys: []string{"test-key"},
}, },
@ -181,7 +181,7 @@ func TestDeleteObject(t *testing.T) {
wantInvalidationIDs: []string{"test-invalidation-id-1"}, wantInvalidationIDs: []string{"test-invalidation-id-1"},
}, },
"lazy invalidation": { "lazy invalidation": {
cacheInvalidationStrategy: CacheInvalidateBatchOnClose, cacheInvalidationStrategy: CacheInvalidateBatchOnFlush,
cacheInvalidationWaitTimeout: time.Microsecond, cacheInvalidationWaitTimeout: time.Microsecond,
wantDirtyKeys: []string{"test-key"}, wantDirtyKeys: []string{"test-key"},
}, },
@ -273,7 +273,7 @@ func TestDeleteObject(t *testing.T) {
} }
} }
func TestClose(t *testing.T) { func TestFlush(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
dirtyKeys []string dirtyKeys []string
invalidationIDs []string invalidationIDs []string
@ -389,7 +389,7 @@ func TestClose(t *testing.T) {
dirtyKeys: tc.dirtyKeys, dirtyKeys: tc.dirtyKeys,
invalidationIDs: tc.invalidationIDs, invalidationIDs: tc.invalidationIDs,
} }
err := client.Close(context.Background()) err := client.Flush(context.Background())
if tc.wantCacheInvalidationErr { if tc.wantCacheInvalidationErr {
assert.ErrorAs(err, &InvalidationError{}) assert.ErrorAs(err, &InvalidationError{})
@ -450,9 +450,9 @@ func TestConcurrency(_ *testing.T) {
}, },
}) })
} }
closeClient := func() { flushClient := func() {
defer wg.Done() defer wg.Done()
_ = client.Close(context.Background()) _ = client.Flush(context.Background())
} }
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
@ -460,7 +460,7 @@ func TestConcurrency(_ *testing.T) {
go upload() go upload()
go deleteObject() go deleteObject()
go deleteObjects() go deleteObjects()
go closeClient() go flushClient()
} }
wg.Wait() wg.Wait()
@ -559,3 +559,11 @@ func (s *stubObjectStorageClient) GetObject(
) (*s3.GetObjectOutput, error) { ) (*s3.GetObjectOutput, error) {
return nil, nil return nil, nil
} }
// currently not needed so no-Op.
func (s *stubObjectStorageClient) ListObjectsV2(
_ context.Context, _ *s3.ListObjectsV2Input,
_ ...func(*s3.Options),
) (*s3.ListObjectsV2Output, error) {
return nil, nil
}