staticupload: correctly set invalidation timeout

Previously the timeout was not set in the client's constructor, thus the
zero value was used. The client did not wait for invalidation.
To prevent this in the future a warning is logged if wait is disabled.

Co-authored-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Otto Bittner 2023-08-31 10:31:45 +02:00
parent fdaa5aab3c
commit 97dc15b1d1
14 changed files with 113 additions and 89 deletions

View File

@ -26,7 +26,8 @@ jobs:
id: checkout
uses: actions/checkout@c85c95e3d7251135ab7dc9ce3241c5835cc595a9 # v3.5.3
with:
ref: ${{ !github.event.pull_request.head.repo.fork && github.head_ref || '' }}
# Don't trigger in forks, use head on pull requests, use default otherwise.
ref: ${{ !github.event.pull_request.head.repo.fork && github.head_ref || github.event.pull_request.head.sha || '' }}
- name: Run Attestationconfig API E2E
uses: ./.github/actions/e2e_attestationconfigapi

View File

@ -14,6 +14,7 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"os"
"time"
@ -125,12 +126,12 @@ func runCmd(cmd *cobra.Command, _ []string) (retErr error) {
log.Infof("Input version: %+v is newer than latest API version: %+v", inputVersion, latestAPIVersion)
client, clientClose, err := attestationconfigapi.NewClient(ctx, cfg, []byte(cosignPwd), []byte(privateKey), false, log)
defer func(retErr *error) {
log.Infof("Invalidating cache. This may take some time")
if err := clientClose(cmd.Context()); err != nil && retErr == nil {
*retErr = fmt.Errorf("invalidating cache: %w", err)
defer func() {
err := clientClose(cmd.Context())
if err != nil {
retErr = errors.Join(retErr, fmt.Errorf("failed to invalidate cache: %w", err))
}
}(&retErr)
}()
if err != nil {
return fmt.Errorf("creating client: %w", err)

View File

@ -50,12 +50,11 @@ type Client struct {
s3Client
s3ClientClose func(ctx context.Context) error
bucket string
cacheInvalidationWaitTimeout time.Duration
dirtyPaths []string // written paths to be invalidated
DryRun bool // no write operations are performed
Log *logger.Logger
Logger *logger.Logger
}
// NewReadOnlyClient creates a new read-only client.
@ -68,7 +67,8 @@ func NewReadOnlyClient(ctx context.Context, region, bucket, distributionID strin
Bucket: bucket,
DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
})
CacheInvalidationWaitTimeout: 10 * time.Minute,
}, log)
if err != nil {
return nil, nil, err
}
@ -78,7 +78,7 @@ func NewReadOnlyClient(ctx context.Context, region, bucket, distributionID strin
s3ClientClose: staticUploadClientClose,
bucket: bucket,
DryRun: true,
Log: log,
Logger: log,
}
clientClose := func(ctx context.Context) error {
return client.Close(ctx)
@ -96,7 +96,8 @@ func NewClient(ctx context.Context, region, bucket, distributionID string, dryRu
Bucket: bucket,
DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
})
CacheInvalidationWaitTimeout: 10 * time.Minute,
}, log)
if err != nil {
return nil, nil, err
}
@ -106,8 +107,7 @@ func NewClient(ctx context.Context, region, bucket, distributionID string, dryRu
s3ClientClose: staticUploadClientClose,
bucket: bucket,
DryRun: dryRun,
Log: log,
cacheInvalidationWaitTimeout: 10 * time.Minute,
Logger: log,
}
clientClose := func(ctx context.Context) error {
return client.Close(ctx)
@ -120,6 +120,7 @@ func NewClient(ctx context.Context, region, bucket, distributionID string, dryRu
// It invalidates the CDN cache for all uploaded files.
func (c *Client) Close(ctx context.Context) error {
if c.s3ClientClose == nil {
c.Logger.Debugf("Client has no s3ClientClose")
return nil
}
return c.s3ClientClose(ctx)
@ -131,7 +132,7 @@ func (c *Client) DeletePath(ctx context.Context, path string) error {
Bucket: &c.bucket,
Prefix: &path,
}
c.Log.Debugf("Listing objects in %s", path)
c.Logger.Debugf("Listing objects in %s", path)
objs := []s3types.Object{}
out := &s3.ListObjectsV2Output{IsTruncated: true}
for out.IsTruncated {
@ -142,10 +143,10 @@ func (c *Client) DeletePath(ctx context.Context, path string) error {
}
objs = append(objs, out.Contents...)
}
c.Log.Debugf("Found %d objects in %s", len(objs), path)
c.Logger.Debugf("Found %d objects in %s", len(objs), path)
if len(objs) == 0 {
c.Log.Warnf("Path %s is already empty", path)
c.Logger.Warnf("Path %s is already empty", path)
return nil
}
@ -155,7 +156,7 @@ func (c *Client) DeletePath(ctx context.Context, path string) error {
}
if c.DryRun {
c.Log.Debugf("DryRun: Deleting %d objects with IDs %v", len(objs), objIDs)
c.Logger.Debugf("DryRun: Deleting %d objects with IDs %v", len(objs), objIDs)
return nil
}
@ -167,7 +168,7 @@ func (c *Client) DeletePath(ctx context.Context, path string) error {
Objects: objIDs,
},
}
c.Log.Debugf("Deleting %d objects in %s", len(objs), path)
c.Logger.Debugf("Deleting %d objects in %s", len(objs), path)
if _, err := c.s3Client.DeleteObjects(ctx, deleteIn); err != nil {
return fmt.Errorf("deleting objects in %s: %w", path, err)
}
@ -197,7 +198,7 @@ func Fetch[T APIObject](ctx context.Context, c *Client, obj T) (T, error) {
Key: ptr(obj.JSONPath()),
}
c.Log.Debugf("Fetching %T from s3: %s", obj, obj.JSONPath())
c.Logger.Debugf("Fetching %T from s3: %s", obj, obj.JSONPath())
out, err := c.s3Client.GetObject(ctx, in)
var noSuchkey *s3types.NoSuchKey
if errors.As(err, &noSuchkey) {
@ -231,7 +232,7 @@ func Update(ctx context.Context, c *Client, obj APIObject) error {
}
if c.DryRun {
c.Log.With(zap.String("bucket", c.bucket), zap.String("key", obj.JSONPath()), zap.String("body", string(rawJSON))).Debugf("DryRun: s3 put object")
c.Logger.With(zap.String("bucket", c.bucket), zap.String("key", obj.JSONPath()), zap.String("body", string(rawJSON))).Debugf("DryRun: s3 put object")
return nil
}
@ -243,7 +244,7 @@ func Update(ctx context.Context, c *Client, obj APIObject) error {
c.dirtyPaths = append(c.dirtyPaths, "/"+obj.JSONPath())
c.Log.Debugf("Uploading %T to s3: %v", obj, obj.JSONPath())
c.Logger.Debugf("Uploading %T to s3: %v", obj, obj.JSONPath())
if _, err := c.Upload(ctx, in); err != nil {
return fmt.Errorf("uploading %T: %w", obj, err)
}
@ -306,7 +307,7 @@ func Delete(ctx context.Context, c *Client, obj APIObject) error {
Key: ptr(obj.JSONPath()),
}
c.Log.Debugf("Deleting %T from s3: %s", obj, obj.JSONPath())
c.Logger.Debugf("Deleting %T from s3: %s", obj, obj.JSONPath())
if _, err := c.DeleteObject(ctx, in); err != nil {
return fmt.Errorf("deleting s3 object at %s: %w", obj.JSONPath(), err)
}

View File

@ -71,12 +71,12 @@ func runAdd(cmd *cobra.Command, _ []string) (retErr error) {
if err != nil {
return fmt.Errorf("creating client: %w", err)
}
defer func(retErr *error) {
log.Infof("Invalidating cache. This may take some time")
if err := clientClose(cmd.Context()); err != nil && retErr == nil {
*retErr = fmt.Errorf("invalidating cache: %w", err)
defer func() {
err := clientClose(cmd.Context())
if err != nil {
retErr = errors.Join(retErr, fmt.Errorf("failed to invalidate cache: %w", err))
}
}(&retErr)
}()
log.Infof("Adding version")
if err := ensureVersion(cmd.Context(), client, flags.kind, ver, versionsapi.GranularityMajor, log); err != nil {

View File

@ -8,6 +8,7 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"github.com/edgelesssys/constellation/v2/internal/api/versionsapi"
@ -32,7 +33,7 @@ func newLatestCmd() *cobra.Command {
return cmd
}
func runLatest(cmd *cobra.Command, _ []string) error {
func runLatest(cmd *cobra.Command, _ []string) (retErr error) {
flags, err := parseLatestFlags(cmd)
if err != nil {
return err
@ -51,8 +52,9 @@ func runLatest(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("creating client: %w", err)
}
defer func() {
if err := clientClose(cmd.Context()); err != nil {
log.Errorf("Closing versions API client: %v", err)
err := clientClose(cmd.Context())
if err != nil {
retErr = errors.Join(retErr, fmt.Errorf("failed to invalidate cache: %w", err))
}
}()

View File

@ -38,7 +38,7 @@ func newListCmd() *cobra.Command {
return cmd
}
func runList(cmd *cobra.Command, _ []string) error {
func runList(cmd *cobra.Command, _ []string) (retErr error) {
flags, err := parseListFlags(cmd)
if err != nil {
return err
@ -57,8 +57,9 @@ func runList(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("creating client: %w", err)
}
defer func() {
if err := clientClose(cmd.Context()); err != nil {
log.Errorf("Closing versions API client: %v", err)
err := clientClose(cmd.Context())
if err != nil {
retErr = errors.Join(retErr, fmt.Errorf("failed to invalidate cache: %w", err))
}
}()

View File

@ -105,12 +105,12 @@ func runRemove(cmd *cobra.Command, _ []string) (retErr error) {
if err != nil {
return fmt.Errorf("creating client: %w", err)
}
defer func(retErr *error) {
log.Infof("Invalidating cache. This may take some time")
if err := verclientClose(cmd.Context()); err != nil && retErr == nil {
*retErr = fmt.Errorf("invalidating cache: %w", err)
defer func() {
err := verclientClose(cmd.Context())
if err != nil {
retErr = errors.Join(retErr, fmt.Errorf("failed to invalidate cache: %w", err))
}
}(&retErr)
}()
imageClients := rmImageClients{
version: verclient,

View File

@ -131,18 +131,18 @@ func (c *Client) DeleteRef(ctx context.Context, ref string) error {
func (c *Client) DeleteVersion(ctx context.Context, ver Version) error {
var retErr error
c.Client.Log.Debugf("Deleting version %s from minor version list", ver.version)
c.Client.Logger.Debugf("Deleting version %s from minor version list", ver.version)
possibleNewLatest, err := c.deleteVersionFromMinorVersionList(ctx, ver)
if err != nil {
retErr = errors.Join(retErr, fmt.Errorf("removing from minor version list: %w", err))
}
c.Client.Log.Debugf("Checking latest version for %s", ver.version)
c.Client.Logger.Debugf("Checking latest version for %s", ver.version)
if err := c.deleteVersionFromLatest(ctx, ver, possibleNewLatest); err != nil {
retErr = errors.Join(retErr, fmt.Errorf("updating latest version: %w", err))
}
c.Client.Log.Debugf("Deleting artifact path %s for %s", ver.ArtifactPath(APIV1), ver.version)
c.Client.Logger.Debugf("Deleting artifact path %s for %s", ver.ArtifactPath(APIV1), ver.version)
if err := c.Client.DeletePath(ctx, ver.ArtifactPath(APIV1)); err != nil {
retErr = errors.Join(retErr, fmt.Errorf("deleting artifact path: %w", err))
}
@ -159,20 +159,20 @@ func (c *Client) deleteVersionFromMinorVersionList(ctx context.Context, ver Vers
Base: ver.WithGranularity(GranularityMinor),
Kind: VersionKindImage,
}
c.Client.Log.Debugf("Fetching minor version list for version %s", ver.version)
c.Client.Logger.Debugf("Fetching minor version list for version %s", ver.version)
minorList, err := c.FetchVersionList(ctx, minorList)
var notFoundErr *apiclient.NotFoundError
if errors.As(err, &notFoundErr) {
c.Client.Log.Warnf("Minor version list for version %s not found", ver.version)
c.Client.Log.Warnf("Skipping update of minor version list")
c.Client.Logger.Warnf("Minor version list for version %s not found", ver.version)
c.Client.Logger.Warnf("Skipping update of minor version list")
return nil, nil
} else if err != nil {
return nil, fmt.Errorf("fetching minor version list for version %s: %w", ver.version, err)
}
if !minorList.Contains(ver.version) {
c.Client.Log.Warnf("Version %s is not in minor version list %s", ver.version, minorList.JSONPath())
c.Client.Log.Warnf("Skipping update of minor version list")
c.Client.Logger.Warnf("Version %s is not in minor version list %s", ver.version, minorList.JSONPath())
c.Client.Logger.Warnf("Skipping update of minor version list")
return nil, nil
}
@ -192,20 +192,20 @@ func (c *Client) deleteVersionFromMinorVersionList(ctx context.Context, ver Vers
Kind: VersionKindImage,
Version: minorList.Versions[len(minorList.Versions)-1],
}
c.Client.Log.Debugf("Possible latest version replacement %q", latest.Version)
c.Client.Logger.Debugf("Possible latest version replacement %q", latest.Version)
}
if c.Client.DryRun {
c.Client.Log.Debugf("DryRun: Updating minor version list %s to %v", minorList.JSONPath(), minorList)
c.Client.Logger.Debugf("DryRun: Updating minor version list %s to %v", minorList.JSONPath(), minorList)
return latest, nil
}
c.Client.Log.Debugf("Updating minor version list %s", minorList.JSONPath())
c.Client.Logger.Debugf("Updating minor version list %s", minorList.JSONPath())
if err := c.UpdateVersionList(ctx, minorList); err != nil {
return latest, fmt.Errorf("updating minor version list %s: %w", minorList.JSONPath(), err)
}
c.Client.Log.Debugf("Removed version %s from minor version list %s", ver.version, minorList.JSONPath())
c.Client.Logger.Debugf("Removed version %s from minor version list %s", ver.version, minorList.JSONPath())
return latest, nil
}
@ -216,33 +216,33 @@ func (c *Client) deleteVersionFromLatest(ctx context.Context, ver Version, possi
Stream: ver.stream,
Kind: VersionKindImage,
}
c.Client.Log.Debugf("Fetching latest version from %s", latest.JSONPath())
c.Client.Logger.Debugf("Fetching latest version from %s", latest.JSONPath())
latest, err := c.FetchVersionLatest(ctx, latest)
var notFoundErr *apiclient.NotFoundError
if errors.As(err, &notFoundErr) {
c.Client.Log.Warnf("Latest version for %s not found", latest.JSONPath())
c.Client.Logger.Warnf("Latest version for %s not found", latest.JSONPath())
return nil
} else if err != nil {
return fmt.Errorf("fetching latest version: %w", err)
}
if latest.Version != ver.version {
c.Client.Log.Debugf("Latest version is %s, not the deleted version %s", latest.Version, ver.version)
c.Client.Logger.Debugf("Latest version is %s, not the deleted version %s", latest.Version, ver.version)
return nil
}
if possibleNewLatest == nil {
c.Client.Log.Errorf("Latest version is %s, but no new latest version was found", latest.Version)
c.Client.Log.Errorf("A manual update of latest at %s might be needed", latest.JSONPath())
c.Client.Logger.Errorf("Latest version is %s, but no new latest version was found", latest.Version)
c.Client.Logger.Errorf("A manual update of latest at %s might be needed", latest.JSONPath())
return fmt.Errorf("latest version is %s, but no new latest version was found", latest.Version)
}
if c.Client.DryRun {
c.Client.Log.Debugf("Would update latest version from %s to %s", latest.Version, possibleNewLatest.Version)
c.Client.Logger.Debugf("Would update latest version from %s to %s", latest.Version, possibleNewLatest.Version)
return nil
}
c.Client.Log.Infof("Updating latest version from %s to %s", latest.Version, possibleNewLatest.Version)
c.Client.Logger.Infof("Updating latest version from %s to %s", latest.Version, possibleNewLatest.Version)
if err := c.UpdateVersionLatest(ctx, *possibleNewLatest); err != nil {
return fmt.Errorf("updating latest version: %w", err)
}

View File

@ -11,6 +11,7 @@ import (
"context"
"io"
"net/url"
"time"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
@ -38,7 +39,8 @@ func New(ctx context.Context, region, bucket, distributionID string, log *logger
Bucket: bucket,
DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
})
CacheInvalidationWaitTimeout: 10 * time.Minute,
}, log)
if err != nil {
return nil, nil, err
}

View File

@ -13,6 +13,7 @@ import (
"encoding/json"
"fmt"
"net/url"
"time"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
@ -40,7 +41,8 @@ func New(ctx context.Context, region, bucket, distributionID string, log *logger
Bucket: bucket,
DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
})
CacheInvalidationWaitTimeout: 10 * time.Minute,
}, log)
if err != nil {
return nil, nil, err
}

View File

@ -13,6 +13,7 @@ import (
"fmt"
"io"
"net/url"
"time"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
@ -41,7 +42,8 @@ func New(ctx context.Context, region, bucket, distributionID string, log *logger
Bucket: bucket,
DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
})
CacheInvalidationWaitTimeout: 10 * time.Minute,
}, log)
if err != nil {
return nil, nil, err
}

View File

@ -13,6 +13,7 @@ go_library(
visibility = ["//:__subpackages__"],
deps = [
"//internal/constants",
"//internal/logger",
"@com_github_aws_aws_sdk_go_v2_config//:config",
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//:cloudfront",
@ -27,6 +28,7 @@ go_test(
srcs = ["staticupload_test.go"],
embed = [":staticupload"],
deps = [
"//internal/logger",
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//:cloudfront",
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//types",

View File

@ -25,6 +25,7 @@ import (
cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/google/uuid"
)
@ -100,7 +101,7 @@ func (e *InvalidationError) Unwrap() error {
}
// New creates a new Client. Call CloseFunc when done with operations.
func New(ctx context.Context, config Config) (*Client, CloseFunc, error) {
func New(ctx context.Context, config Config, log *logger.Logger) (*Client, CloseFunc, error) {
config.SetsDefault()
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(config.Region))
if err != nil {
@ -121,6 +122,7 @@ func New(ctx context.Context, config Config) (*Client, CloseFunc, error) {
bucketID: config.Bucket,
logger: log,
}
return client, client.Flush, nil
}
@ -132,6 +134,7 @@ func (c *Client) Flush(ctx context.Context) error {
c.mux.Lock()
defer c.mux.Unlock()
c.logger.Debugf("Invalidating keys: %s", c.dirtyKeys)
if len(c.dirtyKeys) == 0 {
return nil
}
@ -211,16 +214,17 @@ func (c *Client) invalidateCacheForKeys(ctx context.Context, keys []string) (str
// waitForInvalidations waits for all invalidations to finish.
func (c *Client) waitForInvalidations(ctx context.Context) error {
if c.cacheInvalidationWaitTimeout == 0 {
c.logger.Warnf("cacheInvalidationWaitTimeout set to 0, not waiting for invalidations to finish")
return nil
}
waiter := cloudfront.NewInvalidationCompletedWaiter(c.cdnClient)
c.logger.Debugf("Waiting for invalidations %s in distribution %s", c.invalidationIDs, c.distributionID)
for _, invalidationID := range c.invalidationIDs {
waitIn := &cloudfront.GetInvalidationInput{
DistributionId: &c.distributionID,
Id: &invalidationID,
}
c.logger.Debugf("Waiting for invalidation %s in distribution %s", invalidationID, c.distributionID)
if err := waiter.Wait(ctx, waitIn, c.cacheInvalidationWaitTimeout); err != nil {
return NewInvalidationError(fmt.Errorf("waiting for invalidation to complete: %w", err))
}

View File

@ -21,6 +21,7 @@ import (
cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types"
"github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
@ -105,6 +106,7 @@ func TestUpload(t *testing.T) {
distributionID: "test-distribution-id",
cacheInvalidationStrategy: tc.cacheInvalidationStrategy,
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
logger: logger.NewTest(t),
}
_, err := client.Upload(context.Background(), tc.in)
@ -216,6 +218,7 @@ func TestDeleteObject(t *testing.T) {
distributionID: "test-distribution-id",
cacheInvalidationStrategy: tc.cacheInvalidationStrategy,
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
logger: logger.NewTest(t),
}
_, err := client.DeleteObject(context.Background(), newObjectInput(tc.nilInput, tc.nilKey))
@ -254,6 +257,7 @@ func TestDeleteObject(t *testing.T) {
distributionID: "test-distribution-id",
cacheInvalidationStrategy: tc.cacheInvalidationStrategy,
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
logger: logger.NewTest(t),
}
_, err := client.DeleteObjects(context.Background(), newObjectsInput(tc.nilInput, tc.nilKey))
@ -395,6 +399,7 @@ func TestFlush(t *testing.T) {
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
dirtyKeys: tc.dirtyKeys,
invalidationIDs: tc.invalidationIDs,
logger: logger.NewTest(t),
}
err := client.Flush(context.Background())
@ -413,7 +418,7 @@ func TestFlush(t *testing.T) {
}
}
func TestConcurrency(_ *testing.T) {
func TestConcurrency(t *testing.T) {
newInput := func() *s3.PutObjectInput {
return &s3.PutObjectInput{
Bucket: ptr("test-bucket"),
@ -432,6 +437,7 @@ func TestConcurrency(_ *testing.T) {
uploadClient: uploadClient,
distributionID: "test-distribution-id",
cacheInvalidationWaitTimeout: 50 * time.Millisecond,
logger: logger.NewTest(t),
}
var wg sync.WaitGroup