From e1d3afe8d4e1ea64a209af7bf667c7fa4b1f8bb5 Mon Sep 17 00:00:00 2001 From: Malte Poll <1780588+malt3@users.noreply.github.com> Date: Fri, 2 Jun 2023 11:20:01 +0200 Subject: [PATCH] ci: use aws s3 client that invalidates cloudfront cache for places that modify Constellation api (#1839) --- hack/cli-k8s-compatibility/main.go | 4 +- hack/configapi/cmd/root.go | 7 +- image/upload/internal/cmd/aws.go | 7 +- image/upload/internal/cmd/azure.go | 7 +- image/upload/internal/cmd/flags.go | 24 ++- image/upload/internal/cmd/gcp.go | 7 +- image/upload/internal/cmd/image.go | 1 + image/upload/internal/cmd/info.go | 8 +- .../upload/internal/cmd/measurementsupload.go | 8 +- image/upload/internal/cmd/nop.go | 7 +- .../api/attestationconfig/client/BUILD.bazel | 1 + .../api/attestationconfig/client/client.go | 67 ++++++--- .../attestationconfig/client/client_test.go | 3 +- internal/api/client/BUILD.bazel | 5 +- internal/api/client/client.go | 140 +++++++++--------- internal/api/versions/cli/add.go | 4 +- internal/api/versions/cli/latest.go | 7 +- internal/api/versions/cli/list.go | 7 +- internal/api/versions/cli/rm.go | 4 +- internal/api/versions/client/client.go | 41 ++++- internal/osimage/archive/BUILD.bazel | 2 +- internal/osimage/archive/archive.go | 46 ++++-- internal/osimage/imageinfo/BUILD.bazel | 2 +- internal/osimage/imageinfo/imageinfo.go | 45 ++++-- .../osimage/measurementsuploader/BUILD.bazel | 2 +- .../measurementsuploader.go | 45 ++++-- internal/staticupload/get.go | 15 +- internal/staticupload/staticupload.go | 46 ++++-- internal/staticupload/staticupload_test.go | 22 ++- 29 files changed, 398 insertions(+), 186 deletions(-) diff --git a/hack/cli-k8s-compatibility/main.go b/hack/cli-k8s-compatibility/main.go index c3e07efd3..31c22b8b9 100644 --- a/hack/cli-k8s-compatibility/main.go +++ b/hack/cli-k8s-compatibility/main.go @@ -50,12 +50,12 @@ func main() { cliInfo.Kubernetes = append(cliInfo.Kubernetes, v.ClusterVersion) } - c, err := client.NewClient(ctx, "eu-central-1", "cdn-constellation-backend", "E1H77EZTHC3NE4", false, log) + c, cclose, err := client.NewClient(ctx, "eu-central-1", "cdn-constellation-backend", "E1H77EZTHC3NE4", false, log) if err != nil { log.Fatalf("creating s3 client: %w", err) } defer func() { - if err := c.InvalidateCache(ctx); err != nil { + if err := cclose(ctx); err != nil { log.Fatalf("invalidating cache: %w", err) } }() diff --git a/hack/configapi/cmd/root.go b/hack/configapi/cmd/root.go index c86b97d3c..1b5ece94e 100644 --- a/hack/configapi/cmd/root.go +++ b/hack/configapi/cmd/root.go @@ -76,10 +76,15 @@ func runCmd(cmd *cobra.Command, _ []string) error { return fmt.Errorf("unmarshalling version file: %w", err) } - sut, err := attestationconfigclient.New(ctx, cfg, []byte(cosignPwd), privateKey) + sut, sutClose, err := attestationconfigclient.New(ctx, cfg, []byte(cosignPwd), privateKey) if err != nil { return fmt.Errorf("creating repo: %w", err) } + defer func() { + if err := sutClose(ctx); err != nil { + fmt.Printf("closing repo: %v\n", err) + } + }() if err := sut.UploadAzureSEVSNP(ctx, versions, time.Now()); err != nil { return fmt.Errorf("uploading version: %w", err) diff --git a/image/upload/internal/cmd/aws.go b/image/upload/internal/cmd/aws.go index d64b235b1..7adfc9ece 100644 --- a/image/upload/internal/cmd/aws.go +++ b/image/upload/internal/cmd/aws.go @@ -46,10 +46,15 @@ func runAWS(cmd *cobra.Command, _ []string) error { log := logger.New(logger.PlainLog, flags.logLevel) log.Debugf("Parsed flags: %+v", flags) - archiveC, err := archive.New(cmd.Context(), flags.region, flags.bucket, log) + archiveC, archiveCClose, err := archive.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) if err != nil { return err } + defer func() { + if err := archiveCClose(cmd.Context()); err != nil { + log.Errorf("closing archive client: %v", err) + } + }() uploadC, err := awsupload.New(flags.awsRegion, flags.awsBucket, log) if err != nil { diff --git a/image/upload/internal/cmd/azure.go b/image/upload/internal/cmd/azure.go index 749976f83..a34f26f33 100644 --- a/image/upload/internal/cmd/azure.go +++ b/image/upload/internal/cmd/azure.go @@ -47,10 +47,15 @@ func runAzure(cmd *cobra.Command, _ []string) error { log := logger.New(logger.PlainLog, flags.logLevel) log.Debugf("Parsed flags: %+v", flags) - archiveC, err := archive.New(cmd.Context(), flags.region, flags.bucket, log) + archiveC, archiveCClose, err := archive.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) if err != nil { return err } + defer func() { + if err := archiveCClose(cmd.Context()); err != nil { + log.Errorf("closing archive client: %v", err) + } + }() uploadC, err := azureupload.New(flags.azSubscription, flags.azLocation, flags.azResourceGroup, log) if err != nil { diff --git a/image/upload/internal/cmd/flags.go b/image/upload/internal/cmd/flags.go index aaf76f288..eee221dd4 100644 --- a/image/upload/internal/cmd/flags.go +++ b/image/upload/internal/cmd/flags.go @@ -27,6 +27,7 @@ type commonFlags struct { timestamp time.Time region string bucket string + distributionID string out string logLevel zapcore.Level } @@ -75,6 +76,10 @@ func parseCommonFlags(cmd *cobra.Command) (commonFlags, error) { if err != nil { return commonFlags{}, err } + distributionID, err := cmd.Flags().GetString("distribution-id") + if err != nil { + return commonFlags{}, err + } out, err := cmd.Flags().GetString("out") if err != nil { return commonFlags{}, err @@ -96,6 +101,7 @@ func parseCommonFlags(cmd *cobra.Command) (commonFlags, error) { timestamp: timestmp, region: region, bucket: bucket, + distributionID: distributionID, out: out, logLevel: logLevel, }, nil @@ -201,9 +207,10 @@ func parseGCPFlags(cmd *cobra.Command) (gcpFlags, error) { } type s3Flags struct { - region string - bucket string - logLevel zapcore.Level + region string + bucket string + distributionID string + logLevel zapcore.Level } func parseS3Flags(cmd *cobra.Command) (s3Flags, error) { @@ -215,6 +222,10 @@ func parseS3Flags(cmd *cobra.Command) (s3Flags, error) { if err != nil { return s3Flags{}, err } + distributionID, err := cmd.Flags().GetString("distribution-id") + if err != nil { + return s3Flags{}, err + } verbose, err := cmd.Flags().GetBool("verbose") if err != nil { return s3Flags{}, err @@ -225,9 +236,10 @@ func parseS3Flags(cmd *cobra.Command) (s3Flags, error) { } return s3Flags{ - region: region, - bucket: bucket, - logLevel: logLevel, + region: region, + bucket: bucket, + distributionID: distributionID, + logLevel: logLevel, }, nil } diff --git a/image/upload/internal/cmd/gcp.go b/image/upload/internal/cmd/gcp.go index efdc0bd57..7973636f4 100644 --- a/image/upload/internal/cmd/gcp.go +++ b/image/upload/internal/cmd/gcp.go @@ -47,10 +47,15 @@ func runGCP(cmd *cobra.Command, _ []string) error { log := logger.New(logger.PlainLog, flags.logLevel) log.Debugf("Parsed flags: %+v", flags) - archiveC, err := archive.New(cmd.Context(), flags.region, flags.bucket, log) + archiveC, archiveCClose, err := archive.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) if err != nil { return err } + defer func() { + if err := archiveCClose(cmd.Context()); err != nil { + log.Errorf("closing archive client: %v", err) + } + }() uploadC, err := gcpupload.New(cmd.Context(), flags.gcpProject, flags.gcpLocation, flags.gcpBucket, log) if err != nil { diff --git a/image/upload/internal/cmd/image.go b/image/upload/internal/cmd/image.go index 5c92cb5a8..323bef7e2 100644 --- a/image/upload/internal/cmd/image.go +++ b/image/upload/internal/cmd/image.go @@ -31,6 +31,7 @@ func NewImageCmd() *cobra.Command { cmd.PersistentFlags().String("timestamp", "", "Optional timestamp to use for resource names. Uses format 2006-01-02T15:04:05Z07:00.") cmd.PersistentFlags().String("region", "eu-central-1", "AWS region of the archive S3 bucket") cmd.PersistentFlags().String("bucket", "cdn-constellation-backend", "S3 bucket name of the archive") + cmd.PersistentFlags().String("distribution-id", "E1H77EZTHC3NE4", "CloudFront distribution ID of the API") cmd.PersistentFlags().String("out", "", "Optional path to write the upload result to. If not set, the result is written to stdout.") cmd.PersistentFlags().Bool("verbose", false, "Enable verbose output") must(cmd.MarkPersistentFlagRequired("raw-image")) diff --git a/image/upload/internal/cmd/info.go b/image/upload/internal/cmd/info.go index b74a1d376..01a2bc138 100644 --- a/image/upload/internal/cmd/info.go +++ b/image/upload/internal/cmd/info.go @@ -31,6 +31,7 @@ func NewInfoCmd() *cobra.Command { cmd.Flags().String("region", "eu-central-1", "AWS region of the archive S3 bucket") cmd.Flags().String("bucket", "cdn-constellation-backend", "S3 bucket name of the archive") + cmd.Flags().String("distribution-id", "E1H77EZTHC3NE4", "CloudFront distribution ID of the API") cmd.Flags().Bool("verbose", false, "Enable verbose output") return cmd @@ -54,10 +55,15 @@ func runInfo(cmd *cobra.Command, args []string) error { return err } - uploadC, err := infoupload.New(cmd.Context(), flags.region, flags.bucket, log) + uploadC, uploadCClose, err := infoupload.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) if err != nil { return fmt.Errorf("uploading image info: %w", err) } + defer func() { + if err := uploadCClose(cmd.Context()); err != nil { + log.Errorf("closing upload client: %v", err) + } + }() url, err := uploadC.Upload(cmd.Context(), info) if err != nil { diff --git a/image/upload/internal/cmd/measurementsupload.go b/image/upload/internal/cmd/measurementsupload.go index 131074db3..4398a0dfc 100644 --- a/image/upload/internal/cmd/measurementsupload.go +++ b/image/upload/internal/cmd/measurementsupload.go @@ -31,6 +31,7 @@ func newMeasurementsUploadCmd() *cobra.Command { cmd.Flags().String("signature", "", "Path to signature file to upload") cmd.Flags().String("region", "eu-central-1", "AWS region of the archive S3 bucket") cmd.Flags().String("bucket", "cdn-constellation-backend", "S3 bucket name of the archive") + cmd.Flags().String("distribution-id", "E1H77EZTHC3NE4", "CloudFront distribution ID of the API") cmd.Flags().Bool("verbose", false, "Enable verbose output") must(cmd.MarkFlagRequired("measurements")) @@ -53,10 +54,15 @@ func runMeasurementsUpload(cmd *cobra.Command, _ []string) error { log := logger.New(logger.PlainLog, flags.logLevel) log.Debugf("Parsed flags: %+v", flags) - uploadC, err := measurementsuploader.New(cmd.Context(), flags.region, flags.bucket, log) + uploadC, uploadCClose, err := measurementsuploader.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) if err != nil { return fmt.Errorf("uploading image info: %w", err) } + defer func() { + if err := uploadCClose(cmd.Context()); err != nil { + log.Errorf("closing upload client: %v", err) + } + }() measurements, err := os.Open(flags.measurementsPath) if err != nil { diff --git a/image/upload/internal/cmd/nop.go b/image/upload/internal/cmd/nop.go index df541fd31..3d75c4f4f 100644 --- a/image/upload/internal/cmd/nop.go +++ b/image/upload/internal/cmd/nop.go @@ -33,10 +33,15 @@ func runNOP(cmd *cobra.Command, provider cloudprovider.Provider, _ []string) err log := logger.New(logger.PlainLog, flags.logLevel) log.Debugf("Parsed flags: %+v", flags) - archiveC, err := archive.New(cmd.Context(), flags.region, flags.bucket, log) + archiveC, archiveCClose, err := archive.New(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) if err != nil { return err } + defer func() { + if err := archiveCClose(cmd.Context()); err != nil { + log.Errorf("closing archive client: %v", err) + } + }() uploadC := nopupload.New(log) diff --git a/internal/api/attestationconfig/client/BUILD.bazel b/internal/api/attestationconfig/client/BUILD.bazel index ff3cf9bb8..b72d81a74 100644 --- a/internal/api/attestationconfig/client/BUILD.bazel +++ b/internal/api/attestationconfig/client/BUILD.bazel @@ -13,6 +13,7 @@ go_library( "//internal/sigstore", "//internal/staticupload", "//internal/variant", + "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", "@com_github_aws_aws_sdk_go_v2_service_s3//:s3", ], ) diff --git a/internal/api/attestationconfig/client/client.go b/internal/api/attestationconfig/client/client.go index 9bb8f6db6..6f2307b16 100644 --- a/internal/api/attestationconfig/client/client.go +++ b/internal/api/attestationconfig/client/client.go @@ -16,6 +16,7 @@ import ( "sort" "time" + s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfig" "github.com/edgelesssys/constellation/v2/internal/constants" @@ -27,18 +28,38 @@ import ( // Client manages (modifies) the version information for the attestation variants. type Client struct { - *staticupload.Client - cosignPwd []byte // used to decrypt the cosign private key - privKey []byte // used to sign + s3Client + s3ClientClose func(ctx context.Context) error + bucketID string + cosignPwd []byte // used to decrypt the cosign private key + privKey []byte // used to sign } // New returns a new Client. -func New(ctx context.Context, cfg staticupload.Config, cosignPwd, privateKey []byte) (*Client, error) { - client, err := staticupload.New(ctx, cfg) +func New(ctx context.Context, cfg staticupload.Config, cosignPwd, privateKey []byte) (*Client, CloseFunc, error) { + client, clientClose, err := staticupload.New(ctx, cfg) if err != nil { - return nil, fmt.Errorf("failed to create s3 storage: %w", err) + return nil, nil, fmt.Errorf("failed to create s3 storage: %w", err) } - return &Client{client, cosignPwd, privateKey}, nil + repo := &Client{ + s3Client: client, + s3ClientClose: clientClose, + bucketID: cfg.Bucket, + cosignPwd: cosignPwd, + privKey: privateKey, + } + repoClose := func(ctx context.Context) error { + return repo.Close(ctx) + } + return repo, repoClose, nil +} + +// Close closes the Client. +func (a Client) Close(ctx context.Context) error { + if a.s3ClientClose == nil { + return nil + } + return a.s3ClientClose(ctx) } // UploadAzureSEVSNP uploads the latest version numbers of the Azure SEVSNP. @@ -51,7 +72,7 @@ func (a Client) UploadAzureSEVSNP(ctx context.Context, versions attestationconfi fname := date.Format("2006-01-02-15-04") + ".json" filePath := fmt.Sprintf("%s/%s/%s", constants.CDNAttestationConfigPrefixV1, variant.String(), fname) - err = put(ctx, a.Client, filePath, versionBytes) + err = put(ctx, a.s3Client, a.bucketID, filePath, versionBytes) if err != nil { return err } @@ -69,7 +90,7 @@ func (a Client) createAndUploadSignature(ctx context.Context, content []byte, fi if err != nil { return fmt.Errorf("sign version file: %w", err) } - err = put(ctx, a.Client, filePath+".sig", signature) + err = put(ctx, a.s3Client, a.bucketID, filePath+".sig", signature) if err != nil { return fmt.Errorf("upload signature: %w", err) } @@ -79,7 +100,7 @@ func (a Client) createAndUploadSignature(ctx context.Context, content []byte, fi // List returns the list of versions for the given attestation type. func (a Client) List(ctx context.Context, attestation variant.Variant) ([]string, error) { key := path.Join(constants.CDNAttestationConfigPrefixV1, attestation.String(), "list") - bt, err := get(ctx, a.Client, key) + bt, err := get(ctx, a.s3Client, a.bucketID, key) if err != nil { return nil, err } @@ -97,13 +118,13 @@ func (a Client) DeleteList(ctx context.Context, attestation variant.Variant) err if err != nil { return err } - return put(ctx, a.Client, path.Join(constants.CDNAttestationConfigPrefixV1, attestation.String(), "list"), bt) + return put(ctx, a.s3Client, a.bucketID, path.Join(constants.CDNAttestationConfigPrefixV1, attestation.String(), "list"), bt) } func (a Client) addVersionToList(ctx context.Context, attestation variant.Variant, fname string) error { versions := []string{} key := path.Join(constants.CDNAttestationConfigPrefixV1, attestation.String(), "list") - bt, err := get(ctx, a.Client, key) + bt, err := get(ctx, a.s3Client, a.bucketID, key) if err == nil { if err := json.Unmarshal(bt, &versions); err != nil { return err @@ -118,13 +139,13 @@ func (a Client) addVersionToList(ctx context.Context, attestation variant.Varian if err != nil { return err } - return put(ctx, a.Client, key, json) + return put(ctx, a.s3Client, a.bucketID, key, json) } // get is a convenience method. -func get(ctx context.Context, client *staticupload.Client, path string) ([]byte, error) { +func get(ctx context.Context, client s3Client, bucket, path string) ([]byte, error) { getObjectInput := &s3.GetObjectInput{ - Bucket: &client.BucketID, + Bucket: &bucket, Key: &path, } output, err := client.GetObject(ctx, getObjectInput) @@ -135,12 +156,24 @@ func get(ctx context.Context, client *staticupload.Client, path string) ([]byte, } // put is a convenience method. -func put(ctx context.Context, client *staticupload.Client, path string, data []byte) error { +func put(ctx context.Context, client s3Client, bucket, path string, data []byte) error { putObjectInput := &s3.PutObjectInput{ - Bucket: &client.BucketID, + Bucket: &bucket, Key: &path, Body: bytes.NewReader(data), } _, err := client.Upload(ctx, putObjectInput) return err } + +type s3Client interface { + GetObject( + ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options), + ) (*s3.GetObjectOutput, error) + Upload( + ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader), + ) (*s3manager.UploadOutput, error) +} + +// CloseFunc is a function that closes the client. +type CloseFunc func(ctx context.Context) error diff --git a/internal/api/attestationconfig/client/client_test.go b/internal/api/attestationconfig/client/client_test.go index 41b483c66..f91a31c7b 100644 --- a/internal/api/attestationconfig/client/client_test.go +++ b/internal/api/attestationconfig/client/client_test.go @@ -74,8 +74,9 @@ var versionValues = attestationconfig.AzureSEVSNPVersion{ func TestUploadAzureSEVSNPVersions(t *testing.T) { ctx := context.Background() - client, err := client.New(ctx, cfg, []byte(*cosignPwd), privateKey) + client, clientClose, err := client.New(ctx, cfg, []byte(*cosignPwd), privateKey) require.NoError(t, err) + defer func() { _ = clientClose(ctx) }() d := time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC) require.NoError(t, client.UploadAzureSEVSNP(ctx, versionValues, d)) } diff --git a/internal/api/client/BUILD.bazel b/internal/api/client/BUILD.bazel index f4c6e13fc..dc8c7f289 100644 --- a/internal/api/client/BUILD.bazel +++ b/internal/api/client/BUILD.bazel @@ -7,11 +7,8 @@ go_library( visibility = ["//:__subpackages__"], deps = [ "//internal/logger", - "@com_github_aws_aws_sdk_go_v2//aws", - "@com_github_aws_aws_sdk_go_v2_config//:config", + "//internal/staticupload", "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", - "@com_github_aws_aws_sdk_go_v2_service_cloudfront//:cloudfront", - "@com_github_aws_aws_sdk_go_v2_service_cloudfront//types", "@com_github_aws_aws_sdk_go_v2_service_s3//:s3", "@com_github_aws_aws_sdk_go_v2_service_s3//types", "@org_uber_go_zap//:zap", diff --git a/internal/api/client/client.go b/internal/api/client/client.go index d76c28ac5..be15819e7 100644 --- a/internal/api/client/client.go +++ b/internal/api/client/client.go @@ -35,25 +35,20 @@ import ( "fmt" "time" - "github.com/aws/aws-sdk-go-v2/aws" - awsconfig "github.com/aws/aws-sdk-go-v2/config" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" - "github.com/aws/aws-sdk-go-v2/service/cloudfront" - cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" "github.com/aws/aws-sdk-go-v2/service/s3" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/internal/staticupload" "go.uber.org/zap" ) // Client is the client for the versions API. type Client struct { - config aws.Config - cloudfrontClient *cloudfront.Client - s3Client *s3.Client - uploadClient *s3manager.Uploader + uploadClient uploadClient + s3Client s3Client + s3ClientClose func(ctx context.Context) error bucket string - distributionID string cacheInvalidationWaitTimeout time.Duration dirtyPaths []string // written paths to be invalidated @@ -66,90 +61,68 @@ type Client struct { // This client can be used to fetch objects but cannot write updates. func NewReadOnlyClient(ctx context.Context, region, bucket, distributionID string, log *logger.Logger, -) (*Client, error) { - cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) +) (*Client, CloseFunc, error) { + staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{ + Region: region, + Bucket: bucket, + DistributionID: distributionID, + CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush, + }) if err != nil { - return nil, err + return nil, nil, err } - s3c := s3.NewFromConfig(cfg) - return &Client{ - config: cfg, - s3Client: s3c, - bucket: bucket, - distributionID: distributionID, - DryRun: true, - Log: log, - }, nil + client := &Client{ + s3Client: staticUploadClient, + s3ClientClose: staticUploadClientClose, + bucket: bucket, + DryRun: true, + Log: log, + } + clientClose := func(ctx context.Context) error { + return client.Close(ctx) + } + + return client, clientClose, nil } // NewClient creates a new client for the versions API. func NewClient(ctx context.Context, region, bucket, distributionID string, dryRun bool, log *logger.Logger, -) (*Client, error) { - cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) +) (*Client, CloseFunc, error) { + staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{ + Region: region, + Bucket: bucket, + DistributionID: distributionID, + CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush, + }) if err != nil { - return nil, err + return nil, nil, err } - cloudfrontC := cloudfront.NewFromConfig(cfg) - s3C := s3.NewFromConfig(cfg) - uploadC := s3manager.NewUploader(s3C) - - return &Client{ - config: cfg, - cloudfrontClient: cloudfrontC, - s3Client: s3C, - uploadClient: uploadC, + client := &Client{ + uploadClient: staticUploadClient, + s3Client: staticUploadClient, + s3ClientClose: staticUploadClientClose, bucket: bucket, - distributionID: distributionID, DryRun: dryRun, Log: log, cacheInvalidationWaitTimeout: 10 * time.Minute, - }, nil + } + clientClose := func(ctx context.Context) error { + return client.Close(ctx) + } + + return client, clientClose, nil } -// InvalidateCache invalidates the CDN cache for the paths that have been written. -// The function should be deferred after the client has been created. -func (c *Client) InvalidateCache(ctx context.Context) error { - if len(c.dirtyPaths) == 0 { - c.Log.Debugf("No dirty paths, skipping cache invalidation") +// Close closes the client. +// It invalidates the CDN cache for all uploaded files. +func (c *Client) Close(ctx context.Context) error { + if c.s3ClientClose == nil { return nil } - - if c.DryRun { - c.Log.With(zap.String("distributionID", c.distributionID), zap.Strings("dirtyPaths", c.dirtyPaths)).Debugf("DryRun: cloudfront create invalidation") - return nil - } - - c.Log.Debugf("Paths to invalidate: %v", c.dirtyPaths) - - in := &cloudfront.CreateInvalidationInput{ - DistributionId: &c.distributionID, - InvalidationBatch: &cftypes.InvalidationBatch{ - CallerReference: ptr(fmt.Sprintf("%d", time.Now().Unix())), - Paths: &cftypes.Paths{ - Items: c.dirtyPaths, - Quantity: ptr(int32(len(c.dirtyPaths))), - }, - }, - } - invalidation, err := c.cloudfrontClient.CreateInvalidation(ctx, in) - if err != nil { - return fmt.Errorf("creating invalidation: %w", err) - } - - c.Log.Debugf("Waiting for invalidation %s to complete", *invalidation.Invalidation.Id) - waiter := cloudfront.NewInvalidationCompletedWaiter(c.cloudfrontClient) - waitIn := &cloudfront.GetInvalidationInput{ - DistributionId: &c.distributionID, - Id: invalidation.Invalidation.Id, - } - if err := waiter.Wait(ctx, waitIn, c.cacheInvalidationWaitTimeout); err != nil { - return fmt.Errorf("waiting for invalidation to complete: %w", err) - } - - return nil + return c.s3ClientClose(ctx) } // DeletePath deletes all objects at a given path from a s3 bucket. @@ -289,3 +262,22 @@ func (e *NotFoundError) Error() string { func (e *NotFoundError) Unwrap() error { return e.err } + +type s3Client interface { + GetObject( + ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options), + ) (*s3.GetObjectOutput, error) + ListObjectsV2( + ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options), + ) (*s3.ListObjectsV2Output, error) + DeleteObjects( + ctx context.Context, params *s3.DeleteObjectsInput, optFns ...func(*s3.Options), + ) (*s3.DeleteObjectsOutput, error) +} + +type uploadClient interface { + Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) +} + +// CloseFunc is a function that closes the client. +type CloseFunc func(ctx context.Context) error diff --git a/internal/api/versions/cli/add.go b/internal/api/versions/cli/add.go index 4d2734160..fe71f6385 100644 --- a/internal/api/versions/cli/add.go +++ b/internal/api/versions/cli/add.go @@ -73,13 +73,13 @@ func runAdd(cmd *cobra.Command, _ []string) (retErr error) { } log.Debugf("Creating versions API client") - client, err := verclient.NewClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, flags.dryRun, log) + client, clientClose, err := verclient.NewClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, flags.dryRun, log) if err != nil { return fmt.Errorf("creating client: %w", err) } defer func(retErr *error) { log.Infof("Invalidating cache. This may take some time") - if err := client.InvalidateCache(cmd.Context()); err != nil && retErr == nil { + if err := clientClose(cmd.Context()); err != nil && retErr == nil { *retErr = fmt.Errorf("invalidating cache: %w", err) } }(&retErr) diff --git a/internal/api/versions/cli/latest.go b/internal/api/versions/cli/latest.go index 9cfab9ecc..c534392ed 100644 --- a/internal/api/versions/cli/latest.go +++ b/internal/api/versions/cli/latest.go @@ -47,10 +47,15 @@ func runLatest(cmd *cobra.Command, _ []string) error { } log.Debugf("Creating versions API client") - client, err := verclient.NewReadOnlyClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) + client, clientClose, err := verclient.NewReadOnlyClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) if err != nil { return fmt.Errorf("creating client: %w", err) } + defer func() { + if err := clientClose(cmd.Context()); err != nil { + log.Errorf("Closing versions API client: %v", err) + } + }() log.Debugf("Requesting latest version") latest := versionsapi.Latest{ diff --git a/internal/api/versions/cli/list.go b/internal/api/versions/cli/list.go index 474fe5be8..8dafd5482 100644 --- a/internal/api/versions/cli/list.go +++ b/internal/api/versions/cli/list.go @@ -53,10 +53,15 @@ func runList(cmd *cobra.Command, _ []string) error { } log.Debugf("Creating versions API client") - client, err := verclient.NewReadOnlyClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) + client, clientClose, err := verclient.NewReadOnlyClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, log) if err != nil { return fmt.Errorf("creating client: %w", err) } + defer func() { + if err := clientClose(cmd.Context()); err != nil { + log.Errorf("Closing versions API client: %v", err) + } + }() var minorVersions []string if flags.minorVersion != "" { diff --git a/internal/api/versions/cli/rm.go b/internal/api/versions/cli/rm.go index 93ace7476..0d719fb67 100644 --- a/internal/api/versions/cli/rm.go +++ b/internal/api/versions/cli/rm.go @@ -102,13 +102,13 @@ func runRemove(cmd *cobra.Command, _ []string) (retErr error) { } log.Debugf("Creating versions API client") - verclient, err := verclient.NewClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, flags.dryrun, log) + verclient, verclientClose, err := verclient.NewClient(cmd.Context(), flags.region, flags.bucket, flags.distributionID, flags.dryrun, log) if err != nil { return fmt.Errorf("creating client: %w", err) } defer func(retErr *error) { log.Infof("Invalidating cache. This may take some time") - if err := verclient.InvalidateCache(cmd.Context()); err != nil && retErr == nil { + if err := verclientClose(cmd.Context()); err != nil && retErr == nil { *retErr = fmt.Errorf("invalidating cache: %w", err) } }(&retErr) diff --git a/internal/api/versions/client/client.go b/internal/api/versions/client/client.go index 47c3570a1..8df6944f6 100644 --- a/internal/api/versions/client/client.go +++ b/internal/api/versions/client/client.go @@ -27,23 +27,49 @@ import ( // VersionsClient is a client for the versions API. type VersionsClient struct { *apiclient.Client + clientClose func(ctx context.Context) error } // NewClient creates a new client for the versions API. func NewClient(ctx context.Context, region, bucket, distributionID string, dryRun bool, log *logger.Logger, -) (*VersionsClient, error) { - genericClient, err := apiclient.NewClient(ctx, region, bucket, distributionID, dryRun, log) - return &VersionsClient{genericClient}, err +) (*VersionsClient, CloseFunc, error) { + genericClient, genericClientClose, err := apiclient.NewClient(ctx, region, bucket, distributionID, dryRun, log) + versionsClient := &VersionsClient{ + genericClient, + genericClientClose, + } + versionsClientClose := func(ctx context.Context) error { + return versionsClient.Close(ctx) + } + return versionsClient, versionsClientClose, err } // NewReadOnlyClient creates a new read-only client. // This client can be used to fetch objects but cannot write updates. func NewReadOnlyClient(ctx context.Context, region, bucket, distributionID string, log *logger.Logger, -) (*VersionsClient, error) { - genericClient, err := apiclient.NewReadOnlyClient(ctx, region, bucket, distributionID, log) - return &VersionsClient{genericClient}, err +) (*VersionsClient, CloseFunc, error) { + genericClient, genericClientClose, err := apiclient.NewReadOnlyClient(ctx, region, bucket, distributionID, log) + if err != nil { + return nil, nil, err + } + versionsClient := &VersionsClient{ + genericClient, + genericClientClose, + } + versionsClientClose := func(ctx context.Context) error { + return versionsClient.Close(ctx) + } + return versionsClient, versionsClientClose, err +} + +// Close closes the client. +func (c *VersionsClient) Close(ctx context.Context) error { + if c.clientClose == nil { + return nil + } + return c.clientClose(ctx) } // FetchVersionList fetches the given version list from the versions API. @@ -228,3 +254,6 @@ func (c *VersionsClient) deleteVersionFromLatest(ctx context.Context, ver versio return nil } + +// CloseFunc is a function that closes the client. +type CloseFunc func(ctx context.Context) error diff --git a/internal/osimage/archive/BUILD.bazel b/internal/osimage/archive/BUILD.bazel index 24485e291..41db03c25 100644 --- a/internal/osimage/archive/BUILD.bazel +++ b/internal/osimage/archive/BUILD.bazel @@ -9,7 +9,7 @@ go_library( "//internal/api/versions", "//internal/constants", "//internal/logger", - "@com_github_aws_aws_sdk_go_v2_config//:config", + "//internal/staticupload", "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", "@com_github_aws_aws_sdk_go_v2_service_s3//:s3", "@com_github_aws_aws_sdk_go_v2_service_s3//types", diff --git a/internal/osimage/archive/archive.go b/internal/osimage/archive/archive.go index 12da424ad..fb365fa18 100644 --- a/internal/osimage/archive/archive.go +++ b/internal/osimage/archive/archive.go @@ -12,18 +12,19 @@ import ( "io" "net/url" - awsconfig "github.com/aws/aws-sdk-go-v2/config" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" versionsapi "github.com/edgelesssys/constellation/v2/internal/api/versions" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/internal/staticupload" ) // Archivist uploads OS images to S3. type Archivist struct { - uploadClient uploadClient + uploadClient uploadClient + uploadClientClose func(ctx context.Context) error // bucket is the name of the S3 bucket to use. bucket string @@ -31,19 +32,37 @@ type Archivist struct { } // New creates a new Archivist. -func New(ctx context.Context, region, bucket string, log *logger.Logger) (*Archivist, error) { - cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) +func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Archivist, CloseFunc, error) { + staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{ + Region: region, + Bucket: bucket, + DistributionID: distributionID, + CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush, + }) if err != nil { - return nil, err + return nil, nil, err } - s3client := s3.NewFromConfig(cfg) - uploadClient := s3manager.NewUploader(s3client) - return &Archivist{ - uploadClient: uploadClient, - bucket: bucket, - log: log, - }, nil + archivist := &Archivist{ + uploadClient: staticUploadClient, + uploadClientClose: staticUploadClientClose, + bucket: bucket, + log: log, + } + archivistClose := func(ctx context.Context) error { + return archivist.Close(ctx) + } + + return archivist, archivistClose, nil +} + +// Close closes the uploader. +// It invalidates the CDN cache for all uploaded files. +func (a *Archivist) Close(ctx context.Context) error { + if a.uploadClientClose == nil { + return nil + } + return a.uploadClientClose(ctx) } // Archive reads the OS image in img and uploads it as key. @@ -65,3 +84,6 @@ func (a *Archivist) Archive(ctx context.Context, version versionsapi.Version, cs type uploadClient interface { Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) } + +// CloseFunc is a function that closes the client. +type CloseFunc func(ctx context.Context) error diff --git a/internal/osimage/imageinfo/BUILD.bazel b/internal/osimage/imageinfo/BUILD.bazel index 97a99ec57..2589830ec 100644 --- a/internal/osimage/imageinfo/BUILD.bazel +++ b/internal/osimage/imageinfo/BUILD.bazel @@ -9,7 +9,7 @@ go_library( "//internal/api/versions", "//internal/constants", "//internal/logger", - "@com_github_aws_aws_sdk_go_v2_config//:config", + "//internal/staticupload", "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", "@com_github_aws_aws_sdk_go_v2_service_s3//:s3", "@com_github_aws_aws_sdk_go_v2_service_s3//types", diff --git a/internal/osimage/imageinfo/imageinfo.go b/internal/osimage/imageinfo/imageinfo.go index 756e2eaa1..6e4afa373 100644 --- a/internal/osimage/imageinfo/imageinfo.go +++ b/internal/osimage/imageinfo/imageinfo.go @@ -13,18 +13,19 @@ import ( "encoding/json" "net/url" - awsconfig "github.com/aws/aws-sdk-go-v2/config" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" versionsapi "github.com/edgelesssys/constellation/v2/internal/api/versions" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/internal/staticupload" ) // Uploader uploads image info to S3. type Uploader struct { - uploadClient uploadClient + uploadClient uploadClient + uploadClientClose func(ctx context.Context) error // bucket is the name of the S3 bucket to use. bucket string @@ -32,19 +33,36 @@ type Uploader struct { } // New creates a new Uploader. -func New(ctx context.Context, region, bucket string, log *logger.Logger) (*Uploader, error) { - cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) +func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Uploader, CloseFunc, error) { + staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{ + Region: region, + Bucket: bucket, + DistributionID: distributionID, + CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush, + }) if err != nil { - return nil, err + return nil, nil, err } - s3client := s3.NewFromConfig(cfg) - uploadClient := s3manager.NewUploader(s3client) - return &Uploader{ - uploadClient: uploadClient, - bucket: bucket, - log: log, - }, nil + uploader := &Uploader{ + uploadClient: staticUploadClient, + uploadClientClose: staticUploadClientClose, + bucket: bucket, + log: log, + } + uploaderClose := func(ctx context.Context) error { + return uploader.Close(ctx) + } + return uploader, uploaderClose, nil +} + +// Close closes the uploader. +// It invalidates the CDN cache for all uploaded files. +func (a *Uploader) Close(ctx context.Context) error { + if a.uploadClientClose == nil { + return nil + } + return a.uploadClientClose(ctx) } // Upload marshals the image info to JSON and uploads it to S3. @@ -76,3 +94,6 @@ func (a *Uploader) Upload(ctx context.Context, imageInfo versionsapi.ImageInfo) type uploadClient interface { Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) } + +// CloseFunc is a function that closes the client. +type CloseFunc func(ctx context.Context) error diff --git a/internal/osimage/measurementsuploader/BUILD.bazel b/internal/osimage/measurementsuploader/BUILD.bazel index 1f80dd025..e88172b72 100644 --- a/internal/osimage/measurementsuploader/BUILD.bazel +++ b/internal/osimage/measurementsuploader/BUILD.bazel @@ -10,7 +10,7 @@ go_library( "//internal/attestation/measurements", "//internal/constants", "//internal/logger", - "@com_github_aws_aws_sdk_go_v2_config//:config", + "//internal/staticupload", "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", "@com_github_aws_aws_sdk_go_v2_service_s3//:s3", "@com_github_aws_aws_sdk_go_v2_service_s3//types", diff --git a/internal/osimage/measurementsuploader/measurementsuploader.go b/internal/osimage/measurementsuploader/measurementsuploader.go index 2ef6c921c..6292b8877 100644 --- a/internal/osimage/measurementsuploader/measurementsuploader.go +++ b/internal/osimage/measurementsuploader/measurementsuploader.go @@ -14,7 +14,6 @@ import ( "io" "net/url" - awsconfig "github.com/aws/aws-sdk-go-v2/config" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" @@ -22,11 +21,13 @@ import ( "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/internal/staticupload" ) // Uploader uploads image info to S3. type Uploader struct { - uploadClient uploadClient + uploadClient uploadClient + uploadClientClose func(ctx context.Context) error // bucket is the name of the S3 bucket to use. bucket string @@ -34,19 +35,36 @@ type Uploader struct { } // New creates a new Uploader. -func New(ctx context.Context, region, bucket string, log *logger.Logger) (*Uploader, error) { - cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) +func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Uploader, CloseFunc, error) { + staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{ + Region: region, + Bucket: bucket, + DistributionID: distributionID, + CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush, + }) if err != nil { - return nil, err + return nil, nil, err } - s3client := s3.NewFromConfig(cfg) - uploadClient := s3manager.NewUploader(s3client) - return &Uploader{ - uploadClient: uploadClient, - bucket: bucket, - log: log, - }, nil + uploader := &Uploader{ + uploadClient: staticUploadClient, + uploadClientClose: staticUploadClientClose, + bucket: bucket, + log: log, + } + uploaderClose := func(ctx context.Context) error { + return uploader.Close(ctx) + } + return uploader, uploaderClose, nil +} + +// Close closes the uploader. +// It invalidates the CDN cache for all uploaded files. +func (a *Uploader) Close(ctx context.Context) error { + if a.uploadClientClose == nil { + return nil + } + return a.uploadClientClose(ctx) } // Upload uploads the measurements v2 JSON file and its signature to S3. @@ -97,3 +115,6 @@ func (a *Uploader) Upload(ctx context.Context, rawMeasurement, signature io.Read type uploadClient interface { Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) } + +// CloseFunc is a function that closes the client. +type CloseFunc func(ctx context.Context) error diff --git a/internal/staticupload/get.go b/internal/staticupload/get.go index 5a13e5550..893fd243b 100644 --- a/internal/staticupload/get.go +++ b/internal/staticupload/get.go @@ -12,7 +12,16 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" ) -// GetObject returns an object from from AWS S3 Storage. -func (s *Client) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { - return s.s3Client.GetObject(ctx, params, optFns...) +// GetObject retrieves objects from Amazon S3. +func (c *Client) GetObject( + ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options), +) (*s3.GetObjectOutput, error) { + return c.s3Client.GetObject(ctx, params, optFns...) +} + +// ListObjectsV2 returns some or all (up to 1,000) of the objects in a bucket. +func (c *Client) ListObjectsV2( + ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options), +) (*s3.ListObjectsV2Output, error) { + return c.s3Client.ListObjectsV2(ctx, params, optFns...) } diff --git a/internal/staticupload/staticupload.go b/internal/staticupload/staticupload.go index 1860c0d71..f29b15cde 100644 --- a/internal/staticupload/staticupload.go +++ b/internal/staticupload/staticupload.go @@ -36,7 +36,7 @@ type Client struct { uploadClient uploadClient s3Client objectStorageClient distributionID string - BucketID string + bucketID string cacheInvalidationStrategy CacheInvalidationStrategy cacheInvalidationWaitTimeout time.Duration @@ -73,9 +73,9 @@ type CacheInvalidationStrategy int const ( // CacheInvalidateEager invalidates the CDN cache immediately for every key that is uploaded. CacheInvalidateEager CacheInvalidationStrategy = iota - // CacheInvalidateBatchOnClose invalidates the CDN cache in batches when the client is closed. - // This is useful when uploading many files at once but will fail if Close is not called. - CacheInvalidateBatchOnClose + // CacheInvalidateBatchOnFlush invalidates the CDN cache in batches when the client is flushed / closed. + // This is useful when uploading many files at once but may fail to invalidate the cache if close is not called. + CacheInvalidateBatchOnFlush ) // InvalidationError is an error that occurs when invalidating the CDN cache. @@ -94,33 +94,39 @@ func (e InvalidationError) Unwrap() error { } // New creates a new Client. -func New(ctx context.Context, config Config) (*Client, error) { +func New(ctx context.Context, config Config) (*Client, CloseFunc, error) { config.SetsDefault() cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(config.Region)) if err != nil { - return nil, err + return nil, nil, err } s3Client := s3.NewFromConfig(cfg) uploadClient := s3manager.NewUploader(s3Client) cdnClient := cloudfront.NewFromConfig(cfg) - return &Client{ + client := &Client{ cdnClient: cdnClient, s3Client: s3Client, uploadClient: uploadClient, distributionID: config.DistributionID, cacheInvalidationStrategy: config.CacheInvalidationStrategy, cacheInvalidationWaitTimeout: config.CacheInvalidationWaitTimeout, - BucketID: config.Bucket, - }, nil + bucketID: config.Bucket, + } + clientClose := func(ctx context.Context) error { + // ensure that all keys are invalidated + return client.Flush(ctx) + } + + return client, clientClose, nil } -// Close closes the client. +// Flush flushes the client by invalidating the CDN cache for modified keys. // It waits for all invalidations to finish. // It returns nil on success or an error. // The error will be of type InvalidationError if the CDN cache could not be invalidated. -func (c *Client) Close(ctx context.Context) error { +func (c *Client) Flush(ctx context.Context) error { c.mux.Lock() defer c.mux.Unlock() @@ -138,7 +144,7 @@ func (c *Client) Close(ctx context.Context) error { // invalidate invalidates the CDN cache for the given keys. // It either performs the invalidation immediately or adds them to the list of dirty keys. func (c *Client) invalidate(ctx context.Context, keys []string) error { - if c.cacheInvalidationStrategy == CacheInvalidateBatchOnClose { + if c.cacheInvalidationStrategy == CacheInvalidateBatchOnFlush { // save as dirty key for batch invalidation on Close c.mux.Lock() defer c.mux.Unlock() @@ -218,7 +224,15 @@ type uploadClient interface { } type getClient interface { - GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) + GetObject( + ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options), + ) (*s3.GetObjectOutput, error) +} + +type listClient interface { + ListObjectsV2( + ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options), + ) (*s3.ListObjectsV2Output, error) } type deleteClient interface { @@ -241,8 +255,9 @@ type cdnClient interface { } type objectStorageClient interface { - deleteClient getClient + listClient + deleteClient } // statically assert that Client implements the uploadClient interface. @@ -254,3 +269,6 @@ var _ objectStorageClient = (*Client)(nil) func ptr[T any](t T) *T { return &t } + +// CloseFunc is a function that closes the client. +type CloseFunc func(ctx context.Context) error diff --git a/internal/staticupload/staticupload_test.go b/internal/staticupload/staticupload_test.go index 659277934..46b732964 100644 --- a/internal/staticupload/staticupload_test.go +++ b/internal/staticupload/staticupload_test.go @@ -59,7 +59,7 @@ func TestUpload(t *testing.T) { }, "lazy invalidation": { in: newInput(), - cacheInvalidationStrategy: CacheInvalidateBatchOnClose, + cacheInvalidationStrategy: CacheInvalidateBatchOnFlush, cacheInvalidationWaitTimeout: time.Microsecond, wantDirtyKeys: []string{"test-key"}, }, @@ -181,7 +181,7 @@ func TestDeleteObject(t *testing.T) { wantInvalidationIDs: []string{"test-invalidation-id-1"}, }, "lazy invalidation": { - cacheInvalidationStrategy: CacheInvalidateBatchOnClose, + cacheInvalidationStrategy: CacheInvalidateBatchOnFlush, cacheInvalidationWaitTimeout: time.Microsecond, wantDirtyKeys: []string{"test-key"}, }, @@ -273,7 +273,7 @@ func TestDeleteObject(t *testing.T) { } } -func TestClose(t *testing.T) { +func TestFlush(t *testing.T) { testCases := map[string]struct { dirtyKeys []string invalidationIDs []string @@ -389,7 +389,7 @@ func TestClose(t *testing.T) { dirtyKeys: tc.dirtyKeys, invalidationIDs: tc.invalidationIDs, } - err := client.Close(context.Background()) + err := client.Flush(context.Background()) if tc.wantCacheInvalidationErr { assert.ErrorAs(err, &InvalidationError{}) @@ -450,9 +450,9 @@ func TestConcurrency(_ *testing.T) { }, }) } - closeClient := func() { + flushClient := func() { defer wg.Done() - _ = client.Close(context.Background()) + _ = client.Flush(context.Background()) } for i := 0; i < 100; i++ { @@ -460,7 +460,7 @@ func TestConcurrency(_ *testing.T) { go upload() go deleteObject() go deleteObjects() - go closeClient() + go flushClient() } wg.Wait() @@ -559,3 +559,11 @@ func (s *stubObjectStorageClient) GetObject( ) (*s3.GetObjectOutput, error) { return nil, nil } + +// currently not needed so no-Op. +func (s *stubObjectStorageClient) ListObjectsV2( + _ context.Context, _ *s3.ListObjectsV2Input, + _ ...func(*s3.Options), +) (*s3.ListObjectsV2Output, error) { + return nil, nil +}