cli: image info (v2)

This commit is contained in:
Malte Poll 2023-05-23 09:17:27 +02:00 committed by Malte Poll
parent cd7b116794
commit d0e53cbb59
37 changed files with 429 additions and 461 deletions

View file

@ -17,7 +17,6 @@ go_library(
deps = [
"//cli/internal/clusterid",
"//cli/internal/iamid",
"//cli/internal/image",
"//cli/internal/libvirt",
"//cli/internal/terraform",
"//internal/atls",
@ -27,6 +26,7 @@ go_library(
"//internal/cloud/gcpshared",
"//internal/config",
"//internal/constants",
"//internal/imagefetcher",
"//internal/variant",
"@com_github_azure_azure_sdk_for_go//profiles/latest/attestation/attestation",
"@com_github_azure_azure_sdk_for_go_sdk_azcore//policy",
@ -54,6 +54,7 @@ go_test(
"//internal/cloud/cloudprovider",
"//internal/cloud/gcpshared",
"//internal/config",
"//internal/variant",
"@com_github_hashicorp_terraform_json//:terraform-json",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",

View file

@ -12,13 +12,16 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/variant"
tfjson "github.com/hashicorp/terraform-json"
)
// imageFetcher gets an image reference from the versionsapi.
type imageFetcher interface {
FetchReference(ctx context.Context, config *config.Config) (string, error)
FetchReference(ctx context.Context,
provider cloudprovider.Provider, attestationVariant variant.Variant,
image, region string,
) (string, error)
}
type terraformClient interface {

View file

@ -13,7 +13,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/variant"
tfjson "github.com/hashicorp/terraform-json"
"go.uber.org/goleak"
@ -103,7 +103,10 @@ type stubImageFetcher struct {
fetchReferenceErr error
}
func (f *stubImageFetcher) FetchReference(_ context.Context, _ *config.Config) (string, error) {
func (f *stubImageFetcher) FetchReference(_ context.Context,
_ cloudprovider.Provider, _ variant.Variant,
_, _ string,
) (string, error) {
return f.reference, f.fetchReferenceErr
}

View file

@ -25,12 +25,12 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
"github.com/edgelesssys/constellation/v2/cli/internal/image"
"github.com/edgelesssys/constellation/v2/cli/internal/libvirt"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/imagefetcher"
"github.com/edgelesssys/constellation/v2/internal/variant"
)
@ -48,7 +48,7 @@ type Creator struct {
func NewCreator(out io.Writer) *Creator {
return &Creator{
out: out,
image: image.New(),
image: imagefetcher.New(),
newTerraformClient: func(ctx context.Context) (terraformClient, error) {
return terraform.New(ctx, constants.TerraformWorkingDir)
},
@ -56,7 +56,7 @@ func NewCreator(out io.Writer) *Creator {
return libvirt.New()
},
newRawDownloader: func() rawDownloader {
return image.NewDownloader()
return imagefetcher.NewDownloader()
},
policyPatcher: policyPatcher{},
}
@ -75,7 +75,10 @@ type CreateOptions struct {
// Create creates the handed amount of instances and all the needed resources.
func (c *Creator) Create(ctx context.Context, opts CreateOptions) (clusterid.File, error) {
image, err := c.image.FetchReference(ctx, opts.Config)
provider := opts.Config.GetProvider()
attestationVariant := opts.Config.GetAttestationConfig().GetVariant()
region := opts.Config.GetRegion()
image, err := c.image.FetchReference(ctx, provider, attestationVariant, opts.Config.Image, region)
if err != nil {
return clusterid.File{}, fmt.Errorf("fetching image reference: %w", err)
}

View file

@ -41,7 +41,6 @@ go_library(
"//cli/internal/clusterid",
"//cli/internal/helm",
"//cli/internal/iamid",
"//cli/internal/image",
"//cli/internal/kubernetes",
"//cli/internal/libvirt",
"//cli/internal/terraform",
@ -62,6 +61,7 @@ go_library(
"//internal/file",
"//internal/grpc/dialer",
"//internal/grpc/retry",
"//internal/imagefetcher",
"//internal/kms/uri",
"//internal/kubernetes/kubectl",
"//internal/license",

View file

@ -16,7 +16,6 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/image"
"github.com/edgelesssys/constellation/v2/cli/internal/kubernetes"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/cli/internal/upgrade"
@ -25,6 +24,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/imagefetcher"
"github.com/edgelesssys/constellation/v2/internal/variant"
"github.com/spf13/afero"
"github.com/spf13/cobra"
@ -65,7 +65,7 @@ func runUpgradeApply(cmd *cobra.Command, _ []string) error {
return err
}
fetcher := image.New()
fetcher := imagefetcher.New()
applyCmd := upgradeApplyCmd{upgrader: upgrader, log: log, fetcher: fetcher}
return applyCmd.upgradeApply(cmd, fileHandler)
@ -194,7 +194,10 @@ func (u *upgradeApplyCmd) migrateTerraform(cmd *cobra.Command, file file.Handler
func (u *upgradeApplyCmd) parseUpgradeVars(cmd *cobra.Command, conf *config.Config, fetcher imageFetcher) ([]string, terraform.Variables, error) {
// Fetch variables to execute Terraform script with
imageRef, err := fetcher.FetchReference(cmd.Context(), conf)
provider := conf.GetProvider()
attestationVariant := conf.GetAttestationConfig().GetVariant()
region := conf.GetRegion()
imageRef, err := fetcher.FetchReference(cmd.Context(), provider, attestationVariant, conf.Image, region)
if err != nil {
return nil, nil, fmt.Errorf("fetching image reference: %w", err)
}
@ -264,7 +267,10 @@ func (u *upgradeApplyCmd) parseUpgradeVars(cmd *cobra.Command, conf *config.Conf
}
type imageFetcher interface {
FetchReference(ctx context.Context, conf *config.Config) (string, error)
FetchReference(ctx context.Context,
provider cloudprovider.Provider, attestationVariant variant.Variant,
image, region string,
) (string, error)
}
// upgradeAttestConfigIfDiff checks if the locally configured measurements are different from the cluster's measurements.

View file

@ -200,6 +200,9 @@ type stubImageFetcher struct {
fetchReferenceErr error
}
func (s stubImageFetcher) FetchReference(context.Context, *config.Config) (string, error) {
return "", s.fetchReferenceErr
func (f stubImageFetcher) FetchReference(_ context.Context,
_ cloudprovider.Provider, _ variant.Variant,
_, _ string,
) (string, error) {
return "", f.fetchReferenceErr
}

View file

@ -1,40 +0,0 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//bazel/go:go_test.bzl", "go_test")
go_library(
name = "image",
srcs = [
"image.go",
"raw.go",
],
importpath = "github.com/edgelesssys/constellation/v2/cli/internal/image",
visibility = ["//cli:__subpackages__"],
deps = [
"//internal/cloud/cloudprovider",
"//internal/config",
"//internal/variant",
"//internal/versionsapi",
"//internal/versionsapi/fetcher",
"@com_github_schollz_progressbar_v3//:progressbar",
"@com_github_spf13_afero//:afero",
],
)
go_test(
name = "image_test",
srcs = [
"image_test.go",
"raw_test.go",
],
embed = [":image"],
deps = [
"//internal/cloud/cloudprovider",
"//internal/config",
"//internal/file",
"//internal/versionsapi",
"@com_github_spf13_afero//:afero",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
"@org_uber_go_goleak//:goleak",
],
)

View file

@ -1,167 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
/*
Package image provides helping wrappers around a versionsapi fetcher.
It also enables local image overrides and download of raw images.
*/
package image
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/fs"
"regexp"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/variant"
"github.com/edgelesssys/constellation/v2/internal/versionsapi"
"github.com/edgelesssys/constellation/v2/internal/versionsapi/fetcher"
"github.com/spf13/afero"
)
// Fetcher fetches image references using a lookup table.
type Fetcher struct {
fetcher versionsAPIImageInfoFetcher
fs *afero.Afero
}
// New returns a new image fetcher.
func New() *Fetcher {
return &Fetcher{
fetcher: fetcher.NewFetcher(),
fs: &afero.Afero{Fs: afero.NewOsFs()},
}
}
// FetchReference fetches the image reference for a given image version uid, CSP and image variant.
func (f *Fetcher) FetchReference(ctx context.Context, config *config.Config) (string, error) {
provider := config.GetProvider()
variant, err := imageVariant(provider, config)
if err != nil {
return "", fmt.Errorf("determining variant: %w", err)
}
ver, err := versionsapi.NewVersionFromShortPath(config.Image, versionsapi.VersionKindImage)
if err != nil {
return "", fmt.Errorf("parsing config image short path: %w", err)
}
imgInfoReq := versionsapi.ImageInfo{
Ref: ver.Ref,
Stream: ver.Stream,
Version: ver.Version,
}
url, err := imgInfoReq.URL()
if err != nil {
return "", err
}
imgInfo, err := getFromFile(f.fs, imgInfoReq)
if err != nil && errors.Is(err, fs.ErrNotExist) {
imgInfo, err = f.fetcher.FetchImageInfo(ctx, imgInfoReq)
}
var notFoundErr *fetcher.NotFoundError
switch {
case errors.As(err, &notFoundErr):
overridePath := imageInfoFilename(imgInfoReq)
return "", fmt.Errorf("image info file not found locally at %q or remotely at %s", overridePath, url)
case err != nil:
return "", err
}
if err := imgInfo.Validate(); err != nil {
return "", fmt.Errorf("validating image info file: %w", err)
}
return getReferenceFromImageInfo(provider, variant, imgInfo)
}
// imageVariant returns the image variant for a given CSP and configuration.
func imageVariant(provider cloudprovider.Provider, config *config.Config) (string, error) {
switch provider {
case cloudprovider.AWS:
return config.Provider.AWS.Region, nil
case cloudprovider.Azure:
if config.GetAttestationConfig().GetVariant().Equal(variant.AzureTrustedLaunch{}) {
return "trustedlaunch", nil
}
return "cvm", nil
case cloudprovider.GCP:
return "sev-es", nil
case cloudprovider.OpenStack:
return "sev", nil
case cloudprovider.QEMU:
return "default", nil
default:
return "", fmt.Errorf("unsupported provider: %s", provider)
}
}
func getFromFile(fs *afero.Afero, imgInfo versionsapi.ImageInfo) (versionsapi.ImageInfo, error) {
fileName := imageInfoFilename(imgInfo)
raw, err := fs.ReadFile(fileName)
if err != nil {
return versionsapi.ImageInfo{}, err
}
var newInfo versionsapi.ImageInfo
if err := json.Unmarshal(raw, &newInfo); err != nil {
return versionsapi.ImageInfo{}, fmt.Errorf("decoding image info file: %w", err)
}
return newInfo, nil
}
var filenameReplaceRegexp = regexp.MustCompile(`([^a-zA-Z0-9.-])`)
func imageInfoFilename(imgInfo versionsapi.ImageInfo) string {
path := imgInfo.JSONPath()
return filenameReplaceRegexp.ReplaceAllString(path, "_")
}
// getReferenceFromImageInfo returns the image reference for a given CSP and image variant.
func getReferenceFromImageInfo(provider cloudprovider.Provider, variant string, imgInfo versionsapi.ImageInfo,
) (string, error) {
var providerList map[string]string
switch provider {
case cloudprovider.AWS:
providerList = imgInfo.AWS
case cloudprovider.Azure:
providerList = imgInfo.Azure
case cloudprovider.GCP:
providerList = imgInfo.GCP
case cloudprovider.OpenStack:
providerList = imgInfo.OpenStack
case cloudprovider.QEMU:
providerList = imgInfo.QEMU
default:
return "", fmt.Errorf("image not available in image info for CSP %q", provider.String())
}
if providerList == nil {
return "", fmt.Errorf("image not available in image info for CSP %q", provider.String())
}
ref, ok := providerList[variant]
if !ok {
return "", fmt.Errorf("image not available in image info for variant %q", variant)
}
return ref, nil
}
type versionsAPIImageInfoFetcher interface {
FetchImageInfo(ctx context.Context, imageInfo versionsapi.ImageInfo) (versionsapi.ImageInfo, error)
}

View file

@ -1,348 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package image
import (
"context"
"encoding/json"
"errors"
"net/http"
"testing"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/versionsapi"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestGetReference(t *testing.T) {
testCases := map[string]struct {
info versionsapi.ImageInfo
provider cloudprovider.Provider
variant string
wantReference string
wantErr bool
}{
"reference exists aws": {
info: versionsapi.ImageInfo{AWS: map[string]string{"someVariant": "someReference"}},
provider: cloudprovider.AWS,
variant: "someVariant",
wantReference: "someReference",
},
"reference exists azure": {
info: versionsapi.ImageInfo{Azure: map[string]string{"someVariant": "someReference"}},
provider: cloudprovider.Azure,
variant: "someVariant",
wantReference: "someReference",
},
"reference exists gcp": {
info: versionsapi.ImageInfo{GCP: map[string]string{"someVariant": "someReference"}},
provider: cloudprovider.GCP,
variant: "someVariant",
wantReference: "someReference",
},
"reference exists openstack": {
info: versionsapi.ImageInfo{OpenStack: map[string]string{"someVariant": "someReference"}},
provider: cloudprovider.OpenStack,
variant: "someVariant",
wantReference: "someReference",
},
"reference exists qemu": {
info: versionsapi.ImageInfo{QEMU: map[string]string{"someVariant": "someReference"}},
provider: cloudprovider.QEMU,
variant: "someVariant",
wantReference: "someReference",
},
"csp does not exist": {
info: versionsapi.ImageInfo{AWS: map[string]string{"someVariant": "someReference"}},
provider: cloudprovider.Unknown,
variant: "someVariant",
wantErr: true,
},
"variant does not exist": {
info: versionsapi.ImageInfo{AWS: map[string]string{"someVariant": "someReference"}},
provider: cloudprovider.AWS,
variant: "nonExistingVariant",
wantErr: true,
},
"info is empty": {
info: versionsapi.ImageInfo{},
provider: cloudprovider.AWS,
variant: "someVariant",
wantErr: true,
},
"csp is nil": {
info: versionsapi.ImageInfo{AWS: nil},
provider: cloudprovider.AWS,
variant: "someVariant",
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
reference, err := getReferenceFromImageInfo(tc.provider, tc.variant, tc.info)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantReference, reference)
})
}
}
func TestImageVariant(t *testing.T) {
testCases := map[string]struct {
csp cloudprovider.Provider
config *config.Config
wantVariant string
wantErr bool
}{
"AWS region": {
csp: cloudprovider.AWS,
config: &config.Config{Image: "someImage", Provider: config.ProviderConfig{
AWS: &config.AWSConfig{Region: "someRegion"},
}},
wantVariant: "someRegion",
},
"Azure cvm": {
csp: cloudprovider.Azure,
config: &config.Config{
Image: "someImage", Provider: config.ProviderConfig{Azure: &config.AzureConfig{}},
Attestation: config.AttestationConfig{AzureSEVSNP: &config.AzureSEVSNP{}},
},
wantVariant: "cvm",
},
"Azure trustedlaunch": {
csp: cloudprovider.Azure,
config: &config.Config{
Image: "someImage", Provider: config.ProviderConfig{Azure: &config.AzureConfig{}},
Attestation: config.AttestationConfig{AzureTrustedLaunch: &config.AzureTrustedLaunch{}},
},
wantVariant: "trustedlaunch",
},
"GCP": {
csp: cloudprovider.GCP,
config: &config.Config{Image: "someImage", Provider: config.ProviderConfig{
GCP: &config.GCPConfig{},
}},
wantVariant: "sev-es",
},
"OpenStack": {
csp: cloudprovider.OpenStack,
config: &config.Config{Image: "someImage", Provider: config.ProviderConfig{
OpenStack: &config.OpenStackConfig{},
}},
wantVariant: "sev",
},
"QEMU": {
csp: cloudprovider.QEMU,
config: &config.Config{Image: "someImage", Provider: config.ProviderConfig{
QEMU: &config.QEMUConfig{},
}},
wantVariant: "default",
},
"invalid": {
csp: cloudprovider.Provider(9999),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
vari, err := imageVariant(tc.csp, tc.config)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantVariant, vari)
})
}
}
func TestFetchReference(t *testing.T) {
img := "ref/abc/stream/nightly/v1.2.3"
newImgInfo := func() versionsapi.ImageInfo {
return versionsapi.ImageInfo{
Ref: "abc",
Stream: "nightly",
Version: "v1.2.3",
QEMU: map[string]string{"default": "someReference"},
AWS: map[string]string{"foo": "bar"},
Azure: map[string]string{"foo": "bar"},
GCP: map[string]string{"foo": "bar"},
}
}
imgInfoPath := imageInfoFilename(newImgInfo())
testCases := map[string]struct {
config *config.Config
imageInfoFetcher versionsAPIImageInfoFetcher
localFile []byte
wantReference string
wantErr bool
}{
"reference fetched remotely": {
config: &config.Config{
Image: img,
Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{}},
},
imageInfoFetcher: &stubVersionsAPIImageFetcher{
fetchImageInfoInfo: newImgInfo(),
},
wantReference: "someReference",
},
"reference fetched remotely fails": {
config: &config.Config{
Image: img,
Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{}},
},
imageInfoFetcher: &stubVersionsAPIImageFetcher{
fetchImageInfoErr: errors.New("failed"),
},
wantErr: true,
},
"reference fetched locally": {
config: &config.Config{
Image: img,
Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{}},
},
localFile: func() []byte {
info := newImgInfo()
info.QEMU["default"] = "localOverrideReference"
file, err := json.Marshal(info)
require.NoError(t, err)
return file
}(),
wantReference: "localOverrideReference",
},
"local file first": {
config: &config.Config{
Image: img,
Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{}},
},
imageInfoFetcher: &stubVersionsAPIImageFetcher{
fetchImageInfoInfo: newImgInfo(),
},
localFile: func() []byte {
info := newImgInfo()
info.QEMU["default"] = "localOverrideReference"
file, err := json.Marshal(info)
require.NoError(t, err)
return file
}(),
wantReference: "localOverrideReference",
},
"local file is invalid": {
config: &config.Config{
Image: img,
Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{}},
},
localFile: []byte("invalid"),
wantErr: true,
},
"local file has invalid image info": {
config: &config.Config{
Image: img,
Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{}},
},
localFile: func() []byte {
info := newImgInfo()
info.Ref = ""
file, err := json.Marshal(info)
require.NoError(t, err)
return file
}(),
wantErr: true,
},
"image version does not exist": {
config: &config.Config{
Image: "nonExistingImageName",
Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{}},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
af := &afero.Afero{Fs: fs}
if tc.localFile != nil {
fh := file.NewHandler(af)
require.NoError(fh.Write(imgInfoPath, tc.localFile))
}
fetcher := &Fetcher{
fetcher: tc.imageInfoFetcher,
fs: af,
}
reference, err := fetcher.FetchReference(context.Background(), tc.config)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantReference, reference)
})
}
}
type stubVersionsAPIImageFetcher struct {
fetchImageInfoInfo versionsapi.ImageInfo
fetchImageInfoErr error
}
func (f *stubVersionsAPIImageFetcher) FetchImageInfo(_ context.Context, _ versionsapi.ImageInfo) (
versionsapi.ImageInfo, error,
) {
return f.fetchImageInfoInfo, f.fetchImageInfoErr
}
func must(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// roundTripFunc .
type roundTripFunc func(req *http.Request) *http.Response
// RoundTrip .
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}
// newTestClient returns *http.Client with Transport replaced to avoid making real calls.
func newTestClient(fn roundTripFunc) *http.Client {
return &http.Client{
Transport: fn,
}
}

View file

@ -1,143 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package image
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"net/url"
"os"
"path/filepath"
"github.com/schollz/progressbar/v3"
"github.com/spf13/afero"
)
// Downloader downloads raw images.
type Downloader struct {
httpc httpc
fs *afero.Afero
}
// NewDownloader creates a new Downloader.
func NewDownloader() *Downloader {
return &Downloader{
httpc: http.DefaultClient,
fs: &afero.Afero{Fs: afero.NewOsFs()},
}
}
// Download downloads the raw image from source.
func (d *Downloader) Download(ctx context.Context, errWriter io.Writer, showBar bool, source, imageName string) (string, error) {
url, err := url.Parse(source)
if err != nil {
return "", fmt.Errorf("parsing image source URL: %w", err)
}
imageName = filepath.Base(imageName)
var partfile, destination string
switch url.Scheme {
case "http", "https":
cwd, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("getting current working directory: %w", err)
}
partfile = filepath.Join(cwd, imageName+".raw.part")
destination = filepath.Join(cwd, imageName+".raw")
case "file":
return url.Path, nil
default:
return "", fmt.Errorf("unsupported image source URL scheme: %s", url.Scheme)
}
if !d.shouldDownload(destination) {
return destination, nil
}
if err := d.downloadWithProgress(ctx, errWriter, showBar, source, partfile); err != nil {
return "", err
}
return destination, d.fs.Rename(partfile, destination)
}
// shouldDownload checks if the image should be downloaded.
func (d *Downloader) shouldDownload(destination string) bool {
_, err := d.fs.Stat(destination)
return errors.Is(err, fs.ErrNotExist)
}
// downloadWithProgress downloads the raw image from source to the destination.
func (d *Downloader) downloadWithProgress(ctx context.Context, errWriter io.Writer, showBar bool, source, destination string) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, source, nil)
if err != nil {
return fmt.Errorf("creating request: %w", err)
}
resp, err := d.httpc.Do(req)
if err != nil {
return fmt.Errorf("doing request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("downloading from %q: %s", source, resp.Status)
}
f, err := d.fs.OpenFile(destination, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return err
}
defer f.Close()
var bar io.WriteCloser
if showBar {
bar = prepareBar(errWriter, resp.ContentLength)
} else {
bar = &nopWriteCloser{}
}
defer bar.Close()
_, err = io.Copy(io.MultiWriter(f, bar), resp.Body)
if err != nil {
return err
}
return nil
}
func prepareBar(writer io.Writer, total int64) io.WriteCloser {
return progressbar.NewOptions64(
total,
progressbar.OptionSetWriter(writer),
progressbar.OptionShowBytes(true),
progressbar.OptionSetPredictTime(true),
progressbar.OptionFullWidth(),
progressbar.OptionSetTheme(progressbar.Theme{
Saucer: "=",
SaucerHead: ">",
SaucerPadding: " ",
BarStart: "[",
BarEnd: "]",
}),
progressbar.OptionClearOnFinish(),
progressbar.OptionOnCompletion(func() { fmt.Fprintf(writer, "Done.\n\n") }),
)
}
type nopWriteCloser struct{}
func (*nopWriteCloser) Write(p []byte) (int, error) {
return len(p), nil
}
func (*nopWriteCloser) Close() error {
return nil
}
type httpc interface {
Do(req *http.Request) (*http.Response, error)
}

View file

@ -1,204 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package image
import (
"bytes"
"context"
"io"
"net/http"
"os"
"path"
"testing"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
)
func TestShouldDownload(t *testing.T) {
testCases := map[string]struct {
partfile, destination string
wantDownload bool
}{
"no files exist yet": {
wantDownload: true,
},
"partial download": {
partfile: "some data",
wantDownload: true,
},
"download succeeded": {
destination: "all of the data",
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
downloader := &Downloader{
fs: newDownloaderStubFs(t, "someVersion", tc.partfile, tc.destination),
}
gotDownload := downloader.shouldDownload("someVersion.raw")
assert.Equal(tc.wantDownload, gotDownload)
})
}
}
func TestDownloadWithProgress(t *testing.T) {
rawImage := "raw image"
client := newTestClient(func(req *http.Request) *http.Response {
if req.URL.String() == "https://cdn.example.com/image.raw" {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(rawImage)),
Header: make(http.Header),
}
}
return &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(bytes.NewBufferString("Not found.")),
Header: make(http.Header),
}
})
testCases := map[string]struct {
source string
wantErr bool
}{
"correct file requested": {
source: "https://cdn.example.com/image.raw",
},
"incorrect file requested": {
source: "https://cdn.example.com/incorrect.raw",
wantErr: true,
},
"invalid scheme": {
source: "xyz://",
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
fs := newDownloaderStubFs(t, "someVersion", "", "")
downloader := &Downloader{
httpc: client,
fs: fs,
}
var outBuffer bytes.Buffer
err := downloader.downloadWithProgress(context.Background(), &outBuffer, false, tc.source, "someVersion.raw")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
out, err := fs.ReadFile("someVersion.raw")
assert.NoError(err)
assert.Equal(rawImage, string(out))
})
}
}
func TestDownload(t *testing.T) {
rawImage := "raw image"
cwd, err := os.Getwd()
assert.NoError(t, err)
wantDestination := path.Join(cwd, "someVersion.raw")
client := newTestClient(func(req *http.Request) *http.Response {
if req.URL.String() == "https://cdn.example.com/image.raw" {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(rawImage)),
Header: make(http.Header),
}
}
return &http.Response{
StatusCode: http.StatusNotFound,
Body: io.NopCloser(bytes.NewBufferString("Not found.")),
Header: make(http.Header),
}
})
testCases := map[string]struct {
source string
destination string
overrideFile string
wantErr bool
}{
"correct file requested": {
source: "https://cdn.example.com/image.raw",
},
"file url": {
source: "file:///override.raw",
overrideFile: "override image",
},
"file exists": {
source: "https://cdn.example.com/image.raw",
destination: "already exists",
},
"incorrect file requested": {
source: "https://cdn.example.com/incorrect.raw",
wantErr: true,
},
"invalid scheme": {
source: "xyz://",
wantErr: true,
},
"invalid URL": {
source: "\x00",
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
fs := newDownloaderStubFs(t, cwd+"/someVersion", "", tc.destination)
if tc.overrideFile != "" {
must(t, fs.WriteFile("/override.raw", []byte(tc.overrideFile), os.ModePerm))
}
downloader := &Downloader{
httpc: client,
fs: fs,
}
var outBuffer bytes.Buffer
gotDestination, err := downloader.Download(context.Background(), &outBuffer, false, tc.source, "someVersion")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
if tc.overrideFile == "" {
assert.Equal(wantDestination, gotDestination)
} else {
assert.Equal("/override.raw", gotDestination)
}
out, err := fs.ReadFile(gotDestination)
assert.NoError(err)
switch {
case tc.overrideFile != "":
assert.Equal(tc.overrideFile, string(out))
case tc.destination != "":
assert.Equal(tc.destination, string(out))
default:
assert.Equal(rawImage, string(out))
}
})
}
}
func newDownloaderStubFs(t *testing.T, version, partfile, destination string) *afero.Afero {
fs := afero.NewMemMapFs()
if partfile != "" {
must(t, afero.WriteFile(fs, version+".raw.part", []byte(partfile), os.ModePerm))
}
if destination != "" {
must(t, afero.WriteFile(fs, version+".raw", []byte(destination), os.ModePerm))
}
return &afero.Afero{Fs: fs}
}

View file

@ -12,14 +12,15 @@ go_library(
visibility = ["//cli:__subpackages__"],
deps = [
"//cli/internal/helm",
"//cli/internal/image",
"//cli/internal/terraform",
"//cli/internal/upgrade",
"//internal/attestation/measurements",
"//internal/cloud/cloudprovider",
"//internal/compatibility",
"//internal/config",
"//internal/constants",
"//internal/file",
"//internal/imagefetcher",
"//internal/kubernetes",
"//internal/kubernetes/kubectl",
"//internal/variant",
@ -45,10 +46,12 @@ go_test(
embed = [":kubernetes"],
deps = [
"//internal/attestation/measurements",
"//internal/cloud/cloudprovider",
"//internal/compatibility",
"//internal/config",
"//internal/constants",
"//internal/logger",
"//internal/variant",
"//internal/versions",
"//internal/versions/components",
"//operators/constellation-node-operator/api/v1alpha1",

View file

@ -16,14 +16,15 @@ import (
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/image"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/cli/internal/upgrade"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/compatibility"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/imagefetcher"
internalk8s "github.com/edgelesssys/constellation/v2/internal/kubernetes"
"github.com/edgelesssys/constellation/v2/internal/kubernetes/kubectl"
"github.com/edgelesssys/constellation/v2/internal/variant"
@ -122,7 +123,7 @@ func NewUpgrader(ctx context.Context, outWriter io.Writer, log debugLog) (*Upgra
stableInterface: &stableClient{client: kubeClient},
dynamicInterface: &NodeVersionClient{client: unstructuredClient},
helmClient: helmClient,
imageFetcher: image.New(),
imageFetcher: imagefetcher.New(),
outWriter: outWriter,
tfUpgrader: tfUpgrader,
log: log,
@ -164,7 +165,10 @@ func (u *Upgrader) UpgradeHelmServices(ctx context.Context, config *config.Confi
// UpgradeNodeVersion upgrades the cluster's NodeVersion object and in turn triggers image & k8s version upgrades.
// The versions set in the config are validated against the versions running in the cluster.
func (u *Upgrader) UpgradeNodeVersion(ctx context.Context, conf *config.Config) error {
imageReference, err := u.imageFetcher.FetchReference(ctx, conf)
provider := conf.GetProvider()
attestationVariant := conf.GetAttestationConfig().GetVariant()
region := conf.GetRegion()
imageReference, err := u.imageFetcher.FetchReference(ctx, provider, attestationVariant, conf.Image, region)
if err != nil {
return fmt.Errorf("fetching image reference: %w", err)
}
@ -526,5 +530,8 @@ type debugLog interface {
// imageFetcher gets an image reference from the versionsapi.
type imageFetcher interface {
FetchReference(ctx context.Context, config *config.Config) (string, error)
FetchReference(ctx context.Context,
provider cloudprovider.Provider, attestationVariant variant.Variant,
image, region string,
) (string, error)
}

View file

@ -14,10 +14,12 @@ import (
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/compatibility"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/variant"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/edgelesssys/constellation/v2/internal/versions/components"
updatev1alpha1 "github.com/edgelesssys/constellation/v2/operators/constellation-node-operator/v2/api/v1alpha1"
@ -550,6 +552,9 @@ type stubImageFetcher struct {
fetchReferenceErr error
}
func (f *stubImageFetcher) FetchReference(_ context.Context, _ *config.Config) (string, error) {
func (f *stubImageFetcher) FetchReference(_ context.Context,
_ cloudprovider.Provider, _ variant.Variant,
_, _ string,
) (string, error) {
return f.reference, f.fetchReferenceErr
}