cli: common backend for init and upgrade apply commands (#2449)

* Use common 'apply' backend for init and upgrades
* Move unit tests to new apply backend
* Only perform Terraform migrations if state exists in cwd (#2457)
* Rework skipPhases logic

---------

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
Co-authored-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>
This commit is contained in:
Daniel Weiße 2023-10-24 15:39:18 +02:00 committed by GitHub
parent 15d249092c
commit 671cf36f0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1399 additions and 1018 deletions

View File

@ -21,7 +21,6 @@ go_library(
importpath = "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd",
visibility = ["//cli:__subpackages__"],
deps = [
"//cli/internal/cmd/pathprefix",
"//cli/internal/libvirt",
"//cli/internal/state",
"//cli/internal/terraform",

View File

@ -9,7 +9,6 @@ package cloudcmd
import (
"fmt"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/internal/cloud/azureshared"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
@ -19,27 +18,19 @@ import (
)
// GetMarshaledServiceAccountURI returns the service account URI for the given cloud provider.
func GetMarshaledServiceAccountURI(provider cloudprovider.Provider, config *config.Config, pf pathprefix.PathPrefixer, log debugLog, fileHandler file.Handler,
) (string, error) {
log.Debugf("Getting service account URI")
switch provider {
func GetMarshaledServiceAccountURI(config *config.Config, fileHandler file.Handler) (string, error) {
switch config.GetProvider() {
case cloudprovider.GCP:
log.Debugf("Handling case for GCP")
log.Debugf("GCP service account key path %s", pf.PrefixPrintablePath(config.Provider.GCP.ServiceAccountKeyPath))
var key gcpshared.ServiceAccountKey
if err := fileHandler.ReadJSON(config.Provider.GCP.ServiceAccountKeyPath, &key); err != nil {
return "", fmt.Errorf("reading service account key from path %q: %w", pf.PrefixPrintablePath(config.Provider.GCP.ServiceAccountKeyPath), err)
return "", fmt.Errorf("reading service account key: %w", err)
}
log.Debugf("Read GCP service account key from path")
return key.ToCloudServiceAccountURI(), nil
case cloudprovider.AWS:
log.Debugf("Handling case for AWS")
return "", nil // AWS does not need a service account URI
case cloudprovider.Azure:
log.Debugf("Handling case for Azure")
case cloudprovider.Azure:
authMethod := azureshared.AuthMethodUserAssignedIdentity
creds := azureshared.ApplicationCredentials{
@ -64,10 +55,9 @@ func GetMarshaledServiceAccountURI(provider cloudprovider.Provider, config *conf
return creds.ToCloudServiceAccountURI(), nil
case cloudprovider.QEMU:
log.Debugf("Handling case for QEMU")
return "", nil // QEMU does not use service account keys
default:
return "", fmt.Errorf("unsupported cloud provider %q", provider)
return "", fmt.Errorf("unsupported cloud provider %q", config.GetProvider())
}
}

View File

@ -4,6 +4,10 @@ load("//bazel/go:go_test.bzl", "go_test")
go_library(
name = "cmd",
srcs = [
"apply.go",
"applyhelm.go",
"applyinit.go",
"applyterraform.go",
"cloud.go",
"cmd.go",
"config.go",
@ -94,6 +98,7 @@ go_library(
"@com_github_spf13_pflag//:pflag",
"@in_gopkg_yaml_v3//:yaml_v3",
"@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions",
"@io_k8s_apimachinery//pkg/api/errors",
"@io_k8s_apimachinery//pkg/runtime",
"@io_k8s_client_go//tools/clientcmd",
"@io_k8s_client_go//tools/clientcmd/api/latest",
@ -115,6 +120,7 @@ go_library(
go_test(
name = "cmd_test",
srcs = [
"apply_test.go",
"cloud_test.go",
"configfetchmeasurements_test.go",
"configgenerate_test.go",
@ -171,12 +177,15 @@ go_test(
"@com_github_google_go_tpm_tools//proto/tpm",
"@com_github_spf13_afero//:afero",
"@com_github_spf13_cobra//:cobra",
"@com_github_spf13_pflag//:pflag",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//mock",
"@com_github_stretchr_testify//require",
"@io_k8s_api//core/v1:core",
"@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions",
"@io_k8s_apimachinery//pkg/api/errors",
"@io_k8s_apimachinery//pkg/apis/meta/v1:meta",
"@io_k8s_apimachinery//pkg/runtime/schema",
"@io_k8s_client_go//tools/clientcmd",
"@io_k8s_client_go//tools/clientcmd/api",
"@org_golang_google_grpc//:go_default_library",

515
cli/internal/cmd/apply.go Normal file
View File

@ -0,0 +1,515 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/fs"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/compatibility"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
)
// applyFlags defines the flags for the apply command.
type applyFlags struct {
rootFlags
yes bool
conformance bool
mergeConfigs bool
upgradeTimeout time.Duration
helmWaitMode helm.WaitMode
skipPhases skipPhases
}
// parse the apply command flags.
func (f *applyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
rawSkipPhases, err := flags.GetStringSlice("skip-phases")
if err != nil {
return fmt.Errorf("getting 'skip-phases' flag: %w", err)
}
var skipPhases skipPhases
for _, phase := range rawSkipPhases {
switch skipPhase(strings.ToLower(phase)) {
case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase:
skipPhases.add(skipPhase(phase))
default:
return fmt.Errorf("invalid phase %s", phase)
}
}
f.skipPhases = skipPhases
f.yes, err = flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.upgradeTimeout, err = flags.GetDuration("timeout")
if err != nil {
return fmt.Errorf("getting 'timeout' flag: %w", err)
}
f.conformance, err = flags.GetBool("conformance")
if err != nil {
return fmt.Errorf("getting 'conformance' flag: %w", err)
}
skipHelmWait, err := flags.GetBool("skip-helm-wait")
if err != nil {
return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err)
}
f.helmWaitMode = helm.WaitModeAtomic
if skipHelmWait {
f.helmWaitMode = helm.WaitModeNone
}
f.mergeConfigs, err = flags.GetBool("merge-kubeconfig")
if err != nil {
return fmt.Errorf("getting 'merge-kubeconfig' flag: %w", err)
}
return nil
}
// runApply sets up the apply command and runs it.
func runApply(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
}
defer log.Sync()
spinner, err := newSpinnerOrStderr(cmd)
if err != nil {
return err
}
defer spinner.Stop()
flags := applyFlags{}
if err := flags.parse(cmd.Flags()); err != nil {
return err
}
fileHandler := file.NewHandler(afero.NewOsFs())
newDialer := func(validator atls.Validator) *dialer.Dialer {
return dialer.New(nil, validator, &net.Dialer{})
}
newKubeUpgrader := func(w io.Writer, kubeConfigPath string, log debugLog) (kubernetesUpgrader, error) {
return kubecmd.New(w, kubeConfigPath, fileHandler, log)
}
newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) {
return helm.NewClient(kubeConfigPath, log)
}
upgradeID := generateUpgradeID(upgradeCmdKindApply)
upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID)
clusterUpgrader, err := cloudcmd.NewClusterUpgrader(
cmd.Context(),
constants.TerraformWorkingDir,
upgradeDir,
flags.tfLogLevel,
fileHandler,
)
if err != nil {
return fmt.Errorf("setting up cluster upgrader: %w", err)
}
apply := &applyCmd{
fileHandler: fileHandler,
flags: flags,
log: log,
spinner: spinner,
merger: &kubeconfigMerger{log: log},
quotaChecker: license.NewClient(),
newHelmClient: newHelmClient,
newDialer: newDialer,
newKubeUpgrader: newKubeUpgrader,
clusterUpgrader: clusterUpgrader,
}
ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour)
defer cancel()
cmd.SetContext(ctx)
return apply.apply(cmd, attestationconfigapi.NewFetcher(), upgradeDir)
}
type applyCmd struct {
fileHandler file.Handler
flags applyFlags
log debugLog
spinner spinnerInterf
merger configMerger
quotaChecker license.QuotaChecker
newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error)
newDialer func(validator atls.Validator) *dialer.Dialer
newKubeUpgrader func(io.Writer, string, debugLog) (kubernetesUpgrader, error)
clusterUpgrader clusterUpgrader
}
/*
apply updates a Constellation cluster by applying a user's config.
The control flow is as follows:
Parse Flags
Read Config
Read State-File
Validate input
Check if Terraform state is up to date
Not up to date
(Diff from Terraform plan)
Terraform
Phase
Apply Terraform updates
Check for constellation-admin.conf
File does not exist
Run Bootstrapper Init RPC
File does exist
Init
Write constellation-admin.conf Phase
Prepare "Init Success" Message
Apply Attestation Config
Extend API Server Cert SANs
Helm
Apply Helm Charts Phase
Can be skipped Upgrade NodeVersion object K8s/Image
if we ran Init RP (Image and K8s update) Phase
Write success output
*/
func (a *applyCmd) apply(cmd *cobra.Command, configFetcher attestationconfigapi.Fetcher, upgradeDir string) error {
// Validate inputs
conf, stateFile, err := a.validateInputs(cmd, configFetcher)
if err != nil {
return err
}
// Now start actually running the apply command
// Check current Terraform state, if it exists and infrastructure upgrades are not skipped,
// and apply migrations if necessary.
if !a.flags.skipPhases.contains(skipInfrastructurePhase) {
if err := a.runTerraformApply(cmd, conf, stateFile, upgradeDir); err != nil {
return fmt.Errorf("applying Terraform configuration : %w", err)
}
}
bufferedOutput := &bytes.Buffer{}
// Run init RPC if required
if !a.flags.skipPhases.contains(skipInitPhase) {
bufferedOutput, err = a.runInit(cmd, conf, stateFile)
if err != nil {
return err
}
}
// From now on we can assume a valid Kubernetes admin config file exists
kubeUpgrader, err := a.newKubeUpgrader(cmd.OutOrStdout(), constants.AdminConfFilename, a.log)
if err != nil {
return err
}
// Apply Attestation Config
a.log.Debugf("Creating Kubernetes client using %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
a.log.Debugf("Applying new attestation config to cluster")
if err := a.applyJoinConfig(cmd, kubeUpgrader, conf.GetAttestationConfig(), stateFile.ClusterValues.MeasurementSalt); err != nil {
return fmt.Errorf("applying attestation config: %w", err)
}
// Extend API Server Cert SANs
sans := append([]string{stateFile.Infrastructure.ClusterEndpoint, conf.CustomEndpoint}, stateFile.Infrastructure.APIServerCertSANs...)
if err := kubeUpgrader.ExtendClusterConfigCertSANs(cmd.Context(), sans); err != nil {
return fmt.Errorf("extending cert SANs: %w", err)
}
// Apply Helm Charts
if !a.flags.skipPhases.contains(skipHelmPhase) {
if err := a.runHelmApply(cmd, conf, stateFile, kubeUpgrader, upgradeDir); err != nil {
return err
}
}
// Upgrade NodeVersion object
// This can be skipped if we ran the init RPC, as the NodeVersion object is already up to date
if !(a.flags.skipPhases.contains(skipK8sPhase) && a.flags.skipPhases.contains(skipImagePhase)) &&
a.flags.skipPhases.contains(skipInitPhase) {
if err := a.runK8sUpgrade(cmd, conf, kubeUpgrader); err != nil {
return err
}
}
// Write success output
cmd.Print(bufferedOutput.String())
return nil
}
func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationconfigapi.Fetcher) (*config.Config, *state.State, error) {
// Read user's config and state file
a.log.Debugf("Reading config from %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(a.fileHandler, constants.ConfigFilename, configFetcher, a.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
}
if err != nil {
return nil, nil, err
}
a.log.Debugf("Reading state file from %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
stateFile, err := state.ReadFromFile(a.fileHandler, constants.StateFilename)
if err != nil {
return nil, nil, err
}
// Check license
a.log.Debugf("Running license check")
checker := license.NewChecker(a.quotaChecker, a.fileHandler)
if err := checker.CheckLicense(cmd.Context(), conf.GetProvider(), conf.Provider, cmd.Printf); err != nil {
cmd.PrintErrf("License check failed: %v", err)
}
a.log.Debugf("Checked license")
// Check if we already have a running Kubernetes cluster
// by checking if the Kubernetes admin config file exists
// If it exist, we skip the init phase
// If it does not exist, we need to run the init RPC first
// This may break things further down the line
// It is the user's responsibility to make sure the cluster is in a valid state
a.log.Debugf("Checking if %s exists", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
if _, err := a.fileHandler.Stat(constants.AdminConfFilename); err == nil {
a.flags.skipPhases.add(skipInitPhase)
} else if !errors.Is(err, os.ErrNotExist) {
return nil, nil, fmt.Errorf("checking for %s: %w", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename), err)
}
a.log.Debugf("Init RPC required: %t", !a.flags.skipPhases.contains(skipInitPhase))
// Validate input arguments
// Validate Kubernetes version as set in the user's config
// If we need to run the init RPC, the version has to be valid
// Otherwise, we are able to use an outdated version, meaning we skip the K8s upgrade
// We skip version validation if the user explicitly skips the Kubernetes phase
a.log.Debugf("Validating Kubernetes version %s", conf.KubernetesVersion)
validVersion, err := versions.NewValidK8sVersion(string(conf.KubernetesVersion), true)
if err != nil && !a.flags.skipPhases.contains(skipK8sPhase) {
a.log.Debugf("Kubernetes version not valid: %s", err)
if !a.flags.skipPhases.contains(skipInitPhase) {
return nil, nil, err
}
a.log.Debugf("Checking if user wants to continue anyway")
if !a.flags.yes {
confirmed, err := askToConfirm(cmd,
fmt.Sprintf(
"WARNING: The Kubernetes patch version %s is not supported. If you continue, Kubernetes upgrades will be skipped. Do you want to continue anyway?",
validVersion,
),
)
if err != nil {
return nil, nil, fmt.Errorf("asking for confirmation: %w", err)
}
if !confirmed {
return nil, nil, fmt.Errorf("aborted by user")
}
}
a.flags.skipPhases.add(skipK8sPhase)
a.log.Debugf("Outdated Kubernetes version accepted, Kubernetes upgrade will be skipped")
}
if versions.IsPreviewK8sVersion(validVersion) {
cmd.PrintErrf("Warning: Constellation with Kubernetes %s is still in preview. Use only for evaluation purposes.\n", validVersion)
}
conf.KubernetesVersion = validVersion
a.log.Debugf("Target Kubernetes version set to %s", conf.KubernetesVersion)
// Validate microservice version (helm versions) in the user's config matches the version of the CLI
// This makes sure we catch potential errors early, not just after we already ran Terraform migrations or the init RPC
if !a.flags.force && !a.flags.skipPhases.contains(skipHelmPhase) && !a.flags.skipPhases.contains(skipInitPhase) {
if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil {
return nil, nil, err
}
}
// Constellation on QEMU or OpenStack don't support upgrades
// If using one of those providers, make sure the command is only used to initialize a cluster
if !(conf.GetProvider() == cloudprovider.AWS || conf.GetProvider() == cloudprovider.Azure || conf.GetProvider() == cloudprovider.GCP) {
if a.flags.skipPhases.contains(skipInitPhase) {
return nil, nil, fmt.Errorf("upgrades are not supported for provider %s", conf.GetProvider())
}
// Skip Terraform phase
a.log.Debugf("Skipping Infrastructure upgrade")
a.flags.skipPhases.add(skipInfrastructurePhase)
}
// Check if Terraform state exists
if tfStateExists, err := a.tfStateExists(); err != nil {
return nil, nil, fmt.Errorf("checking Terraform state: %w", err)
} else if !tfStateExists {
a.flags.skipPhases.add(skipInfrastructurePhase)
a.log.Debugf("No Terraform state found in current working directory. Assuming self-managed infrastructure. Infrastructure upgrades will not be performed.")
}
// Print warning about AWS attestation
// TODO(derpsteb): remove once AWS fixes SEV-SNP attestation provisioning issues
if !a.flags.skipPhases.contains(skipInitPhase) && conf.GetAttestationConfig().GetVariant().Equal(variant.AWSSEVSNP{}) {
cmd.PrintErrln("WARNING: Attestation temporarily relies on AWS nitroTPM. See https://docs.edgeless.systems/constellation/workflows/config#choosing-a-vm-type for more information.")
}
return conf, stateFile, nil
}
// applyJoincConfig creates or updates the cluster's join config.
// If the config already exists, and is different from the new config, the user is asked to confirm the upgrade.
func (a *applyCmd) applyJoinConfig(
cmd *cobra.Command, kubeUpgrader kubernetesUpgrader, newConfig config.AttestationCfg, measurementSalt []byte,
) error {
clusterAttestationConfig, err := kubeUpgrader.GetClusterAttestationConfig(cmd.Context(), newConfig.GetVariant())
if err != nil {
a.log.Debugf("Getting cluster attestation config failed: %s", err)
if k8serrors.IsNotFound(err) {
a.log.Debugf("Creating new join config")
return kubeUpgrader.ApplyJoinConfig(cmd.Context(), newConfig, measurementSalt)
}
return fmt.Errorf("getting cluster attestation config: %w", err)
}
// If the current config is equal, or there is an error when comparing the configs, we skip the upgrade.
equal, err := newConfig.EqualTo(clusterAttestationConfig)
if err != nil {
return fmt.Errorf("comparing attestation configs: %w", err)
}
if equal {
a.log.Debugf("Current attestation config is equal to the new config, nothing to do")
return nil
}
cmd.Println("The configured attestation config is different from the attestation config in the cluster.")
diffStr, err := diffAttestationCfg(clusterAttestationConfig, newConfig)
if err != nil {
return fmt.Errorf("diffing attestation configs: %w", err)
}
cmd.Println("The following changes will be applied to the attestation config:")
cmd.Println(diffStr)
if !a.flags.yes {
ok, err := askToConfirm(cmd, "Are you sure you want to change your cluster's attestation config?")
if err != nil {
return fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
return errors.New("aborting upgrade since attestation config is different")
}
}
if err := kubeUpgrader.ApplyJoinConfig(cmd.Context(), newConfig, measurementSalt); err != nil {
return fmt.Errorf("updating attestation config: %w", err)
}
cmd.Println("Successfully updated the cluster's attestation config")
return nil
}
// runK8sUpgrade upgrades image and Kubernetes version of the Constellation cluster.
func (a *applyCmd) runK8sUpgrade(cmd *cobra.Command, conf *config.Config, kubeUpgrader kubernetesUpgrader,
) error {
err := kubeUpgrader.UpgradeNodeVersion(
cmd.Context(), conf, a.flags.force,
a.flags.skipPhases.contains(skipK8sPhase),
a.flags.skipPhases.contains(skipImagePhase),
)
var upgradeErr *compatibility.InvalidUpgradeError
switch {
case errors.Is(err, kubecmd.ErrInProgress):
cmd.PrintErrln("Skipping image and Kubernetes upgrades. Another upgrade is in progress.")
case errors.As(err, &upgradeErr):
cmd.PrintErrln(err)
case err != nil:
return fmt.Errorf("upgrading NodeVersion: %w", err)
}
return nil
}
// tfStateExists checks whether a Constellation Terraform state exists in the current working directory.
func (a *applyCmd) tfStateExists() (bool, error) {
_, err := a.fileHandler.Stat(constants.TerraformWorkingDir)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("reading Terraform state: %w", err)
}
return true, nil
}

View File

@ -0,0 +1,153 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"context"
"fmt"
"testing"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseApplyFlags(t *testing.T) {
require := require.New(t)
// TODO: Use flags := applyCmd().Flags() once we have a separate apply command
defaultFlags := func() *pflag.FlagSet {
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
flags.String("workspace", "", "")
flags.String("tf-log", "NONE", "")
flags.Bool("force", false, "")
flags.Bool("debug", false, "")
flags.Bool("merge-kubeconfig", false, "")
flags.Bool("conformance", false, "")
flags.Bool("skip-helm-wait", false, "")
flags.Bool("yes", false, "")
flags.StringSlice("skip-phases", []string{}, "")
flags.Duration("timeout", 0, "")
return flags
}
testCases := map[string]struct {
flags *pflag.FlagSet
wantFlags applyFlags
wantErr bool
}{
"default flags": {
flags: defaultFlags(),
wantFlags: applyFlags{
helmWaitMode: helm.WaitModeAtomic,
},
},
"skip phases": {
flags: func() *pflag.FlagSet {
flags := defaultFlags()
require.NoError(flags.Set("skip-phases", fmt.Sprintf("%s,%s", skipHelmPhase, skipK8sPhase)))
return flags
}(),
wantFlags: applyFlags{
skipPhases: skipPhases{skipHelmPhase: struct{}{}, skipK8sPhase: struct{}{}},
helmWaitMode: helm.WaitModeAtomic,
},
},
"skip helm wait": {
flags: func() *pflag.FlagSet {
flags := defaultFlags()
require.NoError(flags.Set("skip-helm-wait", "true"))
return flags
}(),
wantFlags: applyFlags{
helmWaitMode: helm.WaitModeNone,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
var flags applyFlags
err := flags.parse(tc.flags)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantFlags, flags)
})
}
}
func TestBackupHelmCharts(t *testing.T) {
testCases := map[string]struct {
helmApplier helm.Applier
backupClient *stubKubernetesUpgrader
includesUpgrades bool
wantErr bool
}{
"success, no upgrades": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{},
},
"success with upgrades": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{},
includesUpgrades: true,
},
"saving charts fails": {
helmApplier: &stubRunner{
saveChartsErr: assert.AnError,
},
backupClient: &stubKubernetesUpgrader{},
wantErr: true,
},
"backup CRDs fails": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{
backupCRDsErr: assert.AnError,
},
includesUpgrades: true,
wantErr: true,
},
"backup CRs fails": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{
backupCRsErr: assert.AnError,
},
includesUpgrades: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
a := applyCmd{
fileHandler: file.NewHandler(afero.NewMemMapFs()),
log: logger.NewTest(t),
}
err := a.backupHelmCharts(context.Background(), tc.backupClient, tc.helmApplier, tc.includesUpgrades, "")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
if tc.includesUpgrades {
assert.True(tc.backupClient.backupCRDsCalled)
assert.True(tc.backupClient.backupCRsCalled)
}
})
}
}

View File

@ -0,0 +1,125 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"context"
"errors"
"fmt"
"path/filepath"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/compatibility"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/spf13/cobra"
)
// runHelmApply handles installing or upgrading helm charts for the cluster.
func (a *applyCmd) runHelmApply(
cmd *cobra.Command, conf *config.Config, stateFile *state.State,
kubeUpgrader kubernetesUpgrader, upgradeDir string,
) error {
a.log.Debugf("Installing or upgrading Helm charts")
var masterSecret uri.MasterSecret
if err := a.fileHandler.ReadJSON(constants.MasterSecretFilename, &masterSecret); err != nil {
return fmt.Errorf("reading master secret: %w", err)
}
options := helm.Options{
Force: a.flags.force,
Conformance: a.flags.conformance,
HelmWaitMode: a.flags.helmWaitMode,
AllowDestructive: helm.DenyDestructive,
}
helmApplier, err := a.newHelmClient(constants.AdminConfFilename, a.log)
if err != nil {
return fmt.Errorf("creating Helm client: %w", err)
}
a.log.Debugf("Getting service account URI")
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(conf, a.fileHandler)
if err != nil {
return err
}
a.log.Debugf("Preparing Helm charts")
executor, includesUpgrades, err := helmApplier.PrepareApply(conf, stateFile, options, serviceAccURI, masterSecret)
if errors.Is(err, helm.ErrConfirmationMissing) {
if !a.flags.yes {
cmd.PrintErrln("WARNING: Upgrading cert-manager will destroy all custom resources you have manually created that are based on the current version of cert-manager.")
ok, askErr := askToConfirm(cmd, "Do you want to upgrade cert-manager anyway?")
if askErr != nil {
return fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
cmd.Println("Skipping upgrade.")
return nil
}
}
options.AllowDestructive = helm.AllowDestructive
executor, includesUpgrades, err = helmApplier.PrepareApply(conf, stateFile, options, serviceAccURI, masterSecret)
}
var upgradeErr *compatibility.InvalidUpgradeError
if err != nil {
if !errors.As(err, &upgradeErr) {
return fmt.Errorf("preparing Helm charts: %w", err)
}
cmd.PrintErrln(err)
}
a.log.Debugf("Backing up Helm charts")
if err := a.backupHelmCharts(cmd.Context(), kubeUpgrader, executor, includesUpgrades, upgradeDir); err != nil {
return err
}
a.log.Debugf("Applying Helm charts")
if !a.flags.skipPhases.contains(skipInitPhase) {
a.spinner.Start("Installing Kubernetes components ", false)
} else {
a.spinner.Start("Upgrading Kubernetes components ", false)
}
if err := executor.Apply(cmd.Context()); err != nil {
return fmt.Errorf("applying Helm charts: %w", err)
}
a.spinner.Stop()
if a.flags.skipPhases.contains(skipInitPhase) {
cmd.Println("Successfully upgraded Constellation services.")
}
return nil
}
// backupHelmCharts saves the Helm charts for the upgrade to disk and creates a backup of existing CRDs and CRs.
func (a *applyCmd) backupHelmCharts(
ctx context.Context, kubeUpgrader kubernetesUpgrader, executor helm.Applier, includesUpgrades bool, upgradeDir string,
) error {
// Save the Helm charts for the upgrade to disk
chartDir := filepath.Join(upgradeDir, "helm-charts")
if err := executor.SaveCharts(chartDir, a.fileHandler); err != nil {
return fmt.Errorf("saving Helm charts to disk: %w", err)
}
a.log.Debugf("Helm charts saved to %s", a.flags.pathPrefixer.PrefixPrintablePath(chartDir))
if includesUpgrades {
a.log.Debugf("Creating backup of CRDs and CRs")
crds, err := kubeUpgrader.BackupCRDs(ctx, upgradeDir)
if err != nil {
return fmt.Errorf("creating CRD backup: %w", err)
}
if err := kubeUpgrader.BackupCRs(ctx, crds, upgradeDir); err != nil {
return fmt.Errorf("creating CR backup: %w", err)
}
}
return nil
}

View File

@ -0,0 +1,238 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"bytes"
"context"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/url"
"path/filepath"
"strconv"
"text/tabwriter"
"time"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"k8s.io/client-go/tools/clientcmd"
)
// runInit runs the init RPC to set up the Kubernetes cluster.
// This function only needs to be run once per cluster.
// On success, it writes the Kubernetes admin config file to disk.
// Therefore it is skipped if the Kubernetes admin config file already exists.
func (a *applyCmd) runInit(cmd *cobra.Command, conf *config.Config, stateFile *state.State) (*bytes.Buffer, error) {
a.log.Debugf("Running init RPC")
a.log.Debugf("Creating aTLS Validator for %s", conf.GetAttestationConfig().GetVariant())
validator, err := cloudcmd.NewValidator(cmd, conf.GetAttestationConfig(), a.log)
if err != nil {
return nil, fmt.Errorf("creating new validator: %w", err)
}
a.log.Debugf("Generating master secret")
masterSecret, err := a.generateAndPersistMasterSecret(cmd.OutOrStdout())
if err != nil {
return nil, fmt.Errorf("generating master secret: %w", err)
}
a.log.Debugf("Generated master secret key and salt values")
a.log.Debugf("Generating measurement salt")
measurementSalt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return nil, fmt.Errorf("generating measurement salt: %w", err)
}
a.spinner.Start("Connecting ", false)
req := &initproto.InitRequest{
KmsUri: masterSecret.EncodeToURI(),
StorageUri: uri.NoStoreURI,
MeasurementSalt: measurementSalt,
KubernetesVersion: versions.VersionConfigs[conf.KubernetesVersion].ClusterVersion,
KubernetesComponents: versions.VersionConfigs[conf.KubernetesVersion].KubernetesComponents.ToInitProto(),
ConformanceMode: a.flags.conformance,
InitSecret: stateFile.Infrastructure.InitSecret,
ClusterName: stateFile.Infrastructure.Name,
ApiserverCertSans: stateFile.Infrastructure.APIServerCertSANs,
}
a.log.Debugf("Sending initialization request")
resp, err := a.initCall(cmd.Context(), a.newDialer(validator), stateFile.Infrastructure.ClusterEndpoint, req)
a.spinner.Stop()
a.log.Debugf("Initialization request finished")
if err != nil {
var nonRetriable *nonRetriableError
if errors.As(err, &nonRetriable) {
cmd.PrintErrln("Cluster initialization failed. This error is not recoverable.")
cmd.PrintErrln("Terminate your cluster and try again.")
if nonRetriable.logCollectionErr != nil {
cmd.PrintErrf("Failed to collect logs from bootstrapper: %s\n", nonRetriable.logCollectionErr)
} else {
cmd.PrintErrf("Fetched bootstrapper logs are stored in %q\n", a.flags.pathPrefixer.PrefixPrintablePath(constants.ErrorLog))
}
}
return nil, err
}
a.log.Debugf("Initialization request successful")
a.log.Debugf("Buffering init success message")
bufferedOutput := &bytes.Buffer{}
if err := a.writeInitOutput(stateFile, resp, a.flags.mergeConfigs, bufferedOutput, measurementSalt); err != nil {
return nil, err
}
return bufferedOutput, nil
}
// initCall performs the gRPC call to the bootstrapper to initialize the cluster.
func (a *applyCmd) initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.InitRequest) (*initproto.InitSuccessResponse, error) {
doer := &initDoer{
dialer: dialer,
endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.BootstrapperPort)),
req: req,
log: a.log,
spinner: a.spinner,
fh: file.NewHandler(afero.NewOsFs()),
}
// Create a wrapper function that allows logging any returned error from the retrier before checking if it's the expected retriable one.
serviceIsUnavailable := func(err error) bool {
isServiceUnavailable := grpcRetry.ServiceIsUnavailable(err)
a.log.Debugf("Encountered error (retriable: %t): %s", isServiceUnavailable, err)
return isServiceUnavailable
}
a.log.Debugf("Making initialization call, doer is %+v", doer)
retrier := retry.NewIntervalRetrier(doer, 30*time.Second, serviceIsUnavailable)
if err := retrier.Do(ctx); err != nil {
return nil, err
}
return doer.resp, nil
}
// generateAndPersistMasterSecret generates a 32 byte master secret and saves it to disk.
func (a *applyCmd) generateAndPersistMasterSecret(outWriter io.Writer) (uri.MasterSecret, error) {
// No file given, generate a new secret, and save it to disk
key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault)
if err != nil {
return uri.MasterSecret{}, err
}
salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return uri.MasterSecret{}, err
}
secret := uri.MasterSecret{
Key: key,
Salt: salt,
}
if err := a.fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
return uri.MasterSecret{}, err
}
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to %q\n", a.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename))
return secret, nil
}
// writeInitOutput writes the output of a cluster initialization to the
// state- / kubeconfig-file and saves it to disk.
func (a *applyCmd) writeInitOutput(
stateFile *state.State, initResp *initproto.InitSuccessResponse,
mergeConfig bool, wr io.Writer, measurementSalt []byte,
) error {
fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n")
ownerID := hex.EncodeToString(initResp.GetOwnerId())
clusterID := hex.EncodeToString(initResp.GetClusterId())
stateFile.SetClusterValues(state.ClusterValues{
MeasurementSalt: measurementSalt,
OwnerID: ownerID,
ClusterID: clusterID,
})
tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0)
writeRow(tw, "Constellation cluster identifier", clusterID)
writeRow(tw, "Kubernetes configuration", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
tw.Flush()
fmt.Fprintln(wr)
a.log.Debugf("Rewriting cluster server address in kubeconfig to %s", stateFile.Infrastructure.ClusterEndpoint)
kubeconfig, err := clientcmd.Load(initResp.GetKubeconfig())
if err != nil {
return fmt.Errorf("loading kubeconfig: %w", err)
}
if len(kubeconfig.Clusters) != 1 {
return fmt.Errorf("expected exactly one cluster in kubeconfig, got %d", len(kubeconfig.Clusters))
}
for _, cluster := range kubeconfig.Clusters {
kubeEndpoint, err := url.Parse(cluster.Server)
if err != nil {
return fmt.Errorf("parsing kubeconfig server URL: %w", err)
}
kubeEndpoint.Host = net.JoinHostPort(stateFile.Infrastructure.ClusterEndpoint, kubeEndpoint.Port())
cluster.Server = kubeEndpoint.String()
}
kubeconfigBytes, err := clientcmd.Write(*kubeconfig)
if err != nil {
return fmt.Errorf("marshaling kubeconfig: %w", err)
}
if err := a.fileHandler.Write(constants.AdminConfFilename, kubeconfigBytes, file.OptNone); err != nil {
return fmt.Errorf("writing kubeconfig: %w", err)
}
a.log.Debugf("Kubeconfig written to %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
if mergeConfig {
if err := a.merger.mergeConfigs(constants.AdminConfFilename, a.fileHandler); err != nil {
writeRow(tw, "Failed to automatically merge kubeconfig", err.Error())
mergeConfig = false // Set to false so we don't print the wrong message below.
} else {
writeRow(tw, "Kubernetes configuration merged with default config", "")
}
}
if err := stateFile.WriteToFile(a.fileHandler, constants.StateFilename); err != nil {
return fmt.Errorf("writing Constellation state file: %w", err)
}
a.log.Debugf("Constellation state file written to %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
if !mergeConfig {
fmt.Fprintln(wr, "You can now connect to your cluster by executing:")
exportPath, err := filepath.Abs(constants.AdminConfFilename)
if err != nil {
return fmt.Errorf("getting absolute path to kubeconfig: %w", err)
}
fmt.Fprintf(wr, "\texport KUBECONFIG=%q\n", exportPath)
} else {
fmt.Fprintln(wr, "Constellation kubeconfig merged with default config.")
if a.merger.kubeconfigEnvVar() != "" {
fmt.Fprintln(wr, "Warning: KUBECONFIG environment variable is set.")
fmt.Fprintln(wr, "You may need to unset it to use the default config and connect to your cluster.")
} else {
fmt.Fprintln(wr, "You can now connect to your cluster.")
}
}
fmt.Fprintln(wr) // add final newline
return nil
}

View File

@ -0,0 +1,115 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"fmt"
"path/filepath"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/spf13/cobra"
)
// runTerraformApply checks if changes to Terraform are required and applies them.
func (a *applyCmd) runTerraformApply(cmd *cobra.Command, conf *config.Config, stateFile *state.State, upgradeDir string) error {
a.log.Debugf("Checking if Terraform migrations are required")
migrationRequired, err := a.planTerraformMigration(cmd, conf)
if err != nil {
return fmt.Errorf("planning Terraform migrations: %w", err)
}
if !migrationRequired {
a.log.Debugf("No changes to infrastructure required, skipping Terraform migrations")
return nil
}
a.log.Debugf("Migrating terraform resources for infrastructure changes")
postMigrationInfraState, err := a.migrateTerraform(cmd, conf, upgradeDir)
if err != nil {
return fmt.Errorf("performing Terraform migrations: %w", err)
}
// Merge the pre-upgrade state with the post-migration infrastructure values
a.log.Debugf("Updating state file with new infrastructure state")
if _, err := stateFile.Merge(
// temporary state with post-migration infrastructure values
state.New().SetInfrastructure(postMigrationInfraState),
); err != nil {
return fmt.Errorf("merging pre-upgrade state with post-migration infrastructure values: %w", err)
}
// Write the post-migration state to disk
if err := stateFile.WriteToFile(a.fileHandler, constants.StateFilename); err != nil {
return fmt.Errorf("writing state file: %w", err)
}
return nil
}
// planTerraformMigration checks if the Constellation version the cluster is being upgraded to requires a migration.
func (a *applyCmd) planTerraformMigration(cmd *cobra.Command, conf *config.Config) (bool, error) {
a.log.Debugf("Planning Terraform migrations")
vars, err := cloudcmd.TerraformUpgradeVars(conf)
if err != nil {
return false, fmt.Errorf("parsing upgrade variables: %w", err)
}
a.log.Debugf("Using Terraform variables:\n%+v", vars)
// Check if there are any Terraform migrations to apply
// Add manual migrations here if required
//
// var manualMigrations []terraform.StateMigration
// for _, migration := range manualMigrations {
// u.log.Debugf("Adding manual Terraform migration: %s", migration.DisplayName)
// u.upgrader.AddManualStateMigration(migration)
// }
return a.clusterUpgrader.PlanClusterUpgrade(cmd.Context(), cmd.OutOrStdout(), vars, conf.GetProvider())
}
// migrateTerraform migrates an existing Terraform state and the post-migration infrastructure state is returned.
func (a *applyCmd) migrateTerraform(cmd *cobra.Command, conf *config.Config, upgradeDir string) (state.Infrastructure, error) {
// Ask for confirmation first
fmt.Fprintln(cmd.OutOrStdout(), "The upgrade requires a migration of Constellation cloud resources by applying an updated Terraform template. Please manually review the suggested changes below.")
if !a.flags.yes {
ok, err := askToConfirm(cmd, "Do you want to apply the Terraform migrations?")
if err != nil {
return state.Infrastructure{}, fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
cmd.Println("Aborting upgrade.")
// User doesn't expect to see any changes in his workspace after aborting an "upgrade apply",
// therefore, roll back to the backed up state.
if err := a.clusterUpgrader.RestoreClusterWorkspace(); err != nil {
return state.Infrastructure{}, fmt.Errorf(
"restoring Terraform workspace: %w, restore the Terraform workspace manually from %s ",
err,
filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir),
)
}
return state.Infrastructure{}, fmt.Errorf("cluster upgrade aborted by user")
}
}
a.log.Debugf("Applying Terraform migrations")
a.spinner.Start("Migrating Terraform resources", false)
infraState, err := a.clusterUpgrader.ApplyClusterUpgrade(cmd.Context(), conf.GetProvider())
a.spinner.Stop()
if err != nil {
return state.Infrastructure{}, fmt.Errorf("applying terraform migrations: %w", err)
}
cmd.Printf("Infrastructure migrations applied successfully and output written to: %s\n"+
"A backup of the pre-upgrade state has been written to: %s\n",
a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename),
a.flags.pathPrefixer.PrefixPrintablePath(filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir)),
)
return infraState, nil
}

View File

@ -7,28 +7,15 @@ SPDX-License-Identifier: AGPL-3.0-only
package cmd
import (
"bytes"
"context"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"sync"
"text/tabwriter"
"time"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/clientcmd"
@ -36,21 +23,13 @@ import (
"sigs.k8s.io/yaml"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/versions"
)
// NewInitCmd returns a new cobra.Command for the init command.
@ -61,7 +40,15 @@ func NewInitCmd() *cobra.Command {
Long: "Initialize the Constellation cluster.\n\n" +
"Start your confidential Kubernetes.",
Args: cobra.ExactArgs(0),
RunE: runInitialize,
RunE: func(cmd *cobra.Command, args []string) error {
// Define flags for apply backend that are not set by init
cmd.Flags().Bool("yes", false, "")
// Don't skip any phases
// The apply backend should handle init calls correctly
cmd.Flags().StringSlice("skip-phases", []string{}, "")
cmd.Flags().Duration("timeout", time.Hour, "")
return runApply(cmd, args)
},
}
cmd.Flags().Bool("conformance", false, "enable conformance mode")
cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready")
@ -69,270 +56,6 @@ func NewInitCmd() *cobra.Command {
return cmd
}
// initFlags are flags used by the init command.
type initFlags struct {
rootFlags
conformance bool
helmWaitMode helm.WaitMode
mergeConfigs bool
}
func (f *initFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
skipHelmWait, err := flags.GetBool("skip-helm-wait")
if err != nil {
return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err)
}
f.helmWaitMode = helm.WaitModeAtomic
if skipHelmWait {
f.helmWaitMode = helm.WaitModeNone
}
f.conformance, err = flags.GetBool("conformance")
if err != nil {
return fmt.Errorf("getting 'conformance' flag: %w", err)
}
f.mergeConfigs, err = flags.GetBool("merge-kubeconfig")
if err != nil {
return fmt.Errorf("getting 'merge-kubeconfig' flag: %w", err)
}
return nil
}
type initCmd struct {
log debugLog
merger configMerger
spinner spinnerInterf
fileHandler file.Handler
flags initFlags
}
func newInitCmd(fileHandler file.Handler, spinner spinnerInterf, merger configMerger, log debugLog) *initCmd {
return &initCmd{
log: log,
merger: merger,
spinner: spinner,
fileHandler: fileHandler,
}
}
// runInitialize runs the initialize command.
func runInitialize(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
}
defer log.Sync()
fileHandler := file.NewHandler(afero.NewOsFs())
newDialer := func(validator atls.Validator) *dialer.Dialer {
return dialer.New(nil, validator, &net.Dialer{})
}
spinner, err := newSpinnerOrStderr(cmd)
if err != nil {
return err
}
defer spinner.Stop()
ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour)
defer cancel()
cmd.SetContext(ctx)
i := newInitCmd(fileHandler, spinner, &kubeconfigMerger{log: log}, log)
if err := i.flags.parse(cmd.Flags()); err != nil {
return err
}
i.log.Debugf("Using flags: %+v", i.flags)
fetcher := attestationconfigapi.NewFetcher()
newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) {
return kubecmd.New(w, kubeConfig, fileHandler, log)
}
newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) {
return helm.NewClient(kubeConfigPath, log)
} // need to defer helm client instantiation until kubeconfig is available
return i.initialize(cmd, newDialer, license.NewClient(), fetcher, newAttestationApplier, newHelmClient)
}
// initialize initializes a Constellation.
func (i *initCmd) initialize(
cmd *cobra.Command, newDialer func(validator atls.Validator) *dialer.Dialer,
quotaChecker license.QuotaChecker, configFetcher attestationconfigapi.Fetcher,
newAttestationApplier func(io.Writer, string, debugLog) (attestationConfigApplier, error),
newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error),
) error {
i.log.Debugf("Loading configuration file from %q", i.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(i.fileHandler, constants.ConfigFilename, configFetcher, i.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
}
if err != nil {
return err
}
// cfg validation does not check k8s patch version since upgrade may accept an outdated patch version.
k8sVersion, err := versions.NewValidK8sVersion(string(conf.KubernetesVersion), true)
if err != nil {
return err
}
if !i.flags.force {
if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil {
return err
}
}
if conf.GetAttestationConfig().GetVariant().Equal(variant.AWSSEVSNP{}) {
cmd.PrintErrln("WARNING: Attestation temporarily relies on AWS nitroTPM. See https://docs.edgeless.systems/constellation/workflows/config#choosing-a-vm-type for more information.")
}
stateFile, err := state.ReadFromFile(i.fileHandler, constants.StateFilename)
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
i.log.Debugf("Validated k8s version as %s", k8sVersion)
if versions.IsPreviewK8sVersion(k8sVersion) {
cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", k8sVersion)
}
provider := conf.GetProvider()
i.log.Debugf("Got provider %s", provider.String())
checker := license.NewChecker(quotaChecker, i.fileHandler)
if err := checker.CheckLicense(cmd.Context(), provider, conf.Provider, cmd.Printf); err != nil {
cmd.PrintErrf("License check failed: %v", err)
}
i.log.Debugf("Checked license")
if stateFile.Infrastructure.Azure != nil {
conf.UpdateMAAURL(stateFile.Infrastructure.Azure.AttestationURL)
}
i.log.Debugf("Creating aTLS Validator for %s", conf.GetAttestationConfig().GetVariant())
validator, err := cloudcmd.NewValidator(cmd, conf.GetAttestationConfig(), i.log)
if err != nil {
return fmt.Errorf("creating new validator: %w", err)
}
i.log.Debugf("Created a new validator")
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(provider, conf, i.flags.pathPrefixer, i.log, i.fileHandler)
if err != nil {
return err
}
i.log.Debugf("Successfully marshaled service account URI")
i.log.Debugf("Generating master secret")
masterSecret, err := i.generateMasterSecret(cmd.OutOrStdout())
if err != nil {
return fmt.Errorf("generating master secret: %w", err)
}
i.log.Debugf("Generating measurement salt")
measurementSalt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return fmt.Errorf("generating measurement salt: %w", err)
}
i.log.Debugf("Setting cluster name to %s", stateFile.Infrastructure.Name)
cmd.PrintErrln("Note: If you just created the cluster, it can take a few minutes to connect.")
i.spinner.Start("Connecting ", false)
req := &initproto.InitRequest{
KmsUri: masterSecret.EncodeToURI(),
StorageUri: uri.NoStoreURI,
MeasurementSalt: measurementSalt,
KubernetesVersion: versions.VersionConfigs[k8sVersion].ClusterVersion,
KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToInitProto(),
ConformanceMode: i.flags.conformance,
InitSecret: stateFile.Infrastructure.InitSecret,
ClusterName: stateFile.Infrastructure.Name,
ApiserverCertSans: stateFile.Infrastructure.APIServerCertSANs,
}
i.log.Debugf("Sending initialization request")
resp, err := i.initCall(cmd.Context(), newDialer(validator), stateFile.Infrastructure.ClusterEndpoint, req)
i.spinner.Stop()
if err != nil {
var nonRetriable *nonRetriableError
if errors.As(err, &nonRetriable) {
cmd.PrintErrln("Cluster initialization failed. This error is not recoverable.")
cmd.PrintErrln("Terminate your cluster and try again.")
if nonRetriable.logCollectionErr != nil {
cmd.PrintErrf("Failed to collect logs from bootstrapper: %s\n", nonRetriable.logCollectionErr)
} else {
cmd.PrintErrf("Fetched bootstrapper logs are stored in %q\n", i.flags.pathPrefixer.PrefixPrintablePath(constants.ErrorLog))
}
}
return err
}
i.log.Debugf("Initialization request succeeded")
bufferedOutput := &bytes.Buffer{}
if err := i.writeOutput(stateFile, resp, i.flags.mergeConfigs, bufferedOutput, measurementSalt); err != nil {
return err
}
attestationApplier, err := newAttestationApplier(cmd.OutOrStdout(), constants.AdminConfFilename, i.log)
if err != nil {
return err
}
if err := attestationApplier.ApplyJoinConfig(cmd.Context(), conf.GetAttestationConfig(), measurementSalt); err != nil {
return fmt.Errorf("applying attestation config: %w", err)
}
i.spinner.Start("Installing Kubernetes components ", false)
options := helm.Options{
Force: i.flags.force,
Conformance: i.flags.conformance,
HelmWaitMode: i.flags.helmWaitMode,
AllowDestructive: helm.DenyDestructive,
}
helmApplier, err := newHelmClient(constants.AdminConfFilename, i.log)
if err != nil {
return fmt.Errorf("creating Helm client: %w", err)
}
executor, includesUpgrades, err := helmApplier.PrepareApply(conf, stateFile, options, serviceAccURI, masterSecret)
if err != nil {
return fmt.Errorf("getting Helm chart executor: %w", err)
}
if includesUpgrades {
return errors.New("init: helm tried to upgrade charts instead of installing them")
}
if err := executor.Apply(cmd.Context()); err != nil {
return fmt.Errorf("applying Helm charts: %w", err)
}
i.spinner.Stop()
i.log.Debugf("Helm deployment installation succeeded")
cmd.Println(bufferedOutput.String())
return nil
}
func (i *initCmd) initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.InitRequest) (*initproto.InitSuccessResponse, error) {
doer := &initDoer{
dialer: dialer,
endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.BootstrapperPort)),
req: req,
log: i.log,
spinner: i.spinner,
fh: file.NewHandler(afero.NewOsFs()),
}
// Create a wrapper function that allows logging any returned error from the retrier before checking if it's the expected retriable one.
serviceIsUnavailable := func(err error) bool {
isServiceUnavailable := grpcRetry.ServiceIsUnavailable(err)
i.log.Debugf("Encountered error (retriable: %t): %s", isServiceUnavailable, err)
return isServiceUnavailable
}
i.log.Debugf("Making initialization call, doer is %+v", doer)
retrier := retry.NewIntervalRetrier(doer, 30*time.Second, serviceIsUnavailable)
if err := retrier.Do(ctx); err != nil {
return nil, err
}
return doer.resp, nil
}
type initDoer struct {
dialer grpcDialer
endpoint string
@ -469,122 +192,10 @@ func (d *initDoer) handleGRPCStateChanges(ctx context.Context, wg *sync.WaitGrou
})
}
// writeOutput writes the output of a cluster initialization to the
// state- / id- / kubeconfig-file and saves it to disk.
func (i *initCmd) writeOutput(
stateFile *state.State,
initResp *initproto.InitSuccessResponse,
mergeConfig bool, wr io.Writer,
measurementSalt []byte,
) error {
fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n")
ownerID := hex.EncodeToString(initResp.GetOwnerId())
clusterID := hex.EncodeToString(initResp.GetClusterId())
stateFile.SetClusterValues(state.ClusterValues{
MeasurementSalt: measurementSalt,
OwnerID: ownerID,
ClusterID: clusterID,
})
tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0)
writeRow(tw, "Constellation cluster identifier", clusterID)
writeRow(tw, "Kubernetes configuration", i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
tw.Flush()
fmt.Fprintln(wr)
i.log.Debugf("Rewriting cluster server address in kubeconfig to %s", stateFile.Infrastructure.ClusterEndpoint)
kubeconfig, err := clientcmd.Load(initResp.GetKubeconfig())
if err != nil {
return fmt.Errorf("loading kubeconfig: %w", err)
}
if len(kubeconfig.Clusters) != 1 {
return fmt.Errorf("expected exactly one cluster in kubeconfig, got %d", len(kubeconfig.Clusters))
}
for _, cluster := range kubeconfig.Clusters {
kubeEndpoint, err := url.Parse(cluster.Server)
if err != nil {
return fmt.Errorf("parsing kubeconfig server URL: %w", err)
}
kubeEndpoint.Host = net.JoinHostPort(stateFile.Infrastructure.ClusterEndpoint, kubeEndpoint.Port())
cluster.Server = kubeEndpoint.String()
}
kubeconfigBytes, err := clientcmd.Write(*kubeconfig)
if err != nil {
return fmt.Errorf("marshaling kubeconfig: %w", err)
}
if err := i.fileHandler.Write(constants.AdminConfFilename, kubeconfigBytes, file.OptNone); err != nil {
return fmt.Errorf("writing kubeconfig: %w", err)
}
i.log.Debugf("Kubeconfig written to %s", i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
if mergeConfig {
if err := i.merger.mergeConfigs(constants.AdminConfFilename, i.fileHandler); err != nil {
writeRow(tw, "Failed to automatically merge kubeconfig", err.Error())
mergeConfig = false // Set to false so we don't print the wrong message below.
} else {
writeRow(tw, "Kubernetes configuration merged with default config", "")
}
}
if err := stateFile.WriteToFile(i.fileHandler, constants.StateFilename); err != nil {
return fmt.Errorf("writing Constellation state file: %w", err)
}
i.log.Debugf("Constellation state file written to %s", i.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
if !mergeConfig {
fmt.Fprintln(wr, "You can now connect to your cluster by executing:")
exportPath, err := filepath.Abs(constants.AdminConfFilename)
if err != nil {
return fmt.Errorf("getting absolute path to kubeconfig: %w", err)
}
fmt.Fprintf(wr, "\texport KUBECONFIG=%q\n", exportPath)
} else {
fmt.Fprintln(wr, "Constellation kubeconfig merged with default config.")
if i.merger.kubeconfigEnvVar() != "" {
fmt.Fprintln(wr, "Warning: KUBECONFIG environment variable is set.")
fmt.Fprintln(wr, "You may need to unset it to use the default config and connect to your cluster.")
} else {
fmt.Fprintln(wr, "You can now connect to your cluster.")
}
}
return nil
}
func writeRow(wr io.Writer, col1 string, col2 string) {
fmt.Fprint(wr, col1, "\t", col2, "\n")
}
// generateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func (i *initCmd) generateMasterSecret(outWriter io.Writer) (uri.MasterSecret, error) {
// No file given, generate a new secret, and save it to disk
i.log.Debugf("Generating new master secret")
key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault)
if err != nil {
return uri.MasterSecret{}, err
}
salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return uri.MasterSecret{}, err
}
secret := uri.MasterSecret{
Key: key,
Salt: salt,
}
i.log.Debugf("Generated master secret key and salt values")
if err := i.fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
return uri.MasterSecret{}, err
}
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to %q\n", i.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename))
return secret, nil
}
type configMerger interface {
mergeConfigs(configPath string, fileHandler file.Handler) error
kubeconfigEnvVar() string
@ -657,10 +268,6 @@ func (e *nonRetriableError) Unwrap() error {
return e.err
}
type attestationConfigApplier interface {
ApplyJoinConfig(ctx context.Context, newAttestConfig config.AttestationCfg, measurementSalt []byte) error
}
type helmApplier interface {
PrepareApply(conf *config.Config, stateFile *state.State,
flags helm.Options, serviceAccURI string, masterSecret uri.MasterSecret) (

View File

@ -44,6 +44,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/tools/clientcmd"
k8sclientapi "k8s.io/client-go/tools/clientcmd/api"
)
@ -58,8 +60,6 @@ func TestInitArgumentValidation(t *testing.T) {
}
func TestInitialize(t *testing.T) {
require := require.New(t)
respKubeconfig := k8sclientapi.Config{
Clusters: map[string]*k8sclientapi.Cluster{
"cluster": {
@ -68,7 +68,7 @@ func TestInitialize(t *testing.T) {
},
}
respKubeconfigBytes, err := clientcmd.Write(respKubeconfig)
require.NoError(err)
require.NoError(t, err)
gcpServiceAccKey := &gcpshared.ServiceAccountKey{
Type: "service_account",
@ -149,31 +149,47 @@ func TestInitialize(t *testing.T) {
masterSecretShouldExist: true,
wantErr: true,
},
"state file with only version": {
provider: cloudprovider.GCP,
stateFile: &state.State{Version: state.Version1},
initServerAPI: &stubInitServer{},
retriable: true,
wantErr: true,
},
"empty state file": {
provider: cloudprovider.GCP,
stateFile: &state.State{},
initServerAPI: &stubInitServer{},
retriable: true,
wantErr: true,
},
/*
Tests currently disabled since we don't actually have validation for the state file yet
These tests cases only passed in the past because of unrelated errors in the test setup
TODO(AB#3492): Re-enable tests once state file validation is implemented
"state file with only version": {
provider: cloudprovider.GCP,
stateFile: &state.State{Version: state.Version1},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{},
retriable: true,
wantErr: true,
},
"empty state file": {
provider: cloudprovider.GCP,
stateFile: &state.State{},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{},
retriable: true,
wantErr: true,
},
*/
"no state file": {
provider: cloudprovider.GCP,
retriable: true,
wantErr: true,
provider: cloudprovider.GCP,
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
retriable: true,
wantErr: true,
},
"init call fails": {
provider: cloudprovider.GCP,
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
initServerAPI: &stubInitServer{initErr: assert.AnError},
retriable: true,
wantErr: true,
provider: cloudprovider.GCP,
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{initErr: assert.AnError},
retriable: false,
masterSecretShouldExist: true,
wantErr: true,
},
"k8s version without v works": {
provider: cloudprovider.Azure,
@ -181,7 +197,7 @@ func TestInitialize(t *testing.T) {
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
configMutator: func(c *config.Config) {
res, err := versions.NewValidK8sVersion(strings.TrimPrefix(string(versions.Default), "v"), true)
require.NoError(err)
require.NoError(t, err)
c.KubernetesVersion = res
},
},
@ -191,7 +207,7 @@ func TestInitialize(t *testing.T) {
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
configMutator: func(c *config.Config) {
v, err := semver.New(versions.SupportedK8sVersions()[0])
require.NoError(err)
require.NoError(t, err)
outdatedPatchVer := semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String()
c.KubernetesVersion = versions.ValidK8sVersion(outdatedPatchVer)
},
@ -203,6 +219,7 @@ func TestInitialize(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
// Networking
netDialer := testdialer.NewBufconnDialer()
newDialer := func(atls.Validator) *dialer.Dialer {
@ -231,8 +248,6 @@ func TestInitialize(t *testing.T) {
tc.configMutator(config)
}
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, config, file.OptNone))
stateFile := state.New()
require.NoError(stateFile.WriteToFile(fileHandler, constants.StateFilename))
if tc.stateFile != nil {
require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename))
}
@ -244,20 +259,28 @@ func TestInitialize(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
cmd.SetContext(ctx)
i := newInitCmd(fileHandler, &nopSpinner{}, nil, logger.NewTest(t))
i.flags.force = true
err := i.initialize(
cmd,
newDialer,
&stubLicenseClient{},
stubAttestationFetcher{},
func(io.Writer, string, debugLog) (attestationConfigApplier, error) {
return &stubAttestationApplier{}, nil
},
func(_ string, _ debugLog) (helmApplier, error) {
i := &applyCmd{
fileHandler: fileHandler,
flags: applyFlags{rootFlags: rootFlags{force: true}},
log: logger.NewTest(t),
spinner: &nopSpinner{},
merger: &stubMerger{},
quotaChecker: &stubLicenseClient{},
newHelmClient: func(string, debugLog) (helmApplier, error) {
return &stubApplier{}, nil
})
},
newDialer: newDialer,
newKubeUpgrader: func(io.Writer, string, debugLog) (kubernetesUpgrader, error) {
return &stubKubernetesUpgrader{
// On init, no attestation config exists yet
getClusterAttestationConfigErr: k8serrors.NewNotFound(schema.GroupResource{}, ""),
}, nil
},
clusterUpgrader: stubTerraformUpgrader{},
}
err := i.apply(cmd, stubAttestationFetcher{}, "test")
if tc.wantErr {
assert.Error(err)
@ -291,14 +314,17 @@ func (s stubApplier) PrepareApply(_ *config.Config, _ *state.State, _ helm.Optio
return stubRunner{}, false, s.err
}
type stubRunner struct{}
type stubRunner struct {
applyErr error
saveChartsErr error
}
func (s stubRunner) Apply(_ context.Context) error {
return nil
return s.applyErr
}
func (s stubRunner) SaveCharts(_ string, _ file.Handler) error {
return nil
return s.saveChartsErr
}
func TestGetLogs(t *testing.T) {
@ -420,8 +446,13 @@ func TestWriteOutput(t *testing.T) {
ClusterEndpoint: clusterEndpoint,
})
i := newInitCmd(fileHandler, &nopSpinner{}, &stubMerger{}, logger.NewTest(t))
err = i.writeOutput(stateFile, resp.GetInitSuccess(), false, &out, measurementSalt)
i := &applyCmd{
fileHandler: fileHandler,
spinner: &nopSpinner{},
merger: &stubMerger{},
log: logger.NewTest(t),
}
err = i.writeInitOutput(stateFile, resp.GetInitSuccess(), false, &out, measurementSalt)
require.NoError(err)
assert.Contains(out.String(), clusterID)
assert.Contains(out.String(), constants.AdminConfFilename)
@ -441,7 +472,7 @@ func TestWriteOutput(t *testing.T) {
// test custom workspace
i.flags.pathPrefixer = pathprefix.New("/some/path")
err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
err = i.writeInitOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
require.NoError(err)
assert.Contains(out.String(), clusterID)
assert.Contains(out.String(), i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
@ -451,7 +482,7 @@ func TestWriteOutput(t *testing.T) {
i.flags.pathPrefixer = pathprefix.PathPrefixer{}
// test config merging
err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
err = i.writeInitOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
require.NoError(err)
assert.Contains(out.String(), clusterID)
assert.Contains(out.String(), constants.AdminConfFilename)
@ -462,7 +493,7 @@ func TestWriteOutput(t *testing.T) {
// test config merging with env vars set
i.merger = &stubMerger{envVar: "/some/path/to/kubeconfig"}
err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
err = i.writeInitOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
require.NoError(err)
assert.Contains(out.String(), clusterID)
assert.Contains(out.String(), constants.AdminConfFilename)
@ -508,8 +539,11 @@ func TestGenerateMasterSecret(t *testing.T) {
require.NoError(tc.createFileFunc(fileHandler))
var out bytes.Buffer
i := newInitCmd(fileHandler, nil, nil, logger.NewTest(t))
secret, err := i.generateMasterSecret(&out)
i := &applyCmd{
fileHandler: fileHandler,
log: logger.NewTest(t),
}
secret, err := i.generateAndPersistMasterSecret(&out)
if tc.wantErr {
assert.Error(err)
@ -601,13 +635,17 @@ func TestAttestation(t *testing.T) {
defer cancel()
cmd.SetContext(ctx)
i := newInitCmd(fileHandler, &nopSpinner{}, nil, logger.NewTest(t))
err := i.initialize(cmd, newDialer, &stubLicenseClient{}, stubAttestationFetcher{},
func(io.Writer, string, debugLog) (attestationConfigApplier, error) {
return &stubAttestationApplier{}, nil
}, func(_ string, _ debugLog) (helmApplier, error) {
return &stubApplier{}, nil
})
i := &applyCmd{
fileHandler: fileHandler,
spinner: &nopSpinner{},
merger: &stubMerger{},
log: logger.NewTest(t),
newKubeUpgrader: func(io.Writer, string, debugLog) (kubernetesUpgrader, error) {
return &stubKubernetesUpgrader{}, nil
},
newDialer: newDialer,
}
_, err := i.runInit(cmd, cfg, existingStateFile)
assert.Error(err)
// make sure the error is actually a TLS handshake error
assert.Contains(err.Error(), "transport: authentication handshake failed")
@ -773,11 +811,3 @@ func (c stubInitClient) Recv() (*initproto.InitResponse, error) {
return res, err
}
type stubAttestationApplier struct {
applyErr error
}
func (a *stubAttestationApplier) ApplyJoinConfig(context.Context, config.AttestationCfg, []byte) error {
return a.applyErr
}

View File

@ -10,23 +10,17 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/featureset"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/libvirt"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
@ -105,7 +99,7 @@ func (m *miniUpCmd) up(cmd *cobra.Command, creator cloudCreator, spinner spinner
cmd.Printf("\tvirsh -c %s\n\n", connectURI)
// initialize cluster
if err := m.initializeMiniCluster(cmd, spinner); err != nil {
if err := m.initializeMiniCluster(cmd); err != nil {
return fmt.Errorf("initializing cluster: %w", err)
}
m.log.Debugf("Initialized cluster")
@ -188,7 +182,7 @@ func (m *miniUpCmd) createMiniCluster(ctx context.Context, creator cloudCreator,
}
// initializeMiniCluster initializes a QEMU cluster.
func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, spinner spinnerInterf) (retErr error) {
func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command) (retErr error) {
m.log.Debugf("Initializing mini cluster")
// clean up cluster resources if initialization fails
defer func() {
@ -199,34 +193,19 @@ func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, spinner spinnerInt
cmd.PrintErrf("Rollback succeeded.\n\n")
}
}()
newDialer := func(validator atls.Validator) *dialer.Dialer {
return dialer.New(nil, validator, &net.Dialer{})
}
m.log.Debugf("Created new dialer")
cmd.Flags().String("endpoint", "", "")
// Define flags for apply backend that are not set by mini up
cmd.Flags().StringSlice(
"skip-phases",
[]string{string(skipInfrastructurePhase), string(skipK8sPhase), string(skipImagePhase)},
"",
)
cmd.Flags().Bool("yes", false, "")
cmd.Flags().Bool("skip-helm-wait", false, "")
cmd.Flags().Bool("conformance", false, "")
cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready")
log, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
}
m.log.Debugf("Created new logger")
defer log.Sync()
cmd.Flags().Duration("timeout", time.Hour, "")
newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) {
return kubecmd.New(w, kubeConfig, m.fileHandler, log)
}
newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) {
return helm.NewClient(kubeConfigPath, log)
} // need to defer helm client instantiation until kubeconfig is available
i := newInitCmd(m.fileHandler, spinner, &kubeconfigMerger{log: log}, log)
if err := i.flags.parse(cmd.Flags()); err != nil {
return err
}
if err := i.initialize(cmd, newDialer, license.NewClient(), m.configFetcher,
newAttestationApplier, newHelmClient); err != nil {
if err := runApply(cmd, nil); err != nil {
return err
}
m.log.Debugf("Initialized mini cluster")

View File

@ -8,36 +8,25 @@ package cmd
import (
"context"
"errors"
"fmt"
"io"
"path/filepath"
"strings"
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/compatibility"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/rogpeppe/go-internal/diff"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"gopkg.in/yaml.v3"
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
)
const (
// skipInitPhase skips the init RPC of the apply process.
skipInitPhase skipPhase = "init"
// skipInfrastructurePhase skips the terraform apply of the upgrade process.
skipInfrastructurePhase skipPhase = "infrastructure"
// skipHelmPhase skips the helm upgrade of the upgrade process.
@ -57,7 +46,11 @@ func newUpgradeApplyCmd() *cobra.Command {
Short: "Apply an upgrade to a Constellation cluster",
Long: "Apply an upgrade to a Constellation cluster by applying the chosen configuration.",
Args: cobra.NoArgs,
RunE: runUpgradeApply,
RunE: func(cmd *cobra.Command, args []string) error {
// Define flags for apply backend that are not set by upgrade-apply
cmd.Flags().Bool("merge-kubeconfig", false, "")
return runApply(cmd, args)
},
}
cmd.Flags().BoolP("yes", "y", false, "run upgrades without further confirmation\n"+
@ -69,238 +62,11 @@ func newUpgradeApplyCmd() *cobra.Command {
cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready")
cmd.Flags().StringSlice("skip-phases", nil, "comma-separated list of upgrade phases to skip\n"+
"one or multiple of { infrastructure | helm | image | k8s }")
if err := cmd.Flags().MarkHidden("timeout"); err != nil {
panic(err)
}
must(cmd.Flags().MarkHidden("timeout"))
return cmd
}
type upgradeApplyFlags struct {
rootFlags
yes bool
upgradeTimeout time.Duration
conformance bool
helmWaitMode helm.WaitMode
skipPhases skipPhases
}
func (f *upgradeApplyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
rawSkipPhases, err := flags.GetStringSlice("skip-phases")
if err != nil {
return fmt.Errorf("parsing skip-phases flag: %w", err)
}
var skipPhases []skipPhase
for _, phase := range rawSkipPhases {
switch skipPhase(phase) {
case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase:
skipPhases = append(skipPhases, skipPhase(phase))
default:
return fmt.Errorf("invalid phase %s", phase)
}
}
f.skipPhases = skipPhases
f.yes, err = flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.upgradeTimeout, err = flags.GetDuration("timeout")
if err != nil {
return fmt.Errorf("getting 'timeout' flag: %w", err)
}
f.conformance, err = flags.GetBool("conformance")
if err != nil {
return fmt.Errorf("getting 'conformance' flag: %w", err)
}
skipHelmWait, err := flags.GetBool("skip-helm-wait")
if err != nil {
return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err)
}
f.helmWaitMode = helm.WaitModeAtomic
if skipHelmWait {
f.helmWaitMode = helm.WaitModeNone
}
return nil
}
func runUpgradeApply(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
}
defer log.Sync()
fileHandler := file.NewHandler(afero.NewOsFs())
upgradeID := generateUpgradeID(upgradeCmdKindApply)
kubeUpgrader, err := kubecmd.New(cmd.OutOrStdout(), constants.AdminConfFilename, fileHandler, log)
if err != nil {
return err
}
configFetcher := attestationconfigapi.NewFetcher()
var flags upgradeApplyFlags
if err := flags.parse(cmd.Flags()); err != nil {
return err
}
// Set up terraform upgrader
upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID)
clusterUpgrader, err := cloudcmd.NewClusterUpgrader(
cmd.Context(),
constants.TerraformWorkingDir,
upgradeDir,
flags.tfLogLevel,
fileHandler,
)
if err != nil {
return fmt.Errorf("setting up cluster upgrader: %w", err)
}
helmClient, err := helm.NewClient(constants.AdminConfFilename, log)
if err != nil {
return fmt.Errorf("creating Helm client: %w", err)
}
applyCmd := upgradeApplyCmd{
kubeUpgrader: kubeUpgrader,
helmApplier: helmClient,
clusterUpgrader: clusterUpgrader,
configFetcher: configFetcher,
fileHandler: fileHandler,
flags: flags,
log: log,
}
return applyCmd.upgradeApply(cmd, upgradeDir)
}
type upgradeApplyCmd struct {
helmApplier helmApplier
kubeUpgrader kubernetesUpgrader
clusterUpgrader clusterUpgrader
configFetcher attestationconfigapi.Fetcher
fileHandler file.Handler
flags upgradeApplyFlags
log debugLog
}
func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string) error {
conf, err := config.New(u.fileHandler, constants.ConfigFilename, u.configFetcher, u.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
}
if err != nil {
return err
}
if cloudcmd.UpgradeRequiresIAMMigration(conf.GetProvider()) {
cmd.Println("WARNING: This upgrade requires an IAM migration. Please make sure you have applied the IAM migration using `iam upgrade apply` before continuing.")
if !u.flags.yes {
yes, err := askToConfirm(cmd, "Did you upgrade the IAM resources?")
if err != nil {
return fmt.Errorf("asking for confirmation: %w", err)
}
if !yes {
cmd.Println("Skipping upgrade.")
return nil
}
}
}
conf.KubernetesVersion, err = validK8sVersion(cmd, string(conf.KubernetesVersion), u.flags.yes)
if err != nil {
return err
}
stateFile, err := state.ReadFromFile(u.fileHandler, constants.StateFilename)
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
if err := u.confirmAndUpgradeAttestationConfig(cmd, conf.GetAttestationConfig(), stateFile.ClusterValues.MeasurementSalt); err != nil {
return fmt.Errorf("upgrading measurements: %w", err)
}
// If infrastructure phase is skipped, we expect the new infrastructure
// to be in the Terraform configuration already. Otherwise, perform
// the Terraform migrations.
if !u.flags.skipPhases.contains(skipInfrastructurePhase) {
migrationRequired, err := u.planTerraformMigration(cmd, conf)
if err != nil {
return fmt.Errorf("planning Terraform migrations: %w", err)
}
if migrationRequired {
postMigrationInfraState, err := u.migrateTerraform(cmd, conf, upgradeDir)
if err != nil {
return fmt.Errorf("performing Terraform migrations: %w", err)
}
// Merge the pre-upgrade state with the post-migration infrastructure values
if _, err := stateFile.Merge(
// temporary state with post-migration infrastructure values
state.New().SetInfrastructure(postMigrationInfraState),
); err != nil {
return fmt.Errorf("merging pre-upgrade state with post-migration infrastructure values: %w", err)
}
// Write the post-migration state to disk
if err := stateFile.WriteToFile(u.fileHandler, constants.StateFilename); err != nil {
return fmt.Errorf("writing state file: %w", err)
}
}
}
// extend the clusterConfig cert SANs with any of the supported endpoints:
// - (legacy) public IP
// - fallback endpoint
// - custom (user-provided) endpoint
sans := append([]string{stateFile.Infrastructure.ClusterEndpoint, conf.CustomEndpoint}, stateFile.Infrastructure.APIServerCertSANs...)
if err := u.kubeUpgrader.ExtendClusterConfigCertSANs(cmd.Context(), sans); err != nil {
return fmt.Errorf("extending cert SANs: %w", err)
}
if conf.GetProvider() != cloudprovider.Azure && conf.GetProvider() != cloudprovider.GCP && conf.GetProvider() != cloudprovider.AWS {
cmd.PrintErrln("WARNING: Skipping service and image upgrades, which are currently only supported for AWS, Azure, and GCP.")
return nil
}
var upgradeErr *compatibility.InvalidUpgradeError
if !u.flags.skipPhases.contains(skipHelmPhase) {
err = u.handleServiceUpgrade(cmd, conf, stateFile, upgradeDir)
switch {
case errors.As(err, &upgradeErr):
cmd.PrintErrln(err)
case err == nil:
cmd.Println("Successfully upgraded Constellation services.")
case err != nil:
return fmt.Errorf("upgrading services: %w", err)
}
}
skipImageUpgrade := u.flags.skipPhases.contains(skipImagePhase)
skipK8sUpgrade := u.flags.skipPhases.contains(skipK8sPhase)
if !(skipImageUpgrade && skipK8sUpgrade) {
err = u.kubeUpgrader.UpgradeNodeVersion(cmd.Context(), conf, u.flags.force, skipImageUpgrade, skipK8sUpgrade)
switch {
case errors.Is(err, kubecmd.ErrInProgress):
cmd.PrintErrln("Skipping image and Kubernetes upgrades. Another upgrade is in progress.")
case errors.As(err, &upgradeErr):
cmd.PrintErrln(err)
case err != nil:
return fmt.Errorf("upgrading NodeVersion: %w", err)
}
}
return nil
}
func diffAttestationCfg(currentAttestationCfg config.AttestationCfg, newAttestationCfg config.AttestationCfg) (string, error) {
// cannot compare structs directly with go-cmp because of unexported fields in the attestation config
currentYml, err := yaml.Marshal(currentAttestationCfg)
@ -315,220 +81,23 @@ func diffAttestationCfg(currentAttestationCfg config.AttestationCfg, newAttestat
return diff, nil
}
// planTerraformMigration checks if the Constellation version the cluster is being upgraded to requires a migration.
func (u *upgradeApplyCmd) planTerraformMigration(cmd *cobra.Command, conf *config.Config) (bool, error) {
u.log.Debugf("Planning Terraform migrations")
vars, err := cloudcmd.TerraformUpgradeVars(conf)
if err != nil {
return false, fmt.Errorf("parsing upgrade variables: %w", err)
}
u.log.Debugf("Using Terraform variables:\n%v", vars)
// Check if there are any Terraform migrations to apply
// Add manual migrations here if required
//
// var manualMigrations []terraform.StateMigration
// for _, migration := range manualMigrations {
// u.log.Debugf("Adding manual Terraform migration: %s", migration.DisplayName)
// u.upgrader.AddManualStateMigration(migration)
// }
return u.clusterUpgrader.PlanClusterUpgrade(cmd.Context(), cmd.OutOrStdout(), vars, conf.GetProvider())
}
// migrateTerraform checks if the Constellation version the cluster is being upgraded to requires a migration
// of cloud resources with Terraform. If so, the migration is performed and the post-migration infrastructure state is returned.
// If no migration is required, the current (pre-upgrade) infrastructure state is returned.
func (u *upgradeApplyCmd) migrateTerraform(cmd *cobra.Command, conf *config.Config, upgradeDir string,
) (state.Infrastructure, error) {
// If there are any Terraform migrations to apply, ask for confirmation
fmt.Fprintln(cmd.OutOrStdout(), "The upgrade requires a migration of Constellation cloud resources by applying an updated Terraform template. Please manually review the suggested changes below.")
if !u.flags.yes {
ok, err := askToConfirm(cmd, "Do you want to apply the Terraform migrations?")
if err != nil {
return state.Infrastructure{}, fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
cmd.Println("Aborting upgrade.")
// User doesn't expect to see any changes in his workspace after aborting an "upgrade apply",
// therefore, roll back to the backed up state.
if err := u.clusterUpgrader.RestoreClusterWorkspace(); err != nil {
return state.Infrastructure{}, fmt.Errorf(
"restoring Terraform workspace: %w, restore the Terraform workspace manually from %s ",
err,
filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir),
)
}
return state.Infrastructure{}, fmt.Errorf("cluster upgrade aborted by user")
}
}
u.log.Debugf("Applying Terraform migrations")
infraState, err := u.clusterUpgrader.ApplyClusterUpgrade(cmd.Context(), conf.GetProvider())
if err != nil {
return state.Infrastructure{}, fmt.Errorf("applying terraform migrations: %w", err)
}
cmd.Printf("Infrastructure migrations applied successfully and output written to: %s\n"+
"A backup of the pre-upgrade state has been written to: %s\n",
u.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename),
u.flags.pathPrefixer.PrefixPrintablePath(filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir)),
)
return infraState, nil
}
// validK8sVersion checks if the Kubernetes patch version is supported and asks for confirmation if not.
func validK8sVersion(cmd *cobra.Command, version string, yes bool) (validVersion versions.ValidK8sVersion, err error) {
validVersion, err = versions.NewValidK8sVersion(version, true)
if versions.IsPreviewK8sVersion(validVersion) {
cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", validVersion)
}
valid := err == nil
if !valid && !yes {
confirmed, err := askToConfirm(cmd, fmt.Sprintf("WARNING: The Kubernetes patch version %s is not supported. If you continue, Kubernetes upgrades will be skipped. Do you want to continue anyway?", version))
if err != nil {
return validVersion, fmt.Errorf("asking for confirmation: %w", err)
}
if !confirmed {
return validVersion, fmt.Errorf("aborted by user")
}
}
return validVersion, nil
}
// confirmAndUpgradeAttestationConfig checks if the locally configured measurements are different from the cluster's measurements.
// If so the function will ask the user to confirm (if --yes is not set) and upgrade the cluster's config.
func (u *upgradeApplyCmd) confirmAndUpgradeAttestationConfig(
cmd *cobra.Command, newConfig config.AttestationCfg, measurementSalt []byte,
) error {
clusterAttestationConfig, err := u.kubeUpgrader.GetClusterAttestationConfig(cmd.Context(), newConfig.GetVariant())
if err != nil {
return fmt.Errorf("getting cluster attestation config: %w", err)
}
// If the current config is equal, or there is an error when comparing the configs, we skip the upgrade.
equal, err := newConfig.EqualTo(clusterAttestationConfig)
if err != nil {
return fmt.Errorf("comparing attestation configs: %w", err)
}
if equal {
return nil
}
cmd.Println("The configured attestation config is different from the attestation config in the cluster.")
diffStr, err := diffAttestationCfg(clusterAttestationConfig, newConfig)
if err != nil {
return fmt.Errorf("diffing attestation configs: %w", err)
}
cmd.Println("The following changes will be applied to the attestation config:")
cmd.Println(diffStr)
if !u.flags.yes {
ok, err := askToConfirm(cmd, "Are you sure you want to change your cluster's attestation config?")
if err != nil {
return fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
return errors.New("aborting upgrade since attestation config is different")
}
}
if err := u.kubeUpgrader.ApplyJoinConfig(cmd.Context(), newConfig, measurementSalt); err != nil {
return fmt.Errorf("updating attestation config: %w", err)
}
cmd.Println("Successfully updated the cluster's attestation config")
return nil
}
func (u *upgradeApplyCmd) handleServiceUpgrade(
cmd *cobra.Command, conf *config.Config, stateFile *state.State, upgradeDir string,
) error {
var secret uri.MasterSecret
if err := u.fileHandler.ReadJSON(constants.MasterSecretFilename, &secret); err != nil {
return fmt.Errorf("reading master secret: %w", err)
}
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(conf.GetProvider(), conf, u.flags.pathPrefixer, u.log, u.fileHandler)
if err != nil {
return fmt.Errorf("getting service account URI: %w", err)
}
options := helm.Options{
Force: u.flags.force,
Conformance: u.flags.conformance,
HelmWaitMode: u.flags.helmWaitMode,
}
prepareApply := func(allowDestructive bool) (helm.Applier, bool, error) {
options.AllowDestructive = allowDestructive
executor, includesUpgrades, err := u.helmApplier.PrepareApply(conf, stateFile, options, serviceAccURI, secret)
var upgradeErr *compatibility.InvalidUpgradeError
switch {
case errors.As(err, &upgradeErr):
cmd.PrintErrln(err)
case err != nil:
return nil, false, fmt.Errorf("getting chart executor: %w", err)
}
return executor, includesUpgrades, nil
}
executor, includesUpgrades, err := prepareApply(helm.DenyDestructive)
if err != nil {
if !errors.Is(err, helm.ErrConfirmationMissing) {
return fmt.Errorf("upgrading charts with deny destructive mode: %w", err)
}
if !u.flags.yes {
cmd.PrintErrln("WARNING: Upgrading cert-manager will destroy all custom resources you have manually created that are based on the current version of cert-manager.")
ok, askErr := askToConfirm(cmd, "Do you want to upgrade cert-manager anyway?")
if askErr != nil {
return fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
cmd.Println("Skipping upgrade.")
return nil
}
}
executor, includesUpgrades, err = prepareApply(helm.AllowDestructive)
if err != nil {
return fmt.Errorf("upgrading charts with allow destructive mode: %w", err)
}
}
// Save the Helm charts for the upgrade to disk
chartDir := filepath.Join(upgradeDir, "helm-charts")
if err := executor.SaveCharts(chartDir, u.fileHandler); err != nil {
return fmt.Errorf("saving Helm charts to disk: %w", err)
}
u.log.Debugf("Helm charts saved to %s", chartDir)
if includesUpgrades {
u.log.Debugf("Creating backup of CRDs and CRs")
crds, err := u.kubeUpgrader.BackupCRDs(cmd.Context(), upgradeDir)
if err != nil {
return fmt.Errorf("creating CRD backup: %w", err)
}
if err := u.kubeUpgrader.BackupCRs(cmd.Context(), crds, upgradeDir); err != nil {
return fmt.Errorf("creating CR backup: %w", err)
}
}
if err := executor.Apply(cmd.Context()); err != nil {
return fmt.Errorf("applying Helm charts: %w", err)
}
return nil
}
// skipPhases is a list of phases that can be skipped during the upgrade process.
type skipPhases []skipPhase
type skipPhases map[skipPhase]struct{}
// contains returns true if the list of phases contains the given phase.
func (s skipPhases) contains(phase skipPhase) bool {
for _, p := range s {
if strings.EqualFold(string(p), string(phase)) {
return true
}
_, ok := s[skipPhase(strings.ToLower(string(phase)))]
return ok
}
// add a phase to the list of phases.
func (s *skipPhases) add(phases ...skipPhase) {
if *s == nil {
*s = make(skipPhases)
}
for _, phase := range phases {
(*s)[skipPhase(strings.ToLower(string(phase)))] = struct{}{}
}
return false
}
type kubernetesUpgrader interface {

View File

@ -41,8 +41,9 @@ func TestUpgradeApply(t *testing.T) {
InitSecret: []byte{0x42},
}).
SetClusterValues(state.ClusterValues{MeasurementSalt: []byte{0x41}})
fsWithStateFile := func() file.Handler {
fsWithStateFileAndTfState := func() file.Handler {
fh := file.NewHandler(afero.NewMemMapFs())
require.NoError(t, fh.MkdirAll(constants.TerraformWorkingDir))
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState))
return fh
}
@ -55,15 +56,15 @@ func TestUpgradeApply(t *testing.T) {
terraformUpgrader clusterUpgrader
wantErr bool
customK8sVersion string
flags upgradeApplyFlags
flags applyFlags
stdin string
}{
"success": {
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
flags: upgradeApplyFlags{yes: true},
fh: fsWithStateFile,
flags: applyFlags{yes: true},
fh: fsWithStateFileAndTfState,
fhAssertions: func(require *require.Assertions, assert *assert.Assertions, fh file.Handler) {
gotState, err := state.ReadFromFile(fh, constants.StateFilename)
require.NoError(err)
@ -71,11 +72,11 @@ func TestUpgradeApply(t *testing.T) {
assert.Equal(defaultState, gotState)
},
},
"state file does not exist": {
"id file and state file do not exist": {
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
flags: upgradeApplyFlags{yes: true},
flags: applyFlags{yes: true},
fh: func() file.Handler {
return file.NewHandler(afero.NewMemMapFs())
},
@ -89,8 +90,8 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
wantErr: true,
flags: upgradeApplyFlags{yes: true},
fh: fsWithStateFile,
flags: applyFlags{yes: true},
fh: fsWithStateFileAndTfState,
},
"nodeVersion in progress error": {
kubeUpgrader: &stubKubernetesUpgrader{
@ -99,8 +100,8 @@ func TestUpgradeApply(t *testing.T) {
},
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
flags: upgradeApplyFlags{yes: true},
fh: fsWithStateFile,
flags: applyFlags{yes: true},
fh: fsWithStateFileAndTfState,
},
"helm other error": {
kubeUpgrader: &stubKubernetesUpgrader{
@ -109,8 +110,8 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{err: assert.AnError},
terraformUpgrader: &stubTerraformUpgrader{},
wantErr: true,
flags: upgradeApplyFlags{yes: true},
fh: fsWithStateFile,
flags: applyFlags{yes: true},
fh: fsWithStateFileAndTfState,
},
"abort": {
kubeUpgrader: &stubKubernetesUpgrader{
@ -120,7 +121,7 @@ func TestUpgradeApply(t *testing.T) {
terraformUpgrader: &stubTerraformUpgrader{terraformDiff: true},
wantErr: true,
stdin: "no\n",
fh: fsWithStateFile,
fh: fsWithStateFileAndTfState,
},
"abort, restore terraform err": {
kubeUpgrader: &stubKubernetesUpgrader{
@ -130,7 +131,7 @@ func TestUpgradeApply(t *testing.T) {
terraformUpgrader: &stubTerraformUpgrader{terraformDiff: true, rollbackWorkspaceErr: assert.AnError},
wantErr: true,
stdin: "no\n",
fh: fsWithStateFile,
fh: fsWithStateFileAndTfState,
},
"plan terraform error": {
kubeUpgrader: &stubKubernetesUpgrader{
@ -139,8 +140,8 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{planTerraformErr: assert.AnError},
wantErr: true,
flags: upgradeApplyFlags{yes: true},
fh: fsWithStateFile,
flags: applyFlags{yes: true},
fh: fsWithStateFileAndTfState,
},
"apply terraform error": {
kubeUpgrader: &stubKubernetesUpgrader{
@ -152,8 +153,8 @@ func TestUpgradeApply(t *testing.T) {
terraformDiff: true,
},
wantErr: true,
flags: upgradeApplyFlags{yes: true},
fh: fsWithStateFile,
flags: applyFlags{yes: true},
fh: fsWithStateFileAndTfState,
},
"outdated K8s patch version": {
kubeUpgrader: &stubKubernetesUpgrader{
@ -166,8 +167,8 @@ func TestUpgradeApply(t *testing.T) {
require.NoError(t, err)
return semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String()
}(),
flags: upgradeApplyFlags{yes: true},
fh: fsWithStateFile,
flags: applyFlags{yes: true},
fh: fsWithStateFileAndTfState,
},
"outdated K8s version": {
kubeUpgrader: &stubKubernetesUpgrader{
@ -176,9 +177,9 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
customK8sVersion: "v1.20.0",
flags: upgradeApplyFlags{yes: true},
flags: applyFlags{yes: true},
wantErr: true,
fh: fsWithStateFile,
fh: fsWithStateFileAndTfState,
},
"skip all upgrade phases": {
kubeUpgrader: &stubKubernetesUpgrader{
@ -186,11 +187,14 @@ func TestUpgradeApply(t *testing.T) {
},
helmUpgrader: &mockApplier{}, // mocks ensure that no methods are called
terraformUpgrader: &mockTerraformUpgrader{},
flags: upgradeApplyFlags{
skipPhases: []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase},
yes: true,
flags: applyFlags{
skipPhases: skipPhases{
skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
skipK8sPhase: struct{}{}, skipImagePhase: struct{}{},
},
yes: true,
},
fh: fsWithStateFile,
fh: fsWithStateFileAndTfState,
},
"skip all phases except node upgrade": {
kubeUpgrader: &stubKubernetesUpgrader{
@ -198,11 +202,37 @@ func TestUpgradeApply(t *testing.T) {
},
helmUpgrader: &mockApplier{}, // mocks ensure that no methods are called
terraformUpgrader: &mockTerraformUpgrader{},
flags: upgradeApplyFlags{
skipPhases: []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase},
yes: true,
flags: applyFlags{
skipPhases: skipPhases{
skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
skipK8sPhase: struct{}{},
},
yes: true,
},
fh: fsWithStateFile,
fh: fsWithStateFileAndTfState,
},
"no tf state, skip infrastructure upgrade": {
kubeUpgrader: &stubKubernetesUpgrader{
currentConfig: config.DefaultForAzureSEVSNP(),
},
helmUpgrader: &stubApplier{},
terraformUpgrader: &mockTerraformUpgrader{},
flags: applyFlags{
yes: true,
},
fh: func() file.Handler {
fh := file.NewHandler(afero.NewMemMapFs())
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState))
return fh
},
},
"attempt to change attestation variant": {
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: &config.AzureTrustedLaunch{}},
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
flags: applyFlags{yes: true},
fh: fsWithStateFileAndTfState,
wantErr: true,
},
}
@ -218,20 +248,26 @@ func TestUpgradeApply(t *testing.T) {
cfg.KubernetesVersion = versions.ValidK8sVersion(tc.customK8sVersion)
}
fh := tc.fh()
require.NoError(fh.Write(constants.AdminConfFilename, []byte{}))
require.NoError(fh.WriteYAML(constants.ConfigFilename, cfg))
require.NoError(fh.WriteJSON(constants.MasterSecretFilename, uri.MasterSecret{}))
upgrader := upgradeApplyCmd{
kubeUpgrader: tc.kubeUpgrader,
helmApplier: tc.helmUpgrader,
upgrader := &applyCmd{
fileHandler: fh,
flags: tc.flags,
log: logger.NewTest(t),
spinner: &nopSpinner{},
merger: &stubMerger{},
quotaChecker: &stubLicenseClient{},
newHelmClient: func(string, debugLog) (helmApplier, error) {
return tc.helmUpgrader, nil
},
newKubeUpgrader: func(_ io.Writer, _ string, _ debugLog) (kubernetesUpgrader, error) {
return tc.kubeUpgrader, nil
},
clusterUpgrader: tc.terraformUpgrader,
log: logger.NewTest(t),
configFetcher: stubAttestationFetcher{},
flags: tc.flags,
fileHandler: fh,
}
err := upgrader.upgradeApply(cmd, "test")
err := upgrader.apply(cmd, stubAttestationFetcher{}, "test")
if tc.wantErr {
assert.Error(err)
return
@ -255,27 +291,37 @@ func TestUpgradeApplyFlagsForSkipPhases(t *testing.T) {
cmd.Flags().Bool("force", true, "")
cmd.Flags().String("tf-log", "NONE", "")
cmd.Flags().Bool("debug", false, "")
cmd.Flags().Bool("merge-kubeconfig", false, "")
require.NoError(cmd.Flags().Set("skip-phases", "infrastructure,helm,k8s,image"))
wantPhases := skipPhases{}
wantPhases.add(skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase)
var flags upgradeApplyFlags
var flags applyFlags
err := flags.parse(cmd.Flags())
require.NoError(err)
assert.ElementsMatch(t, []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, flags.skipPhases)
assert.Equal(t, wantPhases, flags.skipPhases)
}
type stubKubernetesUpgrader struct {
nodeVersionErr error
currentConfig config.AttestationCfg
calledNodeUpgrade bool
nodeVersionErr error
currentConfig config.AttestationCfg
getClusterAttestationConfigErr error
calledNodeUpgrade bool
backupCRDsErr error
backupCRDsCalled bool
backupCRsErr error
backupCRsCalled bool
}
func (u *stubKubernetesUpgrader) BackupCRDs(_ context.Context, _ string) ([]apiextensionsv1.CustomResourceDefinition, error) {
return []apiextensionsv1.CustomResourceDefinition{}, nil
u.backupCRDsCalled = true
return []apiextensionsv1.CustomResourceDefinition{}, u.backupCRDsErr
}
func (u *stubKubernetesUpgrader) BackupCRs(_ context.Context, _ []apiextensionsv1.CustomResourceDefinition, _ string) error {
return nil
u.backupCRsCalled = true
return u.backupCRsErr
}
func (u *stubKubernetesUpgrader) UpgradeNodeVersion(_ context.Context, _ *config.Config, _, _, _ bool) error {
@ -288,7 +334,7 @@ func (u *stubKubernetesUpgrader) ApplyJoinConfig(_ context.Context, _ config.Att
}
func (u *stubKubernetesUpgrader) GetClusterAttestationConfig(_ context.Context, _ variant.Variant) (config.AttestationCfg, error) {
return u.currentConfig, nil
return u.currentConfig, u.getClusterAttestationConfigErr
}
func (u *stubKubernetesUpgrader) ExtendClusterConfigCertSANs(_ context.Context, _ []string) error {

View File

@ -287,12 +287,17 @@ func (k *KubeCmd) ExtendClusterConfigCertSANs(ctx context.Context, alternativeNa
var missingSANs []string
for _, san := range alternativeNames {
if san == "" {
continue // skip empty SANs
}
if _, ok := existingSANs[san]; !ok {
missingSANs = append(missingSANs, san)
existingSANs[san] = struct{}{} // make sure we don't add the same SAN twice
}
}
if len(missingSANs) == 0 {
k.log.Debugf("No new SANs to add to the cluster's apiserver SAN field")
return nil
}
k.log.Debugf("Extending the cluster's apiserver SAN field with the following SANs: %s\n", strings.Join(missingSANs, ", "))

View File

@ -24,7 +24,8 @@ type AttestationCfg interface {
SetMeasurements(m measurements.M)
// GetVariant returns the variant of the attestation config.
GetVariant() variant.Variant
// NewerThan returns true if the config is equal to the given config.
// EqualTo returns true if the config is equal to the given config.
// If the variant differs, an error must be returned.
EqualTo(AttestationCfg) (bool, error)
}