staticupload: correctly set invalidation timeout

Previously the timeout was not set in the client's constructor, thus the
zero value was used. The client did not wait for invalidation.
To prevent this in the future a warning is logged if wait is disabled.

Co-authored-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Otto Bittner 2023-08-31 10:31:45 +02:00
parent fdaa5aab3c
commit 97dc15b1d1
14 changed files with 113 additions and 89 deletions

View File

@ -26,7 +26,8 @@ jobs:
id: checkout id: checkout
uses: actions/checkout@c85c95e3d7251135ab7dc9ce3241c5835cc595a9 # v3.5.3 uses: actions/checkout@c85c95e3d7251135ab7dc9ce3241c5835cc595a9 # v3.5.3
with: with:
ref: ${{ !github.event.pull_request.head.repo.fork && github.head_ref || '' }} # Don't trigger in forks, use head on pull requests, use default otherwise.
ref: ${{ !github.event.pull_request.head.repo.fork && github.head_ref || github.event.pull_request.head.sha || '' }}
- name: Run Attestationconfig API E2E - name: Run Attestationconfig API E2E
uses: ./.github/actions/e2e_attestationconfigapi uses: ./.github/actions/e2e_attestationconfigapi

View File

@ -14,6 +14,7 @@ package main
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
"time" "time"
@ -125,12 +126,12 @@ func runCmd(cmd *cobra.Command, _ []string) (retErr error) {
log.Infof("Input version: %+v is newer than latest API version: %+v", inputVersion, latestAPIVersion) log.Infof("Input version: %+v is newer than latest API version: %+v", inputVersion, latestAPIVersion)
client, clientClose, err := attestationconfigapi.NewClient(ctx, cfg, []byte(cosignPwd), []byte(privateKey), false, log) client, clientClose, err := attestationconfigapi.NewClient(ctx, cfg, []byte(cosignPwd), []byte(privateKey), false, log)
defer func(retErr *error) { defer func() {
log.Infof("Invalidating cache. This may take some time") err := clientClose(cmd.Context())
if err := clientClose(cmd.Context()); err != nil && retErr == nil { if err != nil {
*retErr = fmt.Errorf("invalidating cache: %w", err) retErr = errors.Join(retErr, fmt.Errorf("failed to invalidate cache: %w", err))
} }
}(&retErr) }()
if err != nil { if err != nil {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("creating client: %w", err)

View File

@ -48,14 +48,13 @@ import (
// Client is the a general client for all APIs. // Client is the a general client for all APIs.
type Client struct { type Client struct {
s3Client s3Client
s3ClientClose func(ctx context.Context) error s3ClientClose func(ctx context.Context) error
bucket string bucket string
cacheInvalidationWaitTimeout time.Duration
dirtyPaths []string // written paths to be invalidated dirtyPaths []string // written paths to be invalidated
DryRun bool // no write operations are performed DryRun bool // no write operations are performed
Log *logger.Logger Logger *logger.Logger
} }
// NewReadOnlyClient creates a new read-only client. // NewReadOnlyClient creates a new read-only client.
@ -64,11 +63,12 @@ func NewReadOnlyClient(ctx context.Context, region, bucket, distributionID strin
log *logger.Logger, log *logger.Logger,
) (*Client, CloseFunc, error) { ) (*Client, CloseFunc, error) {
staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{ staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{
Region: region, Region: region,
Bucket: bucket, Bucket: bucket,
DistributionID: distributionID, DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush, CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
}) CacheInvalidationWaitTimeout: 10 * time.Minute,
}, log)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -78,7 +78,7 @@ func NewReadOnlyClient(ctx context.Context, region, bucket, distributionID strin
s3ClientClose: staticUploadClientClose, s3ClientClose: staticUploadClientClose,
bucket: bucket, bucket: bucket,
DryRun: true, DryRun: true,
Log: log, Logger: log,
} }
clientClose := func(ctx context.Context) error { clientClose := func(ctx context.Context) error {
return client.Close(ctx) return client.Close(ctx)
@ -92,22 +92,22 @@ func NewClient(ctx context.Context, region, bucket, distributionID string, dryRu
log *logger.Logger, log *logger.Logger,
) (*Client, CloseFunc, error) { ) (*Client, CloseFunc, error) {
staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{ staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{
Region: region, Region: region,
Bucket: bucket, Bucket: bucket,
DistributionID: distributionID, DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush, CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
}) CacheInvalidationWaitTimeout: 10 * time.Minute,
}, log)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
client := &Client{ client := &Client{
s3Client: staticUploadClient, s3Client: staticUploadClient,
s3ClientClose: staticUploadClientClose, s3ClientClose: staticUploadClientClose,
bucket: bucket, bucket: bucket,
DryRun: dryRun, DryRun: dryRun,
Log: log, Logger: log,
cacheInvalidationWaitTimeout: 10 * time.Minute,
} }
clientClose := func(ctx context.Context) error { clientClose := func(ctx context.Context) error {
return client.Close(ctx) return client.Close(ctx)
@ -120,6 +120,7 @@ func NewClient(ctx context.Context, region, bucket, distributionID string, dryRu
// It invalidates the CDN cache for all uploaded files. // It invalidates the CDN cache for all uploaded files.
func (c *Client) Close(ctx context.Context) error { func (c *Client) Close(ctx context.Context) error {
if c.s3ClientClose == nil { if c.s3ClientClose == nil {
c.Logger.Debugf("Client has no s3ClientClose")
return nil return nil
} }
return c.s3ClientClose(ctx) return c.s3ClientClose(ctx)
@ -131,7 +132,7 @@ func (c *Client) DeletePath(ctx context.Context, path string) error {
Bucket: &c.bucket, Bucket: &c.bucket,
Prefix: &path, Prefix: &path,
} }
c.Log.Debugf("Listing objects in %s", path) c.Logger.Debugf("Listing objects in %s", path)
objs := []s3types.Object{} objs := []s3types.Object{}
out := &s3.ListObjectsV2Output{IsTruncated: true} out := &s3.ListObjectsV2Output{IsTruncated: true}
for out.IsTruncated { for out.IsTruncated {
@ -142,10 +143,10 @@ func (c *Client) DeletePath(ctx context.Context, path string) error {
} }
objs = append(objs, out.Contents...) objs = append(objs, out.Contents...)
} }
c.Log.Debugf("Found %d objects in %s", len(objs), path) c.Logger.Debugf("Found %d objects in %s", len(objs), path)
if len(objs) == 0 { if len(objs) == 0 {
c.Log.Warnf("Path %s is already empty", path) c.Logger.Warnf("Path %s is already empty", path)
return nil return nil
} }
@ -155,7 +156,7 @@ func (c *Client) DeletePath(ctx context.Context, path string) error {
} }
if c.DryRun { if c.DryRun {
c.Log.Debugf("DryRun: Deleting %d objects with IDs %v", len(objs), objIDs) c.Logger.Debugf("DryRun: Deleting %d objects with IDs %v", len(objs), objIDs)
return nil return nil
} }
@ -167,7 +168,7 @@ func (c *Client) DeletePath(ctx context.Context, path string) error {
Objects: objIDs, Objects: objIDs,
}, },
} }
c.Log.Debugf("Deleting %d objects in %s", len(objs), path) c.Logger.Debugf("Deleting %d objects in %s", len(objs), path)
if _, err := c.s3Client.DeleteObjects(ctx, deleteIn); err != nil { if _, err := c.s3Client.DeleteObjects(ctx, deleteIn); err != nil {
return fmt.Errorf("deleting objects in %s: %w", path, err) return fmt.Errorf("deleting objects in %s: %w", path, err)
} }
@ -197,7 +198,7 @@ func Fetch[T APIObject](ctx context.Context, c *Client, obj T) (T, error) {
Key: ptr(obj.JSONPath()), Key: ptr(obj.JSONPath()),
} }
c.Log.Debugf("Fetching %T from s3: %s", obj, obj.JSONPath()) c.Logger.Debugf("Fetching %T from s3: %s", obj, obj.JSONPath())
out, err := c.s3Client.GetObject(ctx, in) out, err := c.s3Client.GetObject(ctx, in)
var noSuchkey *s3types.NoSuchKey var noSuchkey *s3types.NoSuchKey
if errors.As(err, &noSuchkey) { if errors.As(err, &noSuchkey) {
@ -231,7 +232,7 @@ func Update(ctx context.Context, c *Client, obj APIObject) error {
} }
if c.DryRun { if c.DryRun {
c.Log.With(zap.String("bucket", c.bucket), zap.String("key", obj.JSONPath()), zap.String("body", string(rawJSON))).Debugf("DryRun: s3 put object") c.Logger.With(zap.String("bucket", c.bucket), zap.String("key", obj.JSONPath()), zap.String("body", string(rawJSON))).Debugf("DryRun: s3 put object")
return nil return nil
} }
@ -243,7 +244,7 @@ func Update(ctx context.Context, c *Client, obj APIObject) error {
c.dirtyPaths = append(c.dirtyPaths, "/"+obj.JSONPath()) c.dirtyPaths = append(c.dirtyPaths, "/"+obj.JSONPath())
c.Log.Debugf("Uploading %T to s3: %v", obj, obj.JSONPath()) c.Logger.Debugf("Uploading %T to s3: %v", obj, obj.JSONPath())
if _, err := c.Upload(ctx, in); err != nil { if _, err := c.Upload(ctx, in); err != nil {
return fmt.Errorf("uploading %T: %w", obj, err) return fmt.Errorf("uploading %T: %w", obj, err)
} }
@ -306,7 +307,7 @@ func Delete(ctx context.Context, c *Client, obj APIObject) error {
Key: ptr(obj.JSONPath()), Key: ptr(obj.JSONPath()),
} }
c.Log.Debugf("Deleting %T from s3: %s", obj, obj.JSONPath()) c.Logger.Debugf("Deleting %T from s3: %s", obj, obj.JSONPath())
if _, err := c.DeleteObject(ctx, in); err != nil { if _, err := c.DeleteObject(ctx, in); err != nil {
return fmt.Errorf("deleting s3 object at %s: %w", obj.JSONPath(), err) return fmt.Errorf("deleting s3 object at %s: %w", obj.JSONPath(), err)
} }

View File

@ -71,12 +71,12 @@ func runAdd(cmd *cobra.Command, _ []string) (retErr error) {
if err != nil { if err != nil {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("creating client: %w", err)
} }
defer func(retErr *error) { defer func() {
log.Infof("Invalidating cache. This may take some time") err := clientClose(cmd.Context())
if err := clientClose(cmd.Context()); err != nil && retErr == nil { if err != nil {
*retErr = fmt.Errorf("invalidating cache: %w", err) retErr = errors.Join(retErr, fmt.Errorf("failed to invalidate cache: %w", err))
} }
}(&retErr) }()
log.Infof("Adding version") log.Infof("Adding version")
if err := ensureVersion(cmd.Context(), client, flags.kind, ver, versionsapi.GranularityMajor, log); err != nil { if err := ensureVersion(cmd.Context(), client, flags.kind, ver, versionsapi.GranularityMajor, log); err != nil {

View File

@ -8,6 +8,7 @@ package main
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/edgelesssys/constellation/v2/internal/api/versionsapi" "github.com/edgelesssys/constellation/v2/internal/api/versionsapi"
@ -32,7 +33,7 @@ func newLatestCmd() *cobra.Command {
return cmd return cmd
} }
func runLatest(cmd *cobra.Command, _ []string) error { func runLatest(cmd *cobra.Command, _ []string) (retErr error) {
flags, err := parseLatestFlags(cmd) flags, err := parseLatestFlags(cmd)
if err != nil { if err != nil {
return err return err
@ -51,8 +52,9 @@ func runLatest(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("creating client: %w", err)
} }
defer func() { defer func() {
if err := clientClose(cmd.Context()); err != nil { err := clientClose(cmd.Context())
log.Errorf("Closing versions API client: %v", err) if err != nil {
retErr = errors.Join(retErr, fmt.Errorf("failed to invalidate cache: %w", err))
} }
}() }()

View File

@ -38,7 +38,7 @@ func newListCmd() *cobra.Command {
return cmd return cmd
} }
func runList(cmd *cobra.Command, _ []string) error { func runList(cmd *cobra.Command, _ []string) (retErr error) {
flags, err := parseListFlags(cmd) flags, err := parseListFlags(cmd)
if err != nil { if err != nil {
return err return err
@ -57,8 +57,9 @@ func runList(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("creating client: %w", err)
} }
defer func() { defer func() {
if err := clientClose(cmd.Context()); err != nil { err := clientClose(cmd.Context())
log.Errorf("Closing versions API client: %v", err) if err != nil {
retErr = errors.Join(retErr, fmt.Errorf("failed to invalidate cache: %w", err))
} }
}() }()

View File

@ -105,12 +105,12 @@ func runRemove(cmd *cobra.Command, _ []string) (retErr error) {
if err != nil { if err != nil {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("creating client: %w", err)
} }
defer func(retErr *error) { defer func() {
log.Infof("Invalidating cache. This may take some time") err := verclientClose(cmd.Context())
if err := verclientClose(cmd.Context()); err != nil && retErr == nil { if err != nil {
*retErr = fmt.Errorf("invalidating cache: %w", err) retErr = errors.Join(retErr, fmt.Errorf("failed to invalidate cache: %w", err))
} }
}(&retErr) }()
imageClients := rmImageClients{ imageClients := rmImageClients{
version: verclient, version: verclient,

View File

@ -131,18 +131,18 @@ func (c *Client) DeleteRef(ctx context.Context, ref string) error {
func (c *Client) DeleteVersion(ctx context.Context, ver Version) error { func (c *Client) DeleteVersion(ctx context.Context, ver Version) error {
var retErr error var retErr error
c.Client.Log.Debugf("Deleting version %s from minor version list", ver.version) c.Client.Logger.Debugf("Deleting version %s from minor version list", ver.version)
possibleNewLatest, err := c.deleteVersionFromMinorVersionList(ctx, ver) possibleNewLatest, err := c.deleteVersionFromMinorVersionList(ctx, ver)
if err != nil { if err != nil {
retErr = errors.Join(retErr, fmt.Errorf("removing from minor version list: %w", err)) retErr = errors.Join(retErr, fmt.Errorf("removing from minor version list: %w", err))
} }
c.Client.Log.Debugf("Checking latest version for %s", ver.version) c.Client.Logger.Debugf("Checking latest version for %s", ver.version)
if err := c.deleteVersionFromLatest(ctx, ver, possibleNewLatest); err != nil { if err := c.deleteVersionFromLatest(ctx, ver, possibleNewLatest); err != nil {
retErr = errors.Join(retErr, fmt.Errorf("updating latest version: %w", err)) retErr = errors.Join(retErr, fmt.Errorf("updating latest version: %w", err))
} }
c.Client.Log.Debugf("Deleting artifact path %s for %s", ver.ArtifactPath(APIV1), ver.version) c.Client.Logger.Debugf("Deleting artifact path %s for %s", ver.ArtifactPath(APIV1), ver.version)
if err := c.Client.DeletePath(ctx, ver.ArtifactPath(APIV1)); err != nil { if err := c.Client.DeletePath(ctx, ver.ArtifactPath(APIV1)); err != nil {
retErr = errors.Join(retErr, fmt.Errorf("deleting artifact path: %w", err)) retErr = errors.Join(retErr, fmt.Errorf("deleting artifact path: %w", err))
} }
@ -159,20 +159,20 @@ func (c *Client) deleteVersionFromMinorVersionList(ctx context.Context, ver Vers
Base: ver.WithGranularity(GranularityMinor), Base: ver.WithGranularity(GranularityMinor),
Kind: VersionKindImage, Kind: VersionKindImage,
} }
c.Client.Log.Debugf("Fetching minor version list for version %s", ver.version) c.Client.Logger.Debugf("Fetching minor version list for version %s", ver.version)
minorList, err := c.FetchVersionList(ctx, minorList) minorList, err := c.FetchVersionList(ctx, minorList)
var notFoundErr *apiclient.NotFoundError var notFoundErr *apiclient.NotFoundError
if errors.As(err, &notFoundErr) { if errors.As(err, &notFoundErr) {
c.Client.Log.Warnf("Minor version list for version %s not found", ver.version) c.Client.Logger.Warnf("Minor version list for version %s not found", ver.version)
c.Client.Log.Warnf("Skipping update of minor version list") c.Client.Logger.Warnf("Skipping update of minor version list")
return nil, nil return nil, nil
} else if err != nil { } else if err != nil {
return nil, fmt.Errorf("fetching minor version list for version %s: %w", ver.version, err) return nil, fmt.Errorf("fetching minor version list for version %s: %w", ver.version, err)
} }
if !minorList.Contains(ver.version) { if !minorList.Contains(ver.version) {
c.Client.Log.Warnf("Version %s is not in minor version list %s", ver.version, minorList.JSONPath()) c.Client.Logger.Warnf("Version %s is not in minor version list %s", ver.version, minorList.JSONPath())
c.Client.Log.Warnf("Skipping update of minor version list") c.Client.Logger.Warnf("Skipping update of minor version list")
return nil, nil return nil, nil
} }
@ -192,20 +192,20 @@ func (c *Client) deleteVersionFromMinorVersionList(ctx context.Context, ver Vers
Kind: VersionKindImage, Kind: VersionKindImage,
Version: minorList.Versions[len(minorList.Versions)-1], Version: minorList.Versions[len(minorList.Versions)-1],
} }
c.Client.Log.Debugf("Possible latest version replacement %q", latest.Version) c.Client.Logger.Debugf("Possible latest version replacement %q", latest.Version)
} }
if c.Client.DryRun { if c.Client.DryRun {
c.Client.Log.Debugf("DryRun: Updating minor version list %s to %v", minorList.JSONPath(), minorList) c.Client.Logger.Debugf("DryRun: Updating minor version list %s to %v", minorList.JSONPath(), minorList)
return latest, nil return latest, nil
} }
c.Client.Log.Debugf("Updating minor version list %s", minorList.JSONPath()) c.Client.Logger.Debugf("Updating minor version list %s", minorList.JSONPath())
if err := c.UpdateVersionList(ctx, minorList); err != nil { if err := c.UpdateVersionList(ctx, minorList); err != nil {
return latest, fmt.Errorf("updating minor version list %s: %w", minorList.JSONPath(), err) return latest, fmt.Errorf("updating minor version list %s: %w", minorList.JSONPath(), err)
} }
c.Client.Log.Debugf("Removed version %s from minor version list %s", ver.version, minorList.JSONPath()) c.Client.Logger.Debugf("Removed version %s from minor version list %s", ver.version, minorList.JSONPath())
return latest, nil return latest, nil
} }
@ -216,33 +216,33 @@ func (c *Client) deleteVersionFromLatest(ctx context.Context, ver Version, possi
Stream: ver.stream, Stream: ver.stream,
Kind: VersionKindImage, Kind: VersionKindImage,
} }
c.Client.Log.Debugf("Fetching latest version from %s", latest.JSONPath()) c.Client.Logger.Debugf("Fetching latest version from %s", latest.JSONPath())
latest, err := c.FetchVersionLatest(ctx, latest) latest, err := c.FetchVersionLatest(ctx, latest)
var notFoundErr *apiclient.NotFoundError var notFoundErr *apiclient.NotFoundError
if errors.As(err, &notFoundErr) { if errors.As(err, &notFoundErr) {
c.Client.Log.Warnf("Latest version for %s not found", latest.JSONPath()) c.Client.Logger.Warnf("Latest version for %s not found", latest.JSONPath())
return nil return nil
} else if err != nil { } else if err != nil {
return fmt.Errorf("fetching latest version: %w", err) return fmt.Errorf("fetching latest version: %w", err)
} }
if latest.Version != ver.version { if latest.Version != ver.version {
c.Client.Log.Debugf("Latest version is %s, not the deleted version %s", latest.Version, ver.version) c.Client.Logger.Debugf("Latest version is %s, not the deleted version %s", latest.Version, ver.version)
return nil return nil
} }
if possibleNewLatest == nil { if possibleNewLatest == nil {
c.Client.Log.Errorf("Latest version is %s, but no new latest version was found", latest.Version) c.Client.Logger.Errorf("Latest version is %s, but no new latest version was found", latest.Version)
c.Client.Log.Errorf("A manual update of latest at %s might be needed", latest.JSONPath()) c.Client.Logger.Errorf("A manual update of latest at %s might be needed", latest.JSONPath())
return fmt.Errorf("latest version is %s, but no new latest version was found", latest.Version) return fmt.Errorf("latest version is %s, but no new latest version was found", latest.Version)
} }
if c.Client.DryRun { if c.Client.DryRun {
c.Client.Log.Debugf("Would update latest version from %s to %s", latest.Version, possibleNewLatest.Version) c.Client.Logger.Debugf("Would update latest version from %s to %s", latest.Version, possibleNewLatest.Version)
return nil return nil
} }
c.Client.Log.Infof("Updating latest version from %s to %s", latest.Version, possibleNewLatest.Version) c.Client.Logger.Infof("Updating latest version from %s to %s", latest.Version, possibleNewLatest.Version)
if err := c.UpdateVersionLatest(ctx, *possibleNewLatest); err != nil { if err := c.UpdateVersionLatest(ctx, *possibleNewLatest); err != nil {
return fmt.Errorf("updating latest version: %w", err) return fmt.Errorf("updating latest version: %w", err)
} }

View File

@ -11,6 +11,7 @@ import (
"context" "context"
"io" "io"
"net/url" "net/url"
"time"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
@ -34,11 +35,12 @@ type Archivist struct {
// New creates a new Archivist. // New creates a new Archivist.
func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Archivist, CloseFunc, error) { func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Archivist, CloseFunc, error) {
staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{ staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{
Region: region, Region: region,
Bucket: bucket, Bucket: bucket,
DistributionID: distributionID, DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush, CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
}) CacheInvalidationWaitTimeout: 10 * time.Minute,
}, log)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -13,6 +13,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url" "net/url"
"time"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
@ -36,11 +37,12 @@ type Uploader struct {
// New creates a new Uploader. // New creates a new Uploader.
func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Uploader, CloseFunc, error) { func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Uploader, CloseFunc, error) {
staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{ staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{
Region: region, Region: region,
Bucket: bucket, Bucket: bucket,
DistributionID: distributionID, DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush, CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
}) CacheInvalidationWaitTimeout: 10 * time.Minute,
}, log)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -13,6 +13,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/url" "net/url"
"time"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
@ -37,11 +38,12 @@ type Uploader struct {
// New creates a new Uploader. // New creates a new Uploader.
func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Uploader, CloseFunc, error) { func New(ctx context.Context, region, bucket, distributionID string, log *logger.Logger) (*Uploader, CloseFunc, error) {
staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{ staticUploadClient, staticUploadClientClose, err := staticupload.New(ctx, staticupload.Config{
Region: region, Region: region,
Bucket: bucket, Bucket: bucket,
DistributionID: distributionID, DistributionID: distributionID,
CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush, CacheInvalidationStrategy: staticupload.CacheInvalidateBatchOnFlush,
}) CacheInvalidationWaitTimeout: 10 * time.Minute,
}, log)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -13,6 +13,7 @@ go_library(
visibility = ["//:__subpackages__"], visibility = ["//:__subpackages__"],
deps = [ deps = [
"//internal/constants", "//internal/constants",
"//internal/logger",
"@com_github_aws_aws_sdk_go_v2_config//:config", "@com_github_aws_aws_sdk_go_v2_config//:config",
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//:cloudfront", "@com_github_aws_aws_sdk_go_v2_service_cloudfront//:cloudfront",
@ -27,6 +28,7 @@ go_test(
srcs = ["staticupload_test.go"], srcs = ["staticupload_test.go"],
embed = [":staticupload"], embed = [":staticupload"],
deps = [ deps = [
"//internal/logger",
"@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager", "@com_github_aws_aws_sdk_go_v2_feature_s3_manager//:manager",
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//:cloudfront", "@com_github_aws_aws_sdk_go_v2_service_cloudfront//:cloudfront",
"@com_github_aws_aws_sdk_go_v2_service_cloudfront//types", "@com_github_aws_aws_sdk_go_v2_service_cloudfront//types",

View File

@ -25,6 +25,7 @@ import (
cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/google/uuid" "github.com/google/uuid"
) )
@ -100,7 +101,7 @@ func (e *InvalidationError) Unwrap() error {
} }
// New creates a new Client. Call CloseFunc when done with operations. // New creates a new Client. Call CloseFunc when done with operations.
func New(ctx context.Context, config Config) (*Client, CloseFunc, error) { func New(ctx context.Context, config Config, log *logger.Logger) (*Client, CloseFunc, error) {
config.SetsDefault() config.SetsDefault()
cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(config.Region)) cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(config.Region))
if err != nil { if err != nil {
@ -121,6 +122,7 @@ func New(ctx context.Context, config Config) (*Client, CloseFunc, error) {
bucketID: config.Bucket, bucketID: config.Bucket,
logger: log, logger: log,
} }
return client, client.Flush, nil return client, client.Flush, nil
} }
@ -132,6 +134,7 @@ func (c *Client) Flush(ctx context.Context) error {
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()
c.logger.Debugf("Invalidating keys: %s", c.dirtyKeys)
if len(c.dirtyKeys) == 0 { if len(c.dirtyKeys) == 0 {
return nil return nil
} }
@ -211,16 +214,17 @@ func (c *Client) invalidateCacheForKeys(ctx context.Context, keys []string) (str
// waitForInvalidations waits for all invalidations to finish. // waitForInvalidations waits for all invalidations to finish.
func (c *Client) waitForInvalidations(ctx context.Context) error { func (c *Client) waitForInvalidations(ctx context.Context) error {
if c.cacheInvalidationWaitTimeout == 0 { if c.cacheInvalidationWaitTimeout == 0 {
c.logger.Warnf("cacheInvalidationWaitTimeout set to 0, not waiting for invalidations to finish")
return nil return nil
} }
waiter := cloudfront.NewInvalidationCompletedWaiter(c.cdnClient) waiter := cloudfront.NewInvalidationCompletedWaiter(c.cdnClient)
c.logger.Debugf("Waiting for invalidations %s in distribution %s", c.invalidationIDs, c.distributionID)
for _, invalidationID := range c.invalidationIDs { for _, invalidationID := range c.invalidationIDs {
waitIn := &cloudfront.GetInvalidationInput{ waitIn := &cloudfront.GetInvalidationInput{
DistributionId: &c.distributionID, DistributionId: &c.distributionID,
Id: &invalidationID, Id: &invalidationID,
} }
c.logger.Debugf("Waiting for invalidation %s in distribution %s", invalidationID, c.distributionID)
if err := waiter.Wait(ctx, waitIn, c.cacheInvalidationWaitTimeout); err != nil { if err := waiter.Wait(ctx, waitIn, c.cacheInvalidationWaitTimeout); err != nil {
return NewInvalidationError(fmt.Errorf("waiting for invalidation to complete: %w", err)) return NewInvalidationError(fmt.Errorf("waiting for invalidation to complete: %w", err))
} }

View File

@ -21,6 +21,7 @@ import (
cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types"
"github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
@ -105,6 +106,7 @@ func TestUpload(t *testing.T) {
distributionID: "test-distribution-id", distributionID: "test-distribution-id",
cacheInvalidationStrategy: tc.cacheInvalidationStrategy, cacheInvalidationStrategy: tc.cacheInvalidationStrategy,
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout, cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
logger: logger.NewTest(t),
} }
_, err := client.Upload(context.Background(), tc.in) _, err := client.Upload(context.Background(), tc.in)
@ -216,6 +218,7 @@ func TestDeleteObject(t *testing.T) {
distributionID: "test-distribution-id", distributionID: "test-distribution-id",
cacheInvalidationStrategy: tc.cacheInvalidationStrategy, cacheInvalidationStrategy: tc.cacheInvalidationStrategy,
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout, cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
logger: logger.NewTest(t),
} }
_, err := client.DeleteObject(context.Background(), newObjectInput(tc.nilInput, tc.nilKey)) _, err := client.DeleteObject(context.Background(), newObjectInput(tc.nilInput, tc.nilKey))
@ -254,6 +257,7 @@ func TestDeleteObject(t *testing.T) {
distributionID: "test-distribution-id", distributionID: "test-distribution-id",
cacheInvalidationStrategy: tc.cacheInvalidationStrategy, cacheInvalidationStrategy: tc.cacheInvalidationStrategy,
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout, cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
logger: logger.NewTest(t),
} }
_, err := client.DeleteObjects(context.Background(), newObjectsInput(tc.nilInput, tc.nilKey)) _, err := client.DeleteObjects(context.Background(), newObjectsInput(tc.nilInput, tc.nilKey))
@ -395,6 +399,7 @@ func TestFlush(t *testing.T) {
cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout, cacheInvalidationWaitTimeout: tc.cacheInvalidationWaitTimeout,
dirtyKeys: tc.dirtyKeys, dirtyKeys: tc.dirtyKeys,
invalidationIDs: tc.invalidationIDs, invalidationIDs: tc.invalidationIDs,
logger: logger.NewTest(t),
} }
err := client.Flush(context.Background()) err := client.Flush(context.Background())
@ -413,7 +418,7 @@ func TestFlush(t *testing.T) {
} }
} }
func TestConcurrency(_ *testing.T) { func TestConcurrency(t *testing.T) {
newInput := func() *s3.PutObjectInput { newInput := func() *s3.PutObjectInput {
return &s3.PutObjectInput{ return &s3.PutObjectInput{
Bucket: ptr("test-bucket"), Bucket: ptr("test-bucket"),
@ -432,6 +437,7 @@ func TestConcurrency(_ *testing.T) {
uploadClient: uploadClient, uploadClient: uploadClient,
distributionID: "test-distribution-id", distributionID: "test-distribution-id",
cacheInvalidationWaitTimeout: 50 * time.Millisecond, cacheInvalidationWaitTimeout: 50 * time.Millisecond,
logger: logger.NewTest(t),
} }
var wg sync.WaitGroup var wg sync.WaitGroup