diff --git a/cli/internal/cmd/apply.go b/cli/internal/cmd/apply.go index 4b7fb51a2..e5c8caeac 100644 --- a/cli/internal/cmd/apply.go +++ b/cli/internal/cmd/apply.go @@ -41,6 +41,7 @@ import ( "github.com/spf13/afero" "github.com/spf13/cobra" "github.com/spf13/pflag" + xsemver "golang.org/x/mod/semver" apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" k8serrors "k8s.io/apimachinery/pkg/api/errors" ) @@ -545,9 +546,19 @@ func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationc return nil, nil, fmt.Errorf("aborted by user") } } + a.flags.skipPhases.add(skipK8sPhase) a.log.Debugf("Outdated Kubernetes version accepted, Kubernetes upgrade will be skipped") } + + validVersionString, err := versions.ResolveK8sPatchVersion(xsemver.MajorMinor(string(conf.KubernetesVersion))) + if err != nil { + return nil, nil, fmt.Errorf("resolving Kubernetes patch version: %w", err) + } + validVersion, err = versions.NewValidK8sVersion(validVersionString, true) + if err != nil { + return nil, nil, fmt.Errorf("parsing Kubernetes version: %w", err) + } } if versions.IsPreviewK8sVersion(validVersion) { cmd.PrintErrf("Warning: Constellation with Kubernetes %s is still in preview. Use only for evaluation purposes.\n", validVersion) diff --git a/cli/internal/cmd/apply_test.go b/cli/internal/cmd/apply_test.go index dfc0f0344..a8e010a03 100644 --- a/cli/internal/cmd/apply_test.go +++ b/cli/internal/cmd/apply_test.go @@ -28,6 +28,7 @@ import ( "github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/kms/uri" "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/internal/versions" "github.com/spf13/afero" "github.com/spf13/pflag" "github.com/stretchr/testify/assert" @@ -291,6 +292,7 @@ func TestValidateInputs(t *testing.T) { stdin string flags applyFlags wantPhases skipPhases + assert func(require *require.Assertions, assert *assert.Assertions, conf *config.Config, stateFile *state.State) wantErr bool }{ "[upgrade] gcp: all files exist": { @@ -396,6 +398,28 @@ func TestValidateInputs(t *testing.T) { }, wantPhases: newPhases(skipInfrastructurePhase, skipImagePhase, skipK8sPhase), }, + "[upgrade] k8s patch version no longer supported, user confirms to skip k8s and continue upgrade. Valid K8s patch version is used in config afterwards": { + createConfig: func(require *require.Assertions, fh file.Handler) { + cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP) + + // use first version in list (oldest) as it should never have a patch version + versionParts := strings.Split(versions.SupportedK8sVersions()[0], ".") + versionParts[len(versionParts)-1] = "0" + cfg.KubernetesVersion = versions.ValidK8sVersion(strings.Join(versionParts, ".")) + require.NoError(fh.WriteYAML(constants.ConfigFilename, cfg)) + }, + createState: postInitState(cloudprovider.GCP), + createMasterSecret: defaultMasterSecret, + createAdminConfig: defaultAdminConfig, + createTfState: defaultTfState, + stdin: "y\n", + wantPhases: newPhases(skipInitPhase, skipK8sPhase), + assert: func(require *require.Assertions, assert *assert.Assertions, conf *config.Config, stateFile *state.State) { + assert.NotEmpty(conf.KubernetesVersion) + _, err := versions.NewValidK8sVersion(string(conf.KubernetesVersion), true) + assert.NoError(err) + }, + }, } for name, tc := range testCases { @@ -423,7 +447,7 @@ func TestValidateInputs(t *testing.T) { flags: tc.flags, } - _, _, err := a.validateInputs(cmd, &stubAttestationFetcher{}) + conf, state, err := a.validateInputs(cmd, &stubAttestationFetcher{}) if tc.wantErr { assert.Error(err) return @@ -434,6 +458,10 @@ func TestValidateInputs(t *testing.T) { t.Log(cfgErr.LongMessage()) } assert.Equal(tc.wantPhases, a.flags.skipPhases) + + if tc.assert != nil { + tc.assert(require, assert, conf, state) + } }) } } diff --git a/internal/versions/components/BUILD.bazel b/internal/versions/components/BUILD.bazel index 46f484e18..abbc389a2 100644 --- a/internal/versions/components/BUILD.bazel +++ b/internal/versions/components/BUILD.bazel @@ -1,6 +1,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") load("@rules_proto//proto:defs.bzl", "proto_library") +load("//bazel/go:go_test.bzl", "go_test") load("//bazel/proto:rules.bzl", "write_go_proto_srcs") go_library( @@ -30,3 +31,13 @@ write_go_proto_srcs( go_proto_library = ":components_go_proto", visibility = ["//visibility:public"], ) + +go_test( + name = "components_test", + srcs = ["components_test.go"], + embed = [":components"], + deps = [ + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) diff --git a/internal/versions/components/components.go b/internal/versions/components/components.go index 4f22eb870..4520ee7f0 100644 --- a/internal/versions/components/components.go +++ b/internal/versions/components/components.go @@ -8,6 +8,7 @@ package components import ( "crypto/sha256" + "encoding/json" "errors" "fmt" "strings" @@ -16,6 +17,53 @@ import ( // Components is a list of Kubernetes components. type Components []*Component +type legacyComponent struct { + URL string `json:"URL,omitempty"` + Hash string `json:"Hash,omitempty"` + InstallPath string `json:"InstallPath,omitempty"` + Extract bool `json:"Extract,omitempty"` +} + +// UnmarshalJSON implements a custom JSON unmarshaler to ensure backwards compatibility +// with older components lists which had a different format for all keys. +func (c *Components) UnmarshalJSON(b []byte) error { + var legacyComponents []*legacyComponent + if err := json.Unmarshal(b, &legacyComponents); err != nil { + return err + } + var components []*Component + if err := json.Unmarshal(b, &components); err != nil { + return err + } + + if len(legacyComponents) != len(components) { + return errors.New("failed to unmarshal data: inconsistent number of components in list") // just a check, should never happen + } + + // If a value is not set in the new format, + // it might have been set in the old format. + // In this case, we copy the value from the old format. + comps := make(Components, len(components)) + for idx := 0; idx < len(components); idx++ { + comps[idx] = components[idx] + if comps[idx].Url == "" { + comps[idx].Url = legacyComponents[idx].URL + } + if comps[idx].Hash == "" { + comps[idx].Hash = legacyComponents[idx].Hash + } + if comps[idx].InstallPath == "" { + comps[idx].InstallPath = legacyComponents[idx].InstallPath + } + if !comps[idx].Extract { + comps[idx].Extract = legacyComponents[idx].Extract + } + } + + *c = comps + return nil +} + // GetHash returns the hash over all component hashes. func (c Components) GetHash() string { sha := sha256.New() diff --git a/internal/versions/components/components_test.go b/internal/versions/components/components_test.go new file mode 100644 index 000000000..e32516ec9 --- /dev/null +++ b/internal/versions/components/components_test.go @@ -0,0 +1,31 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package components + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnmarshalComponents(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + legacyFormat := `[{"URL":"https://example.com/foo.tar.gz","Hash":"1234567890","InstallPath":"/foo","Extract":true}]` + newFormat := `[{"url":"https://example.com/foo.tar.gz","hash":"1234567890","install_path":"/foo","extract":true}]` + + var fromLegacy Components + require.NoError(json.Unmarshal([]byte(legacyFormat), &fromLegacy)) + + var fromNew Components + require.NoError(json.Unmarshal([]byte(newFormat), &fromNew)) + + assert.Equal(fromLegacy, fromNew) +}