From b3f251909c52eaa4893b8f8e36485d8d093bcd7e Mon Sep 17 00:00:00 2001 From: Paul Meyer <49727155+katexochen@users.noreply.github.com> Date: Thu, 29 Dec 2022 17:24:55 +0100 Subject: [PATCH] versionsapi: implement Go client Signed-off-by: Paul Meyer <49727155+katexochen@users.noreply.github.com> --- internal/versionsapi/client/client.go | 432 ++++++++++++++++++++++++++ 1 file changed, 432 insertions(+) create mode 100644 internal/versionsapi/client/client.go diff --git a/internal/versionsapi/client/client.go b/internal/versionsapi/client/client.go new file mode 100644 index 000000000..d43d3b005 --- /dev/null +++ b/internal/versionsapi/client/client.go @@ -0,0 +1,432 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package client + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "path" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/cloudfront" + cftypes "github.com/aws/aws-sdk-go-v2/service/cloudfront/types" + "github.com/aws/aws-sdk-go-v2/service/s3" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/edgelesssys/constellation/v2/internal/constants" + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/internal/versionsapi" + "go.uber.org/multierr" + "golang.org/x/mod/semver" +) + +// Client is the client for the versions API. +// +// The client needs to be authenticated with AWS. It is the interface that should +// be used in internal development and CI tools. +// +// Needed IAM permissions for read mode: +// - "s3:GetObject" +// - "s3:ListBucket" +// +// Additional needed IAM permissions for write mode: +// - "s3:PutObject" +// - "s3:DeleteObject" +// - "cloudfront:CreateInvalidation" +// +// Thread-safety of the bucket is not guaranteed. The client is not thread-safe. +type Client struct { + config aws.Config + cloudfrontClient *cloudfront.Client + s3Client *s3.Client + uploadClient *s3manager.Uploader + bucket string + distributionID string + cacheInvalidationWaitTimeout time.Duration + + dirtyPaths []string // written paths to be invalidated + dryRun bool // no write operations are performed + + log *logger.Logger +} + +// NewClient creates a new client for the versions API. +func NewClient(ctx context.Context, region, bucket, distributionID string, dryRun bool, + log *logger.Logger, +) (*Client, error) { + cfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region)) + if err != nil { + return nil, err + } + + cloudfrontC := cloudfront.NewFromConfig(cfg) + s3C := s3.NewFromConfig(cfg) + uploadC := s3manager.NewUploader(s3C) + + return &Client{ + config: cfg, + cloudfrontClient: cloudfrontC, + s3Client: s3C, + uploadClient: uploadC, + bucket: bucket, + distributionID: distributionID, + dryRun: dryRun, + log: log, + cacheInvalidationWaitTimeout: 5 * time.Minute, + }, nil +} + +// FetchVersionList fetches the given version list from the versions API. +func (c *Client) FetchVersionList(ctx context.Context, list versionsapi.List) (versionsapi.List, error) { + return fetch(ctx, c, list) +} + +// UpdateVersionList updates the given version list in the versions API. +func (c *Client) UpdateVersionList(ctx context.Context, list versionsapi.List) error { + semver.Sort(list.Versions) + return update(ctx, c, list) +} + +// FetchVersionLatest fetches the latest version from the versions API. +func (c *Client) FetchVersionLatest(ctx context.Context, latest versionsapi.Latest) (versionsapi.Latest, error) { + return fetch(ctx, c, latest) +} + +// UpdateVersionLatest updates the latest version in the versions API. +func (c *Client) UpdateVersionLatest(ctx context.Context, latest versionsapi.Latest) error { + return update(ctx, c, latest) +} + +// FetchImageInfo fetches the given image info from the versions API. +func (c *Client) FetchImageInfo(ctx context.Context, imageInfo versionsapi.ImageInfo) (versionsapi.ImageInfo, error) { + return fetch(ctx, c, imageInfo) +} + +// UpdateImageInfo updates the given image info in the versions API. +func (c *Client) UpdateImageInfo(ctx context.Context, imageInfo versionsapi.ImageInfo) error { + return update(ctx, c, imageInfo) +} + +// DeleteRef deletes the given ref from the versions API. +func (c *Client) DeleteRef(ctx context.Context, ref string) error { + if err := versionsapi.ValidateRef(ref); err != nil { + return fmt.Errorf("validating ref: %w", err) + } + + refPath := path.Join(constants.CDNAPIPrefix, "ref", ref) + if err := c.deletePath(ctx, refPath); err != nil { + return fmt.Errorf("deleting ref path: %w", err) + } + + return nil +} + +// DeleteVersion deletes the given version from the versions API. +// The version will be removed from version lists and latest versions, and the versioned +// objects are deleted. +// Notice that the versions API can get into an inconsistent state if the version is the latest +// version but there is no older version of the same minor version available. +// Manual update of latest versions is required in this case. +func (c *Client) DeleteVersion(ctx context.Context, ver versionsapi.Version) error { + var retErr error + + c.log.Debugf("Deleting version %s from minor version list", ver.Version) + possibleNewLatest, err := c.deleteVersionFromMinorVersionList(ctx, ver) + if err != nil { + retErr = multierr.Append(retErr, fmt.Errorf("removing from minor version list: %w", err)) + } + + c.log.Debugf("Checking latest version for %s", ver.Version) + if err := c.deleteVersionFromLatest(ctx, ver, possibleNewLatest); err != nil { + retErr = multierr.Append(retErr, fmt.Errorf("updating latest version: %w", err)) + } + + return retErr +} + +// InvalidateCache invalidates the CDN cache for the paths that have been written. +// The function should be deferred after the client has been created. +func (c *Client) InvalidateCache(ctx context.Context) error { + if len(c.dirtyPaths) == 0 { + c.log.Debugf("No dirty paths, skipping cache invalidation.") + return nil + } + + if c.dryRun { + c.log.Debugf("DryRun: cloudfront create invalidation {DistributionID: %v, Paths: %v}", c.distributionID, c.dirtyPaths) + return nil + } + + c.log.Debugf("Paths to invalidate: %v", c.dirtyPaths) + + in := &cloudfront.CreateInvalidationInput{ + DistributionId: &c.distributionID, + InvalidationBatch: &cftypes.InvalidationBatch{ + CallerReference: ptr(fmt.Sprintf("%d", time.Now().Unix())), + Paths: &cftypes.Paths{ + Items: c.dirtyPaths, + Quantity: ptr(int32(len(c.dirtyPaths))), + }, + }, + } + invalidation, err := c.cloudfrontClient.CreateInvalidation(ctx, in) + if err != nil { + return fmt.Errorf("creating invalidation: %w", err) + } + + c.log.Debugf("Waiting for invalidation %s to complete.", *invalidation.Invalidation.Id) + waiter := cloudfront.NewInvalidationCompletedWaiter(c.cloudfrontClient) + waitIn := &cloudfront.GetInvalidationInput{ + DistributionId: &c.distributionID, + Id: invalidation.Invalidation.Id, + } + if err := waiter.Wait(ctx, waitIn, c.cacheInvalidationWaitTimeout); err != nil { + return fmt.Errorf("waiting for invalidation to complete: %w", err) + } + + return nil +} + +type apiObject interface { + ValidateRequest() error + Validate() error + JSONPath() string +} + +func fetch[T apiObject](ctx context.Context, c *Client, obj T) (T, error) { + if err := obj.ValidateRequest(); err != nil { + return *new(T), fmt.Errorf("validating request for %T: %w", obj, err) + } + + in := &s3.GetObjectInput{ + Bucket: &c.bucket, + Key: ptr(obj.JSONPath()), + } + + c.log.Debugf("Fetching %T from s3: %s", obj, obj.JSONPath()) + out, err := c.s3Client.GetObject(ctx, in) + var noSuchkey *s3types.NoSuchKey + if errors.As(err, &noSuchkey) { + return *new(T), &NotFoundError{err: err} + } else if err != nil { + return *new(T), fmt.Errorf("getting s3 object at %s: %w", obj.JSONPath(), err) + } + defer out.Body.Close() + + var newObj T + if err := json.NewDecoder(out.Body).Decode(&newObj); err != nil { + return *new(T), fmt.Errorf("decoding %T: %w", obj, err) + } + + if newObj.Validate() != nil { + return *new(T), fmt.Errorf("received invalid %T: %w", newObj, newObj.Validate()) + } + + return newObj, nil +} + +func update[T apiObject](ctx context.Context, c *Client, obj T) error { + if err := obj.Validate(); err != nil { + return fmt.Errorf("validating %T struct: %w", obj, err) + } + + rawJSON, err := json.Marshal(obj) + if err != nil { + return fmt.Errorf("marshaling %T struct: %w", obj, err) + } + + if c.dryRun { + c.log.Debugf("DryRun: s3 put object {Bucket: %v, Key: %v, Body: %v", c.bucket, obj.JSONPath(), string(rawJSON)) + return nil + } + + in := &s3.PutObjectInput{ + Bucket: &c.bucket, + Key: ptr(obj.JSONPath()), + Body: bytes.NewBuffer(rawJSON), + } + + c.dirtyPaths = append(c.dirtyPaths, "/"+obj.JSONPath()) + + c.log.Debugf("Uploading %T to s3: %v", obj, obj.JSONPath()) + if _, err := c.uploadClient.Upload(ctx, in); err != nil { + return fmt.Errorf("uploading %T: %w", obj, err) + } + + return nil +} + +func (c *Client) deleteVersionFromMinorVersionList(ctx context.Context, ver versionsapi.Version, +) (*versionsapi.Latest, error) { + minorList := versionsapi.List{ + Ref: ver.Ref, + Stream: ver.Stream, + Granularity: versionsapi.GranularityMinor, + Base: ver.WithGranularity(versionsapi.GranularityMinor), + Kind: versionsapi.VersionKindImage, + } + c.log.Debugf("Fetching minor version list for version %s", ver.Version) + minorList, err := c.FetchVersionList(ctx, minorList) + var notFoundErr *NotFoundError + if errors.As(err, ¬FoundErr) { + c.log.Warnf("Minor version list for version %s not found", ver.Version) + c.log.Warnf("Skipping update of minor version list") + return nil, nil + } else if err != nil { + return nil, fmt.Errorf("fetching minor version list for version %s: %w", ver.Version, err) + } + + if !minorList.Contains(ver.Version) { + c.log.Warnf("Version %s is not in minor version list %s", ver.Version, minorList.JSONPath()) + c.log.Warnf("Skipping update of minor version list") + return nil, nil + } + + semver.Sort(minorList.Versions) + for i, v := range minorList.Versions { + if v == ver.Version { + minorList.Versions = append(minorList.Versions[:i], minorList.Versions[i+1:]...) + break + } + } + + var latest *versionsapi.Latest + if len(minorList.Versions) != 0 { + latest = &versionsapi.Latest{ + Ref: ver.Ref, + Stream: ver.Stream, + Kind: versionsapi.VersionKindImage, + Version: minorList.Versions[len(minorList.Versions)-1], + } + c.log.Debugf("Possible latest version replacement %q", latest.Version) + } + + if c.dryRun { + c.log.Debugf("DryRun: Updating minor version list %s to %v", minorList.JSONPath(), minorList) + return latest, nil + } + + c.log.Debugf("Updating minor version list %s", minorList.JSONPath()) + if err := c.UpdateVersionList(ctx, minorList); err != nil { + return latest, fmt.Errorf("updating minor version list %s: %w", minorList.JSONPath(), err) + } + + c.log.Debugf("Removed version %s from minor version list %s", ver.Version, minorList.JSONPath()) + return latest, nil +} + +func (c *Client) deleteVersionFromLatest(ctx context.Context, ver versionsapi.Version, possibleNewLatest *versionsapi.Latest, +) error { + latest := versionsapi.Latest{ + Ref: ver.Ref, + Stream: ver.Stream, + Kind: versionsapi.VersionKindImage, + } + c.log.Debugf("Fetching latest version from %s.", latest.JSONPath()) + latest, err := c.FetchVersionLatest(ctx, latest) + var notFoundErr *NotFoundError + if errors.As(err, ¬FoundErr) { + c.log.Warnf("Latest version for %s not found.", latest.JSONPath()) + return nil + } else if err != nil { + return fmt.Errorf("fetching latest version: %w", err) + } + + if latest.Version != ver.Version { + c.log.Debugf("Latest version is %s, not the deleted version %s.", latest.Version, ver.Version) + return nil + } + + if possibleNewLatest == nil { + c.log.Errorf("Latest version is %s, but no new latest version was found.", latest.Version) + c.log.Errorf("A manual update of latest at %s might be needed", latest.JSONPath()) + return fmt.Errorf("latest version is %s, but no new latest version was found", latest.Version) + } + + if c.dryRun { + c.log.Debugf("Would update latest version from %s to %s", latest.Version, possibleNewLatest.Version) + return nil + } + + c.log.Infof("Updating latest version from %s to %s", latest.Version, possibleNewLatest.Version) + if err := c.UpdateVersionLatest(ctx, *possibleNewLatest); err != nil { + return fmt.Errorf("updating latest version: %w", err) + } + + return nil +} + +func (c *Client) deletePath(ctx context.Context, path string) error { + listIn := &s3.ListObjectsV2Input{ + Bucket: &c.bucket, + Prefix: &path, + } + c.log.Debugf("Listing objects in %s", path) + objs := []s3types.Object{} + out := &s3.ListObjectsV2Output{IsTruncated: true} + for out.IsTruncated { + var err error + out, err = c.s3Client.ListObjectsV2(ctx, listIn) + if err != nil { + return fmt.Errorf("listing objects in %s: %w", path, err) + } + objs = append(objs, out.Contents...) + } + c.log.Debugf("Found %d objects in %s", len(objs), path) + + if len(objs) == 0 { + c.log.Warnf("Path %s is already empty", path) + return nil + } + + objIDs := make([]s3types.ObjectIdentifier, len(objs)) + for i, obj := range objs { + objIDs[i] = s3types.ObjectIdentifier{Key: obj.Key} + } + + if c.dryRun { + c.log.Debugf("DryRun: Deleting %d objects with IDs %v", len(objs), objIDs) + return nil + } + + c.dirtyPaths = append(c.dirtyPaths, "/"+path) + + deleteIn := &s3.DeleteObjectsInput{ + Bucket: &c.bucket, + Delete: &s3types.Delete{ + Objects: objIDs, + }, + } + c.log.Debugf("Deleting %d objects in %s", len(objs), path) + if _, err := c.s3Client.DeleteObjects(ctx, deleteIn); err != nil { + return fmt.Errorf("deleting objects in %s: %w", path, err) + } + + return nil +} + +// NotFoundError is an error that is returned when a resource is not found. +type NotFoundError struct { + err error +} + +func (e *NotFoundError) Error() string { + return fmt.Sprintf("the requested resource was not found: %s", e.err.Error()) +} + +func (e *NotFoundError) Unwrap() error { + return e.err +} + +func ptr[T any](t T) *T { + return &t +}