cli: common backend for init and upgrade apply commands (#2449)

* Use common 'apply' backend for init and upgrades
* Move unit tests to new apply backend
* Only perform Terraform migrations if state exists in cwd (#2457)
* Rework skipPhases logic

---------

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
Co-authored-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>
This commit is contained in:
Daniel Weiße 2023-10-24 15:39:18 +02:00 committed by GitHub
parent 15d249092c
commit 671cf36f0a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 1399 additions and 1018 deletions

View file

@ -21,7 +21,6 @@ go_library(
importpath = "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd", importpath = "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd",
visibility = ["//cli:__subpackages__"], visibility = ["//cli:__subpackages__"],
deps = [ deps = [
"//cli/internal/cmd/pathprefix",
"//cli/internal/libvirt", "//cli/internal/libvirt",
"//cli/internal/state", "//cli/internal/state",
"//cli/internal/terraform", "//cli/internal/terraform",

View file

@ -9,7 +9,6 @@ package cloudcmd
import ( import (
"fmt" "fmt"
"github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix"
"github.com/edgelesssys/constellation/v2/internal/cloud/azureshared" "github.com/edgelesssys/constellation/v2/internal/cloud/azureshared"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared" "github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
@ -19,27 +18,19 @@ import (
) )
// GetMarshaledServiceAccountURI returns the service account URI for the given cloud provider. // GetMarshaledServiceAccountURI returns the service account URI for the given cloud provider.
func GetMarshaledServiceAccountURI(provider cloudprovider.Provider, config *config.Config, pf pathprefix.PathPrefixer, log debugLog, fileHandler file.Handler, func GetMarshaledServiceAccountURI(config *config.Config, fileHandler file.Handler) (string, error) {
) (string, error) { switch config.GetProvider() {
log.Debugf("Getting service account URI")
switch provider {
case cloudprovider.GCP: case cloudprovider.GCP:
log.Debugf("Handling case for GCP")
log.Debugf("GCP service account key path %s", pf.PrefixPrintablePath(config.Provider.GCP.ServiceAccountKeyPath))
var key gcpshared.ServiceAccountKey var key gcpshared.ServiceAccountKey
if err := fileHandler.ReadJSON(config.Provider.GCP.ServiceAccountKeyPath, &key); err != nil { if err := fileHandler.ReadJSON(config.Provider.GCP.ServiceAccountKeyPath, &key); err != nil {
return "", fmt.Errorf("reading service account key from path %q: %w", pf.PrefixPrintablePath(config.Provider.GCP.ServiceAccountKeyPath), err) return "", fmt.Errorf("reading service account key: %w", err)
} }
log.Debugf("Read GCP service account key from path")
return key.ToCloudServiceAccountURI(), nil return key.ToCloudServiceAccountURI(), nil
case cloudprovider.AWS: case cloudprovider.AWS:
log.Debugf("Handling case for AWS")
return "", nil // AWS does not need a service account URI return "", nil // AWS does not need a service account URI
case cloudprovider.Azure:
log.Debugf("Handling case for Azure")
case cloudprovider.Azure:
authMethod := azureshared.AuthMethodUserAssignedIdentity authMethod := azureshared.AuthMethodUserAssignedIdentity
creds := azureshared.ApplicationCredentials{ creds := azureshared.ApplicationCredentials{
@ -64,10 +55,9 @@ func GetMarshaledServiceAccountURI(provider cloudprovider.Provider, config *conf
return creds.ToCloudServiceAccountURI(), nil return creds.ToCloudServiceAccountURI(), nil
case cloudprovider.QEMU: case cloudprovider.QEMU:
log.Debugf("Handling case for QEMU")
return "", nil // QEMU does not use service account keys return "", nil // QEMU does not use service account keys
default: default:
return "", fmt.Errorf("unsupported cloud provider %q", provider) return "", fmt.Errorf("unsupported cloud provider %q", config.GetProvider())
} }
} }

View file

@ -4,6 +4,10 @@ load("//bazel/go:go_test.bzl", "go_test")
go_library( go_library(
name = "cmd", name = "cmd",
srcs = [ srcs = [
"apply.go",
"applyhelm.go",
"applyinit.go",
"applyterraform.go",
"cloud.go", "cloud.go",
"cmd.go", "cmd.go",
"config.go", "config.go",
@ -94,6 +98,7 @@ go_library(
"@com_github_spf13_pflag//:pflag", "@com_github_spf13_pflag//:pflag",
"@in_gopkg_yaml_v3//:yaml_v3", "@in_gopkg_yaml_v3//:yaml_v3",
"@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions", "@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions",
"@io_k8s_apimachinery//pkg/api/errors",
"@io_k8s_apimachinery//pkg/runtime", "@io_k8s_apimachinery//pkg/runtime",
"@io_k8s_client_go//tools/clientcmd", "@io_k8s_client_go//tools/clientcmd",
"@io_k8s_client_go//tools/clientcmd/api/latest", "@io_k8s_client_go//tools/clientcmd/api/latest",
@ -115,6 +120,7 @@ go_library(
go_test( go_test(
name = "cmd_test", name = "cmd_test",
srcs = [ srcs = [
"apply_test.go",
"cloud_test.go", "cloud_test.go",
"configfetchmeasurements_test.go", "configfetchmeasurements_test.go",
"configgenerate_test.go", "configgenerate_test.go",
@ -171,12 +177,15 @@ go_test(
"@com_github_google_go_tpm_tools//proto/tpm", "@com_github_google_go_tpm_tools//proto/tpm",
"@com_github_spf13_afero//:afero", "@com_github_spf13_afero//:afero",
"@com_github_spf13_cobra//:cobra", "@com_github_spf13_cobra//:cobra",
"@com_github_spf13_pflag//:pflag",
"@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//mock", "@com_github_stretchr_testify//mock",
"@com_github_stretchr_testify//require", "@com_github_stretchr_testify//require",
"@io_k8s_api//core/v1:core", "@io_k8s_api//core/v1:core",
"@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions", "@io_k8s_apiextensions_apiserver//pkg/apis/apiextensions/v1:apiextensions",
"@io_k8s_apimachinery//pkg/api/errors",
"@io_k8s_apimachinery//pkg/apis/meta/v1:meta", "@io_k8s_apimachinery//pkg/apis/meta/v1:meta",
"@io_k8s_apimachinery//pkg/runtime/schema",
"@io_k8s_client_go//tools/clientcmd", "@io_k8s_client_go//tools/clientcmd",
"@io_k8s_client_go//tools/clientcmd/api", "@io_k8s_client_go//tools/clientcmd/api",
"@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//:go_default_library",

515
cli/internal/cmd/apply.go Normal file
View file

@ -0,0 +1,515 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/fs"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/compatibility"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
)
// applyFlags defines the flags for the apply command.
type applyFlags struct {
rootFlags
yes bool
conformance bool
mergeConfigs bool
upgradeTimeout time.Duration
helmWaitMode helm.WaitMode
skipPhases skipPhases
}
// parse the apply command flags.
func (f *applyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
rawSkipPhases, err := flags.GetStringSlice("skip-phases")
if err != nil {
return fmt.Errorf("getting 'skip-phases' flag: %w", err)
}
var skipPhases skipPhases
for _, phase := range rawSkipPhases {
switch skipPhase(strings.ToLower(phase)) {
case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase:
skipPhases.add(skipPhase(phase))
default:
return fmt.Errorf("invalid phase %s", phase)
}
}
f.skipPhases = skipPhases
f.yes, err = flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.upgradeTimeout, err = flags.GetDuration("timeout")
if err != nil {
return fmt.Errorf("getting 'timeout' flag: %w", err)
}
f.conformance, err = flags.GetBool("conformance")
if err != nil {
return fmt.Errorf("getting 'conformance' flag: %w", err)
}
skipHelmWait, err := flags.GetBool("skip-helm-wait")
if err != nil {
return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err)
}
f.helmWaitMode = helm.WaitModeAtomic
if skipHelmWait {
f.helmWaitMode = helm.WaitModeNone
}
f.mergeConfigs, err = flags.GetBool("merge-kubeconfig")
if err != nil {
return fmt.Errorf("getting 'merge-kubeconfig' flag: %w", err)
}
return nil
}
// runApply sets up the apply command and runs it.
func runApply(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
}
defer log.Sync()
spinner, err := newSpinnerOrStderr(cmd)
if err != nil {
return err
}
defer spinner.Stop()
flags := applyFlags{}
if err := flags.parse(cmd.Flags()); err != nil {
return err
}
fileHandler := file.NewHandler(afero.NewOsFs())
newDialer := func(validator atls.Validator) *dialer.Dialer {
return dialer.New(nil, validator, &net.Dialer{})
}
newKubeUpgrader := func(w io.Writer, kubeConfigPath string, log debugLog) (kubernetesUpgrader, error) {
return kubecmd.New(w, kubeConfigPath, fileHandler, log)
}
newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) {
return helm.NewClient(kubeConfigPath, log)
}
upgradeID := generateUpgradeID(upgradeCmdKindApply)
upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID)
clusterUpgrader, err := cloudcmd.NewClusterUpgrader(
cmd.Context(),
constants.TerraformWorkingDir,
upgradeDir,
flags.tfLogLevel,
fileHandler,
)
if err != nil {
return fmt.Errorf("setting up cluster upgrader: %w", err)
}
apply := &applyCmd{
fileHandler: fileHandler,
flags: flags,
log: log,
spinner: spinner,
merger: &kubeconfigMerger{log: log},
quotaChecker: license.NewClient(),
newHelmClient: newHelmClient,
newDialer: newDialer,
newKubeUpgrader: newKubeUpgrader,
clusterUpgrader: clusterUpgrader,
}
ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour)
defer cancel()
cmd.SetContext(ctx)
return apply.apply(cmd, attestationconfigapi.NewFetcher(), upgradeDir)
}
type applyCmd struct {
fileHandler file.Handler
flags applyFlags
log debugLog
spinner spinnerInterf
merger configMerger
quotaChecker license.QuotaChecker
newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error)
newDialer func(validator atls.Validator) *dialer.Dialer
newKubeUpgrader func(io.Writer, string, debugLog) (kubernetesUpgrader, error)
clusterUpgrader clusterUpgrader
}
/*
apply updates a Constellation cluster by applying a user's config.
The control flow is as follows:
Parse Flags
Read Config
Read State-File
Validate input
Check if Terraform state is up to date
Not up to date
(Diff from Terraform plan)
Terraform
Phase
Apply Terraform updates
Check for constellation-admin.conf
File does not exist
Run Bootstrapper Init RPC
File does exist
Init
Write constellation-admin.conf Phase
Prepare "Init Success" Message
Apply Attestation Config
Extend API Server Cert SANs
Helm
Apply Helm Charts Phase
Can be skipped Upgrade NodeVersion object K8s/Image
if we ran Init RP (Image and K8s update) Phase
Write success output
*/
func (a *applyCmd) apply(cmd *cobra.Command, configFetcher attestationconfigapi.Fetcher, upgradeDir string) error {
// Validate inputs
conf, stateFile, err := a.validateInputs(cmd, configFetcher)
if err != nil {
return err
}
// Now start actually running the apply command
// Check current Terraform state, if it exists and infrastructure upgrades are not skipped,
// and apply migrations if necessary.
if !a.flags.skipPhases.contains(skipInfrastructurePhase) {
if err := a.runTerraformApply(cmd, conf, stateFile, upgradeDir); err != nil {
return fmt.Errorf("applying Terraform configuration : %w", err)
}
}
bufferedOutput := &bytes.Buffer{}
// Run init RPC if required
if !a.flags.skipPhases.contains(skipInitPhase) {
bufferedOutput, err = a.runInit(cmd, conf, stateFile)
if err != nil {
return err
}
}
// From now on we can assume a valid Kubernetes admin config file exists
kubeUpgrader, err := a.newKubeUpgrader(cmd.OutOrStdout(), constants.AdminConfFilename, a.log)
if err != nil {
return err
}
// Apply Attestation Config
a.log.Debugf("Creating Kubernetes client using %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
a.log.Debugf("Applying new attestation config to cluster")
if err := a.applyJoinConfig(cmd, kubeUpgrader, conf.GetAttestationConfig(), stateFile.ClusterValues.MeasurementSalt); err != nil {
return fmt.Errorf("applying attestation config: %w", err)
}
// Extend API Server Cert SANs
sans := append([]string{stateFile.Infrastructure.ClusterEndpoint, conf.CustomEndpoint}, stateFile.Infrastructure.APIServerCertSANs...)
if err := kubeUpgrader.ExtendClusterConfigCertSANs(cmd.Context(), sans); err != nil {
return fmt.Errorf("extending cert SANs: %w", err)
}
// Apply Helm Charts
if !a.flags.skipPhases.contains(skipHelmPhase) {
if err := a.runHelmApply(cmd, conf, stateFile, kubeUpgrader, upgradeDir); err != nil {
return err
}
}
// Upgrade NodeVersion object
// This can be skipped if we ran the init RPC, as the NodeVersion object is already up to date
if !(a.flags.skipPhases.contains(skipK8sPhase) && a.flags.skipPhases.contains(skipImagePhase)) &&
a.flags.skipPhases.contains(skipInitPhase) {
if err := a.runK8sUpgrade(cmd, conf, kubeUpgrader); err != nil {
return err
}
}
// Write success output
cmd.Print(bufferedOutput.String())
return nil
}
func (a *applyCmd) validateInputs(cmd *cobra.Command, configFetcher attestationconfigapi.Fetcher) (*config.Config, *state.State, error) {
// Read user's config and state file
a.log.Debugf("Reading config from %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(a.fileHandler, constants.ConfigFilename, configFetcher, a.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
}
if err != nil {
return nil, nil, err
}
a.log.Debugf("Reading state file from %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
stateFile, err := state.ReadFromFile(a.fileHandler, constants.StateFilename)
if err != nil {
return nil, nil, err
}
// Check license
a.log.Debugf("Running license check")
checker := license.NewChecker(a.quotaChecker, a.fileHandler)
if err := checker.CheckLicense(cmd.Context(), conf.GetProvider(), conf.Provider, cmd.Printf); err != nil {
cmd.PrintErrf("License check failed: %v", err)
}
a.log.Debugf("Checked license")
// Check if we already have a running Kubernetes cluster
// by checking if the Kubernetes admin config file exists
// If it exist, we skip the init phase
// If it does not exist, we need to run the init RPC first
// This may break things further down the line
// It is the user's responsibility to make sure the cluster is in a valid state
a.log.Debugf("Checking if %s exists", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
if _, err := a.fileHandler.Stat(constants.AdminConfFilename); err == nil {
a.flags.skipPhases.add(skipInitPhase)
} else if !errors.Is(err, os.ErrNotExist) {
return nil, nil, fmt.Errorf("checking for %s: %w", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename), err)
}
a.log.Debugf("Init RPC required: %t", !a.flags.skipPhases.contains(skipInitPhase))
// Validate input arguments
// Validate Kubernetes version as set in the user's config
// If we need to run the init RPC, the version has to be valid
// Otherwise, we are able to use an outdated version, meaning we skip the K8s upgrade
// We skip version validation if the user explicitly skips the Kubernetes phase
a.log.Debugf("Validating Kubernetes version %s", conf.KubernetesVersion)
validVersion, err := versions.NewValidK8sVersion(string(conf.KubernetesVersion), true)
if err != nil && !a.flags.skipPhases.contains(skipK8sPhase) {
a.log.Debugf("Kubernetes version not valid: %s", err)
if !a.flags.skipPhases.contains(skipInitPhase) {
return nil, nil, err
}
a.log.Debugf("Checking if user wants to continue anyway")
if !a.flags.yes {
confirmed, err := askToConfirm(cmd,
fmt.Sprintf(
"WARNING: The Kubernetes patch version %s is not supported. If you continue, Kubernetes upgrades will be skipped. Do you want to continue anyway?",
validVersion,
),
)
if err != nil {
return nil, nil, fmt.Errorf("asking for confirmation: %w", err)
}
if !confirmed {
return nil, nil, fmt.Errorf("aborted by user")
}
}
a.flags.skipPhases.add(skipK8sPhase)
a.log.Debugf("Outdated Kubernetes version accepted, Kubernetes upgrade will be skipped")
}
if versions.IsPreviewK8sVersion(validVersion) {
cmd.PrintErrf("Warning: Constellation with Kubernetes %s is still in preview. Use only for evaluation purposes.\n", validVersion)
}
conf.KubernetesVersion = validVersion
a.log.Debugf("Target Kubernetes version set to %s", conf.KubernetesVersion)
// Validate microservice version (helm versions) in the user's config matches the version of the CLI
// This makes sure we catch potential errors early, not just after we already ran Terraform migrations or the init RPC
if !a.flags.force && !a.flags.skipPhases.contains(skipHelmPhase) && !a.flags.skipPhases.contains(skipInitPhase) {
if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil {
return nil, nil, err
}
}
// Constellation on QEMU or OpenStack don't support upgrades
// If using one of those providers, make sure the command is only used to initialize a cluster
if !(conf.GetProvider() == cloudprovider.AWS || conf.GetProvider() == cloudprovider.Azure || conf.GetProvider() == cloudprovider.GCP) {
if a.flags.skipPhases.contains(skipInitPhase) {
return nil, nil, fmt.Errorf("upgrades are not supported for provider %s", conf.GetProvider())
}
// Skip Terraform phase
a.log.Debugf("Skipping Infrastructure upgrade")
a.flags.skipPhases.add(skipInfrastructurePhase)
}
// Check if Terraform state exists
if tfStateExists, err := a.tfStateExists(); err != nil {
return nil, nil, fmt.Errorf("checking Terraform state: %w", err)
} else if !tfStateExists {
a.flags.skipPhases.add(skipInfrastructurePhase)
a.log.Debugf("No Terraform state found in current working directory. Assuming self-managed infrastructure. Infrastructure upgrades will not be performed.")
}
// Print warning about AWS attestation
// TODO(derpsteb): remove once AWS fixes SEV-SNP attestation provisioning issues
if !a.flags.skipPhases.contains(skipInitPhase) && conf.GetAttestationConfig().GetVariant().Equal(variant.AWSSEVSNP{}) {
cmd.PrintErrln("WARNING: Attestation temporarily relies on AWS nitroTPM. See https://docs.edgeless.systems/constellation/workflows/config#choosing-a-vm-type for more information.")
}
return conf, stateFile, nil
}
// applyJoincConfig creates or updates the cluster's join config.
// If the config already exists, and is different from the new config, the user is asked to confirm the upgrade.
func (a *applyCmd) applyJoinConfig(
cmd *cobra.Command, kubeUpgrader kubernetesUpgrader, newConfig config.AttestationCfg, measurementSalt []byte,
) error {
clusterAttestationConfig, err := kubeUpgrader.GetClusterAttestationConfig(cmd.Context(), newConfig.GetVariant())
if err != nil {
a.log.Debugf("Getting cluster attestation config failed: %s", err)
if k8serrors.IsNotFound(err) {
a.log.Debugf("Creating new join config")
return kubeUpgrader.ApplyJoinConfig(cmd.Context(), newConfig, measurementSalt)
}
return fmt.Errorf("getting cluster attestation config: %w", err)
}
// If the current config is equal, or there is an error when comparing the configs, we skip the upgrade.
equal, err := newConfig.EqualTo(clusterAttestationConfig)
if err != nil {
return fmt.Errorf("comparing attestation configs: %w", err)
}
if equal {
a.log.Debugf("Current attestation config is equal to the new config, nothing to do")
return nil
}
cmd.Println("The configured attestation config is different from the attestation config in the cluster.")
diffStr, err := diffAttestationCfg(clusterAttestationConfig, newConfig)
if err != nil {
return fmt.Errorf("diffing attestation configs: %w", err)
}
cmd.Println("The following changes will be applied to the attestation config:")
cmd.Println(diffStr)
if !a.flags.yes {
ok, err := askToConfirm(cmd, "Are you sure you want to change your cluster's attestation config?")
if err != nil {
return fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
return errors.New("aborting upgrade since attestation config is different")
}
}
if err := kubeUpgrader.ApplyJoinConfig(cmd.Context(), newConfig, measurementSalt); err != nil {
return fmt.Errorf("updating attestation config: %w", err)
}
cmd.Println("Successfully updated the cluster's attestation config")
return nil
}
// runK8sUpgrade upgrades image and Kubernetes version of the Constellation cluster.
func (a *applyCmd) runK8sUpgrade(cmd *cobra.Command, conf *config.Config, kubeUpgrader kubernetesUpgrader,
) error {
err := kubeUpgrader.UpgradeNodeVersion(
cmd.Context(), conf, a.flags.force,
a.flags.skipPhases.contains(skipK8sPhase),
a.flags.skipPhases.contains(skipImagePhase),
)
var upgradeErr *compatibility.InvalidUpgradeError
switch {
case errors.Is(err, kubecmd.ErrInProgress):
cmd.PrintErrln("Skipping image and Kubernetes upgrades. Another upgrade is in progress.")
case errors.As(err, &upgradeErr):
cmd.PrintErrln(err)
case err != nil:
return fmt.Errorf("upgrading NodeVersion: %w", err)
}
return nil
}
// tfStateExists checks whether a Constellation Terraform state exists in the current working directory.
func (a *applyCmd) tfStateExists() (bool, error) {
_, err := a.fileHandler.Stat(constants.TerraformWorkingDir)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return false, nil
}
return false, fmt.Errorf("reading Terraform state: %w", err)
}
return true, nil
}

View file

@ -0,0 +1,153 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"context"
"fmt"
"testing"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseApplyFlags(t *testing.T) {
require := require.New(t)
// TODO: Use flags := applyCmd().Flags() once we have a separate apply command
defaultFlags := func() *pflag.FlagSet {
flags := pflag.NewFlagSet("test", pflag.ContinueOnError)
flags.String("workspace", "", "")
flags.String("tf-log", "NONE", "")
flags.Bool("force", false, "")
flags.Bool("debug", false, "")
flags.Bool("merge-kubeconfig", false, "")
flags.Bool("conformance", false, "")
flags.Bool("skip-helm-wait", false, "")
flags.Bool("yes", false, "")
flags.StringSlice("skip-phases", []string{}, "")
flags.Duration("timeout", 0, "")
return flags
}
testCases := map[string]struct {
flags *pflag.FlagSet
wantFlags applyFlags
wantErr bool
}{
"default flags": {
flags: defaultFlags(),
wantFlags: applyFlags{
helmWaitMode: helm.WaitModeAtomic,
},
},
"skip phases": {
flags: func() *pflag.FlagSet {
flags := defaultFlags()
require.NoError(flags.Set("skip-phases", fmt.Sprintf("%s,%s", skipHelmPhase, skipK8sPhase)))
return flags
}(),
wantFlags: applyFlags{
skipPhases: skipPhases{skipHelmPhase: struct{}{}, skipK8sPhase: struct{}{}},
helmWaitMode: helm.WaitModeAtomic,
},
},
"skip helm wait": {
flags: func() *pflag.FlagSet {
flags := defaultFlags()
require.NoError(flags.Set("skip-helm-wait", "true"))
return flags
}(),
wantFlags: applyFlags{
helmWaitMode: helm.WaitModeNone,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
var flags applyFlags
err := flags.parse(tc.flags)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantFlags, flags)
})
}
}
func TestBackupHelmCharts(t *testing.T) {
testCases := map[string]struct {
helmApplier helm.Applier
backupClient *stubKubernetesUpgrader
includesUpgrades bool
wantErr bool
}{
"success, no upgrades": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{},
},
"success with upgrades": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{},
includesUpgrades: true,
},
"saving charts fails": {
helmApplier: &stubRunner{
saveChartsErr: assert.AnError,
},
backupClient: &stubKubernetesUpgrader{},
wantErr: true,
},
"backup CRDs fails": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{
backupCRDsErr: assert.AnError,
},
includesUpgrades: true,
wantErr: true,
},
"backup CRs fails": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{
backupCRsErr: assert.AnError,
},
includesUpgrades: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
a := applyCmd{
fileHandler: file.NewHandler(afero.NewMemMapFs()),
log: logger.NewTest(t),
}
err := a.backupHelmCharts(context.Background(), tc.backupClient, tc.helmApplier, tc.includesUpgrades, "")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
if tc.includesUpgrades {
assert.True(tc.backupClient.backupCRDsCalled)
assert.True(tc.backupClient.backupCRsCalled)
}
})
}
}

View file

@ -0,0 +1,125 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"context"
"errors"
"fmt"
"path/filepath"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/compatibility"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/spf13/cobra"
)
// runHelmApply handles installing or upgrading helm charts for the cluster.
func (a *applyCmd) runHelmApply(
cmd *cobra.Command, conf *config.Config, stateFile *state.State,
kubeUpgrader kubernetesUpgrader, upgradeDir string,
) error {
a.log.Debugf("Installing or upgrading Helm charts")
var masterSecret uri.MasterSecret
if err := a.fileHandler.ReadJSON(constants.MasterSecretFilename, &masterSecret); err != nil {
return fmt.Errorf("reading master secret: %w", err)
}
options := helm.Options{
Force: a.flags.force,
Conformance: a.flags.conformance,
HelmWaitMode: a.flags.helmWaitMode,
AllowDestructive: helm.DenyDestructive,
}
helmApplier, err := a.newHelmClient(constants.AdminConfFilename, a.log)
if err != nil {
return fmt.Errorf("creating Helm client: %w", err)
}
a.log.Debugf("Getting service account URI")
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(conf, a.fileHandler)
if err != nil {
return err
}
a.log.Debugf("Preparing Helm charts")
executor, includesUpgrades, err := helmApplier.PrepareApply(conf, stateFile, options, serviceAccURI, masterSecret)
if errors.Is(err, helm.ErrConfirmationMissing) {
if !a.flags.yes {
cmd.PrintErrln("WARNING: Upgrading cert-manager will destroy all custom resources you have manually created that are based on the current version of cert-manager.")
ok, askErr := askToConfirm(cmd, "Do you want to upgrade cert-manager anyway?")
if askErr != nil {
return fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
cmd.Println("Skipping upgrade.")
return nil
}
}
options.AllowDestructive = helm.AllowDestructive
executor, includesUpgrades, err = helmApplier.PrepareApply(conf, stateFile, options, serviceAccURI, masterSecret)
}
var upgradeErr *compatibility.InvalidUpgradeError
if err != nil {
if !errors.As(err, &upgradeErr) {
return fmt.Errorf("preparing Helm charts: %w", err)
}
cmd.PrintErrln(err)
}
a.log.Debugf("Backing up Helm charts")
if err := a.backupHelmCharts(cmd.Context(), kubeUpgrader, executor, includesUpgrades, upgradeDir); err != nil {
return err
}
a.log.Debugf("Applying Helm charts")
if !a.flags.skipPhases.contains(skipInitPhase) {
a.spinner.Start("Installing Kubernetes components ", false)
} else {
a.spinner.Start("Upgrading Kubernetes components ", false)
}
if err := executor.Apply(cmd.Context()); err != nil {
return fmt.Errorf("applying Helm charts: %w", err)
}
a.spinner.Stop()
if a.flags.skipPhases.contains(skipInitPhase) {
cmd.Println("Successfully upgraded Constellation services.")
}
return nil
}
// backupHelmCharts saves the Helm charts for the upgrade to disk and creates a backup of existing CRDs and CRs.
func (a *applyCmd) backupHelmCharts(
ctx context.Context, kubeUpgrader kubernetesUpgrader, executor helm.Applier, includesUpgrades bool, upgradeDir string,
) error {
// Save the Helm charts for the upgrade to disk
chartDir := filepath.Join(upgradeDir, "helm-charts")
if err := executor.SaveCharts(chartDir, a.fileHandler); err != nil {
return fmt.Errorf("saving Helm charts to disk: %w", err)
}
a.log.Debugf("Helm charts saved to %s", a.flags.pathPrefixer.PrefixPrintablePath(chartDir))
if includesUpgrades {
a.log.Debugf("Creating backup of CRDs and CRs")
crds, err := kubeUpgrader.BackupCRDs(ctx, upgradeDir)
if err != nil {
return fmt.Errorf("creating CRD backup: %w", err)
}
if err := kubeUpgrader.BackupCRs(ctx, crds, upgradeDir); err != nil {
return fmt.Errorf("creating CR backup: %w", err)
}
}
return nil
}

View file

@ -0,0 +1,238 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"bytes"
"context"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/url"
"path/filepath"
"strconv"
"text/tabwriter"
"time"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"k8s.io/client-go/tools/clientcmd"
)
// runInit runs the init RPC to set up the Kubernetes cluster.
// This function only needs to be run once per cluster.
// On success, it writes the Kubernetes admin config file to disk.
// Therefore it is skipped if the Kubernetes admin config file already exists.
func (a *applyCmd) runInit(cmd *cobra.Command, conf *config.Config, stateFile *state.State) (*bytes.Buffer, error) {
a.log.Debugf("Running init RPC")
a.log.Debugf("Creating aTLS Validator for %s", conf.GetAttestationConfig().GetVariant())
validator, err := cloudcmd.NewValidator(cmd, conf.GetAttestationConfig(), a.log)
if err != nil {
return nil, fmt.Errorf("creating new validator: %w", err)
}
a.log.Debugf("Generating master secret")
masterSecret, err := a.generateAndPersistMasterSecret(cmd.OutOrStdout())
if err != nil {
return nil, fmt.Errorf("generating master secret: %w", err)
}
a.log.Debugf("Generated master secret key and salt values")
a.log.Debugf("Generating measurement salt")
measurementSalt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return nil, fmt.Errorf("generating measurement salt: %w", err)
}
a.spinner.Start("Connecting ", false)
req := &initproto.InitRequest{
KmsUri: masterSecret.EncodeToURI(),
StorageUri: uri.NoStoreURI,
MeasurementSalt: measurementSalt,
KubernetesVersion: versions.VersionConfigs[conf.KubernetesVersion].ClusterVersion,
KubernetesComponents: versions.VersionConfigs[conf.KubernetesVersion].KubernetesComponents.ToInitProto(),
ConformanceMode: a.flags.conformance,
InitSecret: stateFile.Infrastructure.InitSecret,
ClusterName: stateFile.Infrastructure.Name,
ApiserverCertSans: stateFile.Infrastructure.APIServerCertSANs,
}
a.log.Debugf("Sending initialization request")
resp, err := a.initCall(cmd.Context(), a.newDialer(validator), stateFile.Infrastructure.ClusterEndpoint, req)
a.spinner.Stop()
a.log.Debugf("Initialization request finished")
if err != nil {
var nonRetriable *nonRetriableError
if errors.As(err, &nonRetriable) {
cmd.PrintErrln("Cluster initialization failed. This error is not recoverable.")
cmd.PrintErrln("Terminate your cluster and try again.")
if nonRetriable.logCollectionErr != nil {
cmd.PrintErrf("Failed to collect logs from bootstrapper: %s\n", nonRetriable.logCollectionErr)
} else {
cmd.PrintErrf("Fetched bootstrapper logs are stored in %q\n", a.flags.pathPrefixer.PrefixPrintablePath(constants.ErrorLog))
}
}
return nil, err
}
a.log.Debugf("Initialization request successful")
a.log.Debugf("Buffering init success message")
bufferedOutput := &bytes.Buffer{}
if err := a.writeInitOutput(stateFile, resp, a.flags.mergeConfigs, bufferedOutput, measurementSalt); err != nil {
return nil, err
}
return bufferedOutput, nil
}
// initCall performs the gRPC call to the bootstrapper to initialize the cluster.
func (a *applyCmd) initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.InitRequest) (*initproto.InitSuccessResponse, error) {
doer := &initDoer{
dialer: dialer,
endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.BootstrapperPort)),
req: req,
log: a.log,
spinner: a.spinner,
fh: file.NewHandler(afero.NewOsFs()),
}
// Create a wrapper function that allows logging any returned error from the retrier before checking if it's the expected retriable one.
serviceIsUnavailable := func(err error) bool {
isServiceUnavailable := grpcRetry.ServiceIsUnavailable(err)
a.log.Debugf("Encountered error (retriable: %t): %s", isServiceUnavailable, err)
return isServiceUnavailable
}
a.log.Debugf("Making initialization call, doer is %+v", doer)
retrier := retry.NewIntervalRetrier(doer, 30*time.Second, serviceIsUnavailable)
if err := retrier.Do(ctx); err != nil {
return nil, err
}
return doer.resp, nil
}
// generateAndPersistMasterSecret generates a 32 byte master secret and saves it to disk.
func (a *applyCmd) generateAndPersistMasterSecret(outWriter io.Writer) (uri.MasterSecret, error) {
// No file given, generate a new secret, and save it to disk
key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault)
if err != nil {
return uri.MasterSecret{}, err
}
salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return uri.MasterSecret{}, err
}
secret := uri.MasterSecret{
Key: key,
Salt: salt,
}
if err := a.fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
return uri.MasterSecret{}, err
}
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to %q\n", a.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename))
return secret, nil
}
// writeInitOutput writes the output of a cluster initialization to the
// state- / kubeconfig-file and saves it to disk.
func (a *applyCmd) writeInitOutput(
stateFile *state.State, initResp *initproto.InitSuccessResponse,
mergeConfig bool, wr io.Writer, measurementSalt []byte,
) error {
fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n")
ownerID := hex.EncodeToString(initResp.GetOwnerId())
clusterID := hex.EncodeToString(initResp.GetClusterId())
stateFile.SetClusterValues(state.ClusterValues{
MeasurementSalt: measurementSalt,
OwnerID: ownerID,
ClusterID: clusterID,
})
tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0)
writeRow(tw, "Constellation cluster identifier", clusterID)
writeRow(tw, "Kubernetes configuration", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
tw.Flush()
fmt.Fprintln(wr)
a.log.Debugf("Rewriting cluster server address in kubeconfig to %s", stateFile.Infrastructure.ClusterEndpoint)
kubeconfig, err := clientcmd.Load(initResp.GetKubeconfig())
if err != nil {
return fmt.Errorf("loading kubeconfig: %w", err)
}
if len(kubeconfig.Clusters) != 1 {
return fmt.Errorf("expected exactly one cluster in kubeconfig, got %d", len(kubeconfig.Clusters))
}
for _, cluster := range kubeconfig.Clusters {
kubeEndpoint, err := url.Parse(cluster.Server)
if err != nil {
return fmt.Errorf("parsing kubeconfig server URL: %w", err)
}
kubeEndpoint.Host = net.JoinHostPort(stateFile.Infrastructure.ClusterEndpoint, kubeEndpoint.Port())
cluster.Server = kubeEndpoint.String()
}
kubeconfigBytes, err := clientcmd.Write(*kubeconfig)
if err != nil {
return fmt.Errorf("marshaling kubeconfig: %w", err)
}
if err := a.fileHandler.Write(constants.AdminConfFilename, kubeconfigBytes, file.OptNone); err != nil {
return fmt.Errorf("writing kubeconfig: %w", err)
}
a.log.Debugf("Kubeconfig written to %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
if mergeConfig {
if err := a.merger.mergeConfigs(constants.AdminConfFilename, a.fileHandler); err != nil {
writeRow(tw, "Failed to automatically merge kubeconfig", err.Error())
mergeConfig = false // Set to false so we don't print the wrong message below.
} else {
writeRow(tw, "Kubernetes configuration merged with default config", "")
}
}
if err := stateFile.WriteToFile(a.fileHandler, constants.StateFilename); err != nil {
return fmt.Errorf("writing Constellation state file: %w", err)
}
a.log.Debugf("Constellation state file written to %s", a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
if !mergeConfig {
fmt.Fprintln(wr, "You can now connect to your cluster by executing:")
exportPath, err := filepath.Abs(constants.AdminConfFilename)
if err != nil {
return fmt.Errorf("getting absolute path to kubeconfig: %w", err)
}
fmt.Fprintf(wr, "\texport KUBECONFIG=%q\n", exportPath)
} else {
fmt.Fprintln(wr, "Constellation kubeconfig merged with default config.")
if a.merger.kubeconfigEnvVar() != "" {
fmt.Fprintln(wr, "Warning: KUBECONFIG environment variable is set.")
fmt.Fprintln(wr, "You may need to unset it to use the default config and connect to your cluster.")
} else {
fmt.Fprintln(wr, "You can now connect to your cluster.")
}
}
fmt.Fprintln(wr) // add final newline
return nil
}

View file

@ -0,0 +1,115 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"fmt"
"path/filepath"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/spf13/cobra"
)
// runTerraformApply checks if changes to Terraform are required and applies them.
func (a *applyCmd) runTerraformApply(cmd *cobra.Command, conf *config.Config, stateFile *state.State, upgradeDir string) error {
a.log.Debugf("Checking if Terraform migrations are required")
migrationRequired, err := a.planTerraformMigration(cmd, conf)
if err != nil {
return fmt.Errorf("planning Terraform migrations: %w", err)
}
if !migrationRequired {
a.log.Debugf("No changes to infrastructure required, skipping Terraform migrations")
return nil
}
a.log.Debugf("Migrating terraform resources for infrastructure changes")
postMigrationInfraState, err := a.migrateTerraform(cmd, conf, upgradeDir)
if err != nil {
return fmt.Errorf("performing Terraform migrations: %w", err)
}
// Merge the pre-upgrade state with the post-migration infrastructure values
a.log.Debugf("Updating state file with new infrastructure state")
if _, err := stateFile.Merge(
// temporary state with post-migration infrastructure values
state.New().SetInfrastructure(postMigrationInfraState),
); err != nil {
return fmt.Errorf("merging pre-upgrade state with post-migration infrastructure values: %w", err)
}
// Write the post-migration state to disk
if err := stateFile.WriteToFile(a.fileHandler, constants.StateFilename); err != nil {
return fmt.Errorf("writing state file: %w", err)
}
return nil
}
// planTerraformMigration checks if the Constellation version the cluster is being upgraded to requires a migration.
func (a *applyCmd) planTerraformMigration(cmd *cobra.Command, conf *config.Config) (bool, error) {
a.log.Debugf("Planning Terraform migrations")
vars, err := cloudcmd.TerraformUpgradeVars(conf)
if err != nil {
return false, fmt.Errorf("parsing upgrade variables: %w", err)
}
a.log.Debugf("Using Terraform variables:\n%+v", vars)
// Check if there are any Terraform migrations to apply
// Add manual migrations here if required
//
// var manualMigrations []terraform.StateMigration
// for _, migration := range manualMigrations {
// u.log.Debugf("Adding manual Terraform migration: %s", migration.DisplayName)
// u.upgrader.AddManualStateMigration(migration)
// }
return a.clusterUpgrader.PlanClusterUpgrade(cmd.Context(), cmd.OutOrStdout(), vars, conf.GetProvider())
}
// migrateTerraform migrates an existing Terraform state and the post-migration infrastructure state is returned.
func (a *applyCmd) migrateTerraform(cmd *cobra.Command, conf *config.Config, upgradeDir string) (state.Infrastructure, error) {
// Ask for confirmation first
fmt.Fprintln(cmd.OutOrStdout(), "The upgrade requires a migration of Constellation cloud resources by applying an updated Terraform template. Please manually review the suggested changes below.")
if !a.flags.yes {
ok, err := askToConfirm(cmd, "Do you want to apply the Terraform migrations?")
if err != nil {
return state.Infrastructure{}, fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
cmd.Println("Aborting upgrade.")
// User doesn't expect to see any changes in his workspace after aborting an "upgrade apply",
// therefore, roll back to the backed up state.
if err := a.clusterUpgrader.RestoreClusterWorkspace(); err != nil {
return state.Infrastructure{}, fmt.Errorf(
"restoring Terraform workspace: %w, restore the Terraform workspace manually from %s ",
err,
filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir),
)
}
return state.Infrastructure{}, fmt.Errorf("cluster upgrade aborted by user")
}
}
a.log.Debugf("Applying Terraform migrations")
a.spinner.Start("Migrating Terraform resources", false)
infraState, err := a.clusterUpgrader.ApplyClusterUpgrade(cmd.Context(), conf.GetProvider())
a.spinner.Stop()
if err != nil {
return state.Infrastructure{}, fmt.Errorf("applying terraform migrations: %w", err)
}
cmd.Printf("Infrastructure migrations applied successfully and output written to: %s\n"+
"A backup of the pre-upgrade state has been written to: %s\n",
a.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename),
a.flags.pathPrefixer.PrefixPrintablePath(filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir)),
)
return infraState, nil
}

View file

@ -7,28 +7,15 @@ SPDX-License-Identifier: AGPL-3.0-only
package cmd package cmd
import ( import (
"bytes"
"context" "context"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/url"
"os" "os"
"path/filepath"
"strconv"
"sync" "sync"
"text/tabwriter"
"time" "time"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"google.golang.org/grpc" "google.golang.org/grpc"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/clientcmd"
@ -36,21 +23,13 @@ import (
"sigs.k8s.io/yaml" "sigs.k8s.io/yaml"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/helm" "github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog" "github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
"github.com/edgelesssys/constellation/v2/internal/kms/uri" "github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/versions"
) )
// NewInitCmd returns a new cobra.Command for the init command. // NewInitCmd returns a new cobra.Command for the init command.
@ -61,7 +40,15 @@ func NewInitCmd() *cobra.Command {
Long: "Initialize the Constellation cluster.\n\n" + Long: "Initialize the Constellation cluster.\n\n" +
"Start your confidential Kubernetes.", "Start your confidential Kubernetes.",
Args: cobra.ExactArgs(0), Args: cobra.ExactArgs(0),
RunE: runInitialize, RunE: func(cmd *cobra.Command, args []string) error {
// Define flags for apply backend that are not set by init
cmd.Flags().Bool("yes", false, "")
// Don't skip any phases
// The apply backend should handle init calls correctly
cmd.Flags().StringSlice("skip-phases", []string{}, "")
cmd.Flags().Duration("timeout", time.Hour, "")
return runApply(cmd, args)
},
} }
cmd.Flags().Bool("conformance", false, "enable conformance mode") cmd.Flags().Bool("conformance", false, "enable conformance mode")
cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready") cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready")
@ -69,270 +56,6 @@ func NewInitCmd() *cobra.Command {
return cmd return cmd
} }
// initFlags are flags used by the init command.
type initFlags struct {
rootFlags
conformance bool
helmWaitMode helm.WaitMode
mergeConfigs bool
}
func (f *initFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
skipHelmWait, err := flags.GetBool("skip-helm-wait")
if err != nil {
return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err)
}
f.helmWaitMode = helm.WaitModeAtomic
if skipHelmWait {
f.helmWaitMode = helm.WaitModeNone
}
f.conformance, err = flags.GetBool("conformance")
if err != nil {
return fmt.Errorf("getting 'conformance' flag: %w", err)
}
f.mergeConfigs, err = flags.GetBool("merge-kubeconfig")
if err != nil {
return fmt.Errorf("getting 'merge-kubeconfig' flag: %w", err)
}
return nil
}
type initCmd struct {
log debugLog
merger configMerger
spinner spinnerInterf
fileHandler file.Handler
flags initFlags
}
func newInitCmd(fileHandler file.Handler, spinner spinnerInterf, merger configMerger, log debugLog) *initCmd {
return &initCmd{
log: log,
merger: merger,
spinner: spinner,
fileHandler: fileHandler,
}
}
// runInitialize runs the initialize command.
func runInitialize(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
}
defer log.Sync()
fileHandler := file.NewHandler(afero.NewOsFs())
newDialer := func(validator atls.Validator) *dialer.Dialer {
return dialer.New(nil, validator, &net.Dialer{})
}
spinner, err := newSpinnerOrStderr(cmd)
if err != nil {
return err
}
defer spinner.Stop()
ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour)
defer cancel()
cmd.SetContext(ctx)
i := newInitCmd(fileHandler, spinner, &kubeconfigMerger{log: log}, log)
if err := i.flags.parse(cmd.Flags()); err != nil {
return err
}
i.log.Debugf("Using flags: %+v", i.flags)
fetcher := attestationconfigapi.NewFetcher()
newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) {
return kubecmd.New(w, kubeConfig, fileHandler, log)
}
newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) {
return helm.NewClient(kubeConfigPath, log)
} // need to defer helm client instantiation until kubeconfig is available
return i.initialize(cmd, newDialer, license.NewClient(), fetcher, newAttestationApplier, newHelmClient)
}
// initialize initializes a Constellation.
func (i *initCmd) initialize(
cmd *cobra.Command, newDialer func(validator atls.Validator) *dialer.Dialer,
quotaChecker license.QuotaChecker, configFetcher attestationconfigapi.Fetcher,
newAttestationApplier func(io.Writer, string, debugLog) (attestationConfigApplier, error),
newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error),
) error {
i.log.Debugf("Loading configuration file from %q", i.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))
conf, err := config.New(i.fileHandler, constants.ConfigFilename, configFetcher, i.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
}
if err != nil {
return err
}
// cfg validation does not check k8s patch version since upgrade may accept an outdated patch version.
k8sVersion, err := versions.NewValidK8sVersion(string(conf.KubernetesVersion), true)
if err != nil {
return err
}
if !i.flags.force {
if err := validateCLIandConstellationVersionAreEqual(constants.BinaryVersion(), conf.Image, conf.MicroserviceVersion); err != nil {
return err
}
}
if conf.GetAttestationConfig().GetVariant().Equal(variant.AWSSEVSNP{}) {
cmd.PrintErrln("WARNING: Attestation temporarily relies on AWS nitroTPM. See https://docs.edgeless.systems/constellation/workflows/config#choosing-a-vm-type for more information.")
}
stateFile, err := state.ReadFromFile(i.fileHandler, constants.StateFilename)
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
i.log.Debugf("Validated k8s version as %s", k8sVersion)
if versions.IsPreviewK8sVersion(k8sVersion) {
cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", k8sVersion)
}
provider := conf.GetProvider()
i.log.Debugf("Got provider %s", provider.String())
checker := license.NewChecker(quotaChecker, i.fileHandler)
if err := checker.CheckLicense(cmd.Context(), provider, conf.Provider, cmd.Printf); err != nil {
cmd.PrintErrf("License check failed: %v", err)
}
i.log.Debugf("Checked license")
if stateFile.Infrastructure.Azure != nil {
conf.UpdateMAAURL(stateFile.Infrastructure.Azure.AttestationURL)
}
i.log.Debugf("Creating aTLS Validator for %s", conf.GetAttestationConfig().GetVariant())
validator, err := cloudcmd.NewValidator(cmd, conf.GetAttestationConfig(), i.log)
if err != nil {
return fmt.Errorf("creating new validator: %w", err)
}
i.log.Debugf("Created a new validator")
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(provider, conf, i.flags.pathPrefixer, i.log, i.fileHandler)
if err != nil {
return err
}
i.log.Debugf("Successfully marshaled service account URI")
i.log.Debugf("Generating master secret")
masterSecret, err := i.generateMasterSecret(cmd.OutOrStdout())
if err != nil {
return fmt.Errorf("generating master secret: %w", err)
}
i.log.Debugf("Generating measurement salt")
measurementSalt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return fmt.Errorf("generating measurement salt: %w", err)
}
i.log.Debugf("Setting cluster name to %s", stateFile.Infrastructure.Name)
cmd.PrintErrln("Note: If you just created the cluster, it can take a few minutes to connect.")
i.spinner.Start("Connecting ", false)
req := &initproto.InitRequest{
KmsUri: masterSecret.EncodeToURI(),
StorageUri: uri.NoStoreURI,
MeasurementSalt: measurementSalt,
KubernetesVersion: versions.VersionConfigs[k8sVersion].ClusterVersion,
KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToInitProto(),
ConformanceMode: i.flags.conformance,
InitSecret: stateFile.Infrastructure.InitSecret,
ClusterName: stateFile.Infrastructure.Name,
ApiserverCertSans: stateFile.Infrastructure.APIServerCertSANs,
}
i.log.Debugf("Sending initialization request")
resp, err := i.initCall(cmd.Context(), newDialer(validator), stateFile.Infrastructure.ClusterEndpoint, req)
i.spinner.Stop()
if err != nil {
var nonRetriable *nonRetriableError
if errors.As(err, &nonRetriable) {
cmd.PrintErrln("Cluster initialization failed. This error is not recoverable.")
cmd.PrintErrln("Terminate your cluster and try again.")
if nonRetriable.logCollectionErr != nil {
cmd.PrintErrf("Failed to collect logs from bootstrapper: %s\n", nonRetriable.logCollectionErr)
} else {
cmd.PrintErrf("Fetched bootstrapper logs are stored in %q\n", i.flags.pathPrefixer.PrefixPrintablePath(constants.ErrorLog))
}
}
return err
}
i.log.Debugf("Initialization request succeeded")
bufferedOutput := &bytes.Buffer{}
if err := i.writeOutput(stateFile, resp, i.flags.mergeConfigs, bufferedOutput, measurementSalt); err != nil {
return err
}
attestationApplier, err := newAttestationApplier(cmd.OutOrStdout(), constants.AdminConfFilename, i.log)
if err != nil {
return err
}
if err := attestationApplier.ApplyJoinConfig(cmd.Context(), conf.GetAttestationConfig(), measurementSalt); err != nil {
return fmt.Errorf("applying attestation config: %w", err)
}
i.spinner.Start("Installing Kubernetes components ", false)
options := helm.Options{
Force: i.flags.force,
Conformance: i.flags.conformance,
HelmWaitMode: i.flags.helmWaitMode,
AllowDestructive: helm.DenyDestructive,
}
helmApplier, err := newHelmClient(constants.AdminConfFilename, i.log)
if err != nil {
return fmt.Errorf("creating Helm client: %w", err)
}
executor, includesUpgrades, err := helmApplier.PrepareApply(conf, stateFile, options, serviceAccURI, masterSecret)
if err != nil {
return fmt.Errorf("getting Helm chart executor: %w", err)
}
if includesUpgrades {
return errors.New("init: helm tried to upgrade charts instead of installing them")
}
if err := executor.Apply(cmd.Context()); err != nil {
return fmt.Errorf("applying Helm charts: %w", err)
}
i.spinner.Stop()
i.log.Debugf("Helm deployment installation succeeded")
cmd.Println(bufferedOutput.String())
return nil
}
func (i *initCmd) initCall(ctx context.Context, dialer grpcDialer, ip string, req *initproto.InitRequest) (*initproto.InitSuccessResponse, error) {
doer := &initDoer{
dialer: dialer,
endpoint: net.JoinHostPort(ip, strconv.Itoa(constants.BootstrapperPort)),
req: req,
log: i.log,
spinner: i.spinner,
fh: file.NewHandler(afero.NewOsFs()),
}
// Create a wrapper function that allows logging any returned error from the retrier before checking if it's the expected retriable one.
serviceIsUnavailable := func(err error) bool {
isServiceUnavailable := grpcRetry.ServiceIsUnavailable(err)
i.log.Debugf("Encountered error (retriable: %t): %s", isServiceUnavailable, err)
return isServiceUnavailable
}
i.log.Debugf("Making initialization call, doer is %+v", doer)
retrier := retry.NewIntervalRetrier(doer, 30*time.Second, serviceIsUnavailable)
if err := retrier.Do(ctx); err != nil {
return nil, err
}
return doer.resp, nil
}
type initDoer struct { type initDoer struct {
dialer grpcDialer dialer grpcDialer
endpoint string endpoint string
@ -469,122 +192,10 @@ func (d *initDoer) handleGRPCStateChanges(ctx context.Context, wg *sync.WaitGrou
}) })
} }
// writeOutput writes the output of a cluster initialization to the
// state- / id- / kubeconfig-file and saves it to disk.
func (i *initCmd) writeOutput(
stateFile *state.State,
initResp *initproto.InitSuccessResponse,
mergeConfig bool, wr io.Writer,
measurementSalt []byte,
) error {
fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n")
ownerID := hex.EncodeToString(initResp.GetOwnerId())
clusterID := hex.EncodeToString(initResp.GetClusterId())
stateFile.SetClusterValues(state.ClusterValues{
MeasurementSalt: measurementSalt,
OwnerID: ownerID,
ClusterID: clusterID,
})
tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0)
writeRow(tw, "Constellation cluster identifier", clusterID)
writeRow(tw, "Kubernetes configuration", i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
tw.Flush()
fmt.Fprintln(wr)
i.log.Debugf("Rewriting cluster server address in kubeconfig to %s", stateFile.Infrastructure.ClusterEndpoint)
kubeconfig, err := clientcmd.Load(initResp.GetKubeconfig())
if err != nil {
return fmt.Errorf("loading kubeconfig: %w", err)
}
if len(kubeconfig.Clusters) != 1 {
return fmt.Errorf("expected exactly one cluster in kubeconfig, got %d", len(kubeconfig.Clusters))
}
for _, cluster := range kubeconfig.Clusters {
kubeEndpoint, err := url.Parse(cluster.Server)
if err != nil {
return fmt.Errorf("parsing kubeconfig server URL: %w", err)
}
kubeEndpoint.Host = net.JoinHostPort(stateFile.Infrastructure.ClusterEndpoint, kubeEndpoint.Port())
cluster.Server = kubeEndpoint.String()
}
kubeconfigBytes, err := clientcmd.Write(*kubeconfig)
if err != nil {
return fmt.Errorf("marshaling kubeconfig: %w", err)
}
if err := i.fileHandler.Write(constants.AdminConfFilename, kubeconfigBytes, file.OptNone); err != nil {
return fmt.Errorf("writing kubeconfig: %w", err)
}
i.log.Debugf("Kubeconfig written to %s", i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
if mergeConfig {
if err := i.merger.mergeConfigs(constants.AdminConfFilename, i.fileHandler); err != nil {
writeRow(tw, "Failed to automatically merge kubeconfig", err.Error())
mergeConfig = false // Set to false so we don't print the wrong message below.
} else {
writeRow(tw, "Kubernetes configuration merged with default config", "")
}
}
if err := stateFile.WriteToFile(i.fileHandler, constants.StateFilename); err != nil {
return fmt.Errorf("writing Constellation state file: %w", err)
}
i.log.Debugf("Constellation state file written to %s", i.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename))
if !mergeConfig {
fmt.Fprintln(wr, "You can now connect to your cluster by executing:")
exportPath, err := filepath.Abs(constants.AdminConfFilename)
if err != nil {
return fmt.Errorf("getting absolute path to kubeconfig: %w", err)
}
fmt.Fprintf(wr, "\texport KUBECONFIG=%q\n", exportPath)
} else {
fmt.Fprintln(wr, "Constellation kubeconfig merged with default config.")
if i.merger.kubeconfigEnvVar() != "" {
fmt.Fprintln(wr, "Warning: KUBECONFIG environment variable is set.")
fmt.Fprintln(wr, "You may need to unset it to use the default config and connect to your cluster.")
} else {
fmt.Fprintln(wr, "You can now connect to your cluster.")
}
}
return nil
}
func writeRow(wr io.Writer, col1 string, col2 string) { func writeRow(wr io.Writer, col1 string, col2 string) {
fmt.Fprint(wr, col1, "\t", col2, "\n") fmt.Fprint(wr, col1, "\t", col2, "\n")
} }
// generateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func (i *initCmd) generateMasterSecret(outWriter io.Writer) (uri.MasterSecret, error) {
// No file given, generate a new secret, and save it to disk
i.log.Debugf("Generating new master secret")
key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault)
if err != nil {
return uri.MasterSecret{}, err
}
salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return uri.MasterSecret{}, err
}
secret := uri.MasterSecret{
Key: key,
Salt: salt,
}
i.log.Debugf("Generated master secret key and salt values")
if err := i.fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
return uri.MasterSecret{}, err
}
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to %q\n", i.flags.pathPrefixer.PrefixPrintablePath(constants.MasterSecretFilename))
return secret, nil
}
type configMerger interface { type configMerger interface {
mergeConfigs(configPath string, fileHandler file.Handler) error mergeConfigs(configPath string, fileHandler file.Handler) error
kubeconfigEnvVar() string kubeconfigEnvVar() string
@ -657,10 +268,6 @@ func (e *nonRetriableError) Unwrap() error {
return e.err return e.err
} }
type attestationConfigApplier interface {
ApplyJoinConfig(ctx context.Context, newAttestConfig config.AttestationCfg, measurementSalt []byte) error
}
type helmApplier interface { type helmApplier interface {
PrepareApply(conf *config.Config, stateFile *state.State, PrepareApply(conf *config.Config, stateFile *state.State,
flags helm.Options, serviceAccURI string, masterSecret uri.MasterSecret) ( flags helm.Options, serviceAccURI string, masterSecret uri.MasterSecret) (

View file

@ -44,6 +44,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc" "google.golang.org/grpc"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/clientcmd"
k8sclientapi "k8s.io/client-go/tools/clientcmd/api" k8sclientapi "k8s.io/client-go/tools/clientcmd/api"
) )
@ -58,8 +60,6 @@ func TestInitArgumentValidation(t *testing.T) {
} }
func TestInitialize(t *testing.T) { func TestInitialize(t *testing.T) {
require := require.New(t)
respKubeconfig := k8sclientapi.Config{ respKubeconfig := k8sclientapi.Config{
Clusters: map[string]*k8sclientapi.Cluster{ Clusters: map[string]*k8sclientapi.Cluster{
"cluster": { "cluster": {
@ -68,7 +68,7 @@ func TestInitialize(t *testing.T) {
}, },
} }
respKubeconfigBytes, err := clientcmd.Write(respKubeconfig) respKubeconfigBytes, err := clientcmd.Write(respKubeconfig)
require.NoError(err) require.NoError(t, err)
gcpServiceAccKey := &gcpshared.ServiceAccountKey{ gcpServiceAccKey := &gcpshared.ServiceAccountKey{
Type: "service_account", Type: "service_account",
@ -149,31 +149,47 @@ func TestInitialize(t *testing.T) {
masterSecretShouldExist: true, masterSecretShouldExist: true,
wantErr: true, wantErr: true,
}, },
"state file with only version": { /*
provider: cloudprovider.GCP, Tests currently disabled since we don't actually have validation for the state file yet
stateFile: &state.State{Version: state.Version1}, These tests cases only passed in the past because of unrelated errors in the test setup
initServerAPI: &stubInitServer{}, TODO(AB#3492): Re-enable tests once state file validation is implemented
retriable: true,
wantErr: true, "state file with only version": {
}, provider: cloudprovider.GCP,
"empty state file": { stateFile: &state.State{Version: state.Version1},
provider: cloudprovider.GCP, configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
stateFile: &state.State{}, serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{}, initServerAPI: &stubInitServer{},
retriable: true, retriable: true,
wantErr: true, wantErr: true,
}, },
"empty state file": {
provider: cloudprovider.GCP,
stateFile: &state.State{},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{},
retriable: true,
wantErr: true,
},
*/
"no state file": { "no state file": {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
retriable: true, configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
wantErr: true, serviceAccKey: gcpServiceAccKey,
retriable: true,
wantErr: true,
}, },
"init call fails": { "init call fails": {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}}, configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
initServerAPI: &stubInitServer{initErr: assert.AnError}, stateFile: &state.State{Version: state.Version1, Infrastructure: state.Infrastructure{ClusterEndpoint: "192.0.2.1"}},
retriable: true, serviceAccKey: gcpServiceAccKey,
wantErr: true, initServerAPI: &stubInitServer{initErr: assert.AnError},
retriable: false,
masterSecretShouldExist: true,
wantErr: true,
}, },
"k8s version without v works": { "k8s version without v works": {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
@ -181,7 +197,7 @@ func TestInitialize(t *testing.T) {
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}}, initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
configMutator: func(c *config.Config) { configMutator: func(c *config.Config) {
res, err := versions.NewValidK8sVersion(strings.TrimPrefix(string(versions.Default), "v"), true) res, err := versions.NewValidK8sVersion(strings.TrimPrefix(string(versions.Default), "v"), true)
require.NoError(err) require.NoError(t, err)
c.KubernetesVersion = res c.KubernetesVersion = res
}, },
}, },
@ -191,7 +207,7 @@ func TestInitialize(t *testing.T) {
initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}}, initServerAPI: &stubInitServer{res: []*initproto.InitResponse{{Kind: &initproto.InitResponse_InitSuccess{InitSuccess: testInitResp}}}},
configMutator: func(c *config.Config) { configMutator: func(c *config.Config) {
v, err := semver.New(versions.SupportedK8sVersions()[0]) v, err := semver.New(versions.SupportedK8sVersions()[0])
require.NoError(err) require.NoError(t, err)
outdatedPatchVer := semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String() outdatedPatchVer := semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String()
c.KubernetesVersion = versions.ValidK8sVersion(outdatedPatchVer) c.KubernetesVersion = versions.ValidK8sVersion(outdatedPatchVer)
}, },
@ -203,6 +219,7 @@ func TestInitialize(t *testing.T) {
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t)
// Networking // Networking
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
newDialer := func(atls.Validator) *dialer.Dialer { newDialer := func(atls.Validator) *dialer.Dialer {
@ -231,8 +248,6 @@ func TestInitialize(t *testing.T) {
tc.configMutator(config) tc.configMutator(config)
} }
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, config, file.OptNone)) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, config, file.OptNone))
stateFile := state.New()
require.NoError(stateFile.WriteToFile(fileHandler, constants.StateFilename))
if tc.stateFile != nil { if tc.stateFile != nil {
require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename)) require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename))
} }
@ -244,20 +259,28 @@ func TestInitialize(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 4*time.Second) ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel() defer cancel()
cmd.SetContext(ctx) cmd.SetContext(ctx)
i := newInitCmd(fileHandler, &nopSpinner{}, nil, logger.NewTest(t))
i.flags.force = true
err := i.initialize( i := &applyCmd{
cmd, fileHandler: fileHandler,
newDialer, flags: applyFlags{rootFlags: rootFlags{force: true}},
&stubLicenseClient{}, log: logger.NewTest(t),
stubAttestationFetcher{}, spinner: &nopSpinner{},
func(io.Writer, string, debugLog) (attestationConfigApplier, error) { merger: &stubMerger{},
return &stubAttestationApplier{}, nil quotaChecker: &stubLicenseClient{},
}, newHelmClient: func(string, debugLog) (helmApplier, error) {
func(_ string, _ debugLog) (helmApplier, error) {
return &stubApplier{}, nil return &stubApplier{}, nil
}) },
newDialer: newDialer,
newKubeUpgrader: func(io.Writer, string, debugLog) (kubernetesUpgrader, error) {
return &stubKubernetesUpgrader{
// On init, no attestation config exists yet
getClusterAttestationConfigErr: k8serrors.NewNotFound(schema.GroupResource{}, ""),
}, nil
},
clusterUpgrader: stubTerraformUpgrader{},
}
err := i.apply(cmd, stubAttestationFetcher{}, "test")
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -291,14 +314,17 @@ func (s stubApplier) PrepareApply(_ *config.Config, _ *state.State, _ helm.Optio
return stubRunner{}, false, s.err return stubRunner{}, false, s.err
} }
type stubRunner struct{} type stubRunner struct {
applyErr error
saveChartsErr error
}
func (s stubRunner) Apply(_ context.Context) error { func (s stubRunner) Apply(_ context.Context) error {
return nil return s.applyErr
} }
func (s stubRunner) SaveCharts(_ string, _ file.Handler) error { func (s stubRunner) SaveCharts(_ string, _ file.Handler) error {
return nil return s.saveChartsErr
} }
func TestGetLogs(t *testing.T) { func TestGetLogs(t *testing.T) {
@ -420,8 +446,13 @@ func TestWriteOutput(t *testing.T) {
ClusterEndpoint: clusterEndpoint, ClusterEndpoint: clusterEndpoint,
}) })
i := newInitCmd(fileHandler, &nopSpinner{}, &stubMerger{}, logger.NewTest(t)) i := &applyCmd{
err = i.writeOutput(stateFile, resp.GetInitSuccess(), false, &out, measurementSalt) fileHandler: fileHandler,
spinner: &nopSpinner{},
merger: &stubMerger{},
log: logger.NewTest(t),
}
err = i.writeInitOutput(stateFile, resp.GetInitSuccess(), false, &out, measurementSalt)
require.NoError(err) require.NoError(err)
assert.Contains(out.String(), clusterID) assert.Contains(out.String(), clusterID)
assert.Contains(out.String(), constants.AdminConfFilename) assert.Contains(out.String(), constants.AdminConfFilename)
@ -441,7 +472,7 @@ func TestWriteOutput(t *testing.T) {
// test custom workspace // test custom workspace
i.flags.pathPrefixer = pathprefix.New("/some/path") i.flags.pathPrefixer = pathprefix.New("/some/path")
err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt) err = i.writeInitOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
require.NoError(err) require.NoError(err)
assert.Contains(out.String(), clusterID) assert.Contains(out.String(), clusterID)
assert.Contains(out.String(), i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename)) assert.Contains(out.String(), i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename))
@ -451,7 +482,7 @@ func TestWriteOutput(t *testing.T) {
i.flags.pathPrefixer = pathprefix.PathPrefixer{} i.flags.pathPrefixer = pathprefix.PathPrefixer{}
// test config merging // test config merging
err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt) err = i.writeInitOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
require.NoError(err) require.NoError(err)
assert.Contains(out.String(), clusterID) assert.Contains(out.String(), clusterID)
assert.Contains(out.String(), constants.AdminConfFilename) assert.Contains(out.String(), constants.AdminConfFilename)
@ -462,7 +493,7 @@ func TestWriteOutput(t *testing.T) {
// test config merging with env vars set // test config merging with env vars set
i.merger = &stubMerger{envVar: "/some/path/to/kubeconfig"} i.merger = &stubMerger{envVar: "/some/path/to/kubeconfig"}
err = i.writeOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt) err = i.writeInitOutput(stateFile, resp.GetInitSuccess(), true, &out, measurementSalt)
require.NoError(err) require.NoError(err)
assert.Contains(out.String(), clusterID) assert.Contains(out.String(), clusterID)
assert.Contains(out.String(), constants.AdminConfFilename) assert.Contains(out.String(), constants.AdminConfFilename)
@ -508,8 +539,11 @@ func TestGenerateMasterSecret(t *testing.T) {
require.NoError(tc.createFileFunc(fileHandler)) require.NoError(tc.createFileFunc(fileHandler))
var out bytes.Buffer var out bytes.Buffer
i := newInitCmd(fileHandler, nil, nil, logger.NewTest(t)) i := &applyCmd{
secret, err := i.generateMasterSecret(&out) fileHandler: fileHandler,
log: logger.NewTest(t),
}
secret, err := i.generateAndPersistMasterSecret(&out)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -601,13 +635,17 @@ func TestAttestation(t *testing.T) {
defer cancel() defer cancel()
cmd.SetContext(ctx) cmd.SetContext(ctx)
i := newInitCmd(fileHandler, &nopSpinner{}, nil, logger.NewTest(t)) i := &applyCmd{
err := i.initialize(cmd, newDialer, &stubLicenseClient{}, stubAttestationFetcher{}, fileHandler: fileHandler,
func(io.Writer, string, debugLog) (attestationConfigApplier, error) { spinner: &nopSpinner{},
return &stubAttestationApplier{}, nil merger: &stubMerger{},
}, func(_ string, _ debugLog) (helmApplier, error) { log: logger.NewTest(t),
return &stubApplier{}, nil newKubeUpgrader: func(io.Writer, string, debugLog) (kubernetesUpgrader, error) {
}) return &stubKubernetesUpgrader{}, nil
},
newDialer: newDialer,
}
_, err := i.runInit(cmd, cfg, existingStateFile)
assert.Error(err) assert.Error(err)
// make sure the error is actually a TLS handshake error // make sure the error is actually a TLS handshake error
assert.Contains(err.Error(), "transport: authentication handshake failed") assert.Contains(err.Error(), "transport: authentication handshake failed")
@ -773,11 +811,3 @@ func (c stubInitClient) Recv() (*initproto.InitResponse, error) {
return res, err return res, err
} }
type stubAttestationApplier struct {
applyErr error
}
func (a *stubAttestationApplier) ApplyJoinConfig(context.Context, config.AttestationCfg, []byte) error {
return a.applyErr
}

View file

@ -10,23 +10,17 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io" "time"
"net"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/featureset" "github.com/edgelesssys/constellation/v2/cli/internal/featureset"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/libvirt" "github.com/edgelesssys/constellation/v2/cli/internal/libvirt"
"github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -105,7 +99,7 @@ func (m *miniUpCmd) up(cmd *cobra.Command, creator cloudCreator, spinner spinner
cmd.Printf("\tvirsh -c %s\n\n", connectURI) cmd.Printf("\tvirsh -c %s\n\n", connectURI)
// initialize cluster // initialize cluster
if err := m.initializeMiniCluster(cmd, spinner); err != nil { if err := m.initializeMiniCluster(cmd); err != nil {
return fmt.Errorf("initializing cluster: %w", err) return fmt.Errorf("initializing cluster: %w", err)
} }
m.log.Debugf("Initialized cluster") m.log.Debugf("Initialized cluster")
@ -188,7 +182,7 @@ func (m *miniUpCmd) createMiniCluster(ctx context.Context, creator cloudCreator,
} }
// initializeMiniCluster initializes a QEMU cluster. // initializeMiniCluster initializes a QEMU cluster.
func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, spinner spinnerInterf) (retErr error) { func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command) (retErr error) {
m.log.Debugf("Initializing mini cluster") m.log.Debugf("Initializing mini cluster")
// clean up cluster resources if initialization fails // clean up cluster resources if initialization fails
defer func() { defer func() {
@ -199,34 +193,19 @@ func (m *miniUpCmd) initializeMiniCluster(cmd *cobra.Command, spinner spinnerInt
cmd.PrintErrf("Rollback succeeded.\n\n") cmd.PrintErrf("Rollback succeeded.\n\n")
} }
}() }()
newDialer := func(validator atls.Validator) *dialer.Dialer {
return dialer.New(nil, validator, &net.Dialer{}) // Define flags for apply backend that are not set by mini up
} cmd.Flags().StringSlice(
m.log.Debugf("Created new dialer") "skip-phases",
cmd.Flags().String("endpoint", "", "") []string{string(skipInfrastructurePhase), string(skipK8sPhase), string(skipImagePhase)},
"",
)
cmd.Flags().Bool("yes", false, "")
cmd.Flags().Bool("skip-helm-wait", false, "")
cmd.Flags().Bool("conformance", false, "") cmd.Flags().Bool("conformance", false, "")
cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready") cmd.Flags().Duration("timeout", time.Hour, "")
log, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
}
m.log.Debugf("Created new logger")
defer log.Sync()
newAttestationApplier := func(w io.Writer, kubeConfig string, log debugLog) (attestationConfigApplier, error) { if err := runApply(cmd, nil); err != nil {
return kubecmd.New(w, kubeConfig, m.fileHandler, log)
}
newHelmClient := func(kubeConfigPath string, log debugLog) (helmApplier, error) {
return helm.NewClient(kubeConfigPath, log)
} // need to defer helm client instantiation until kubeconfig is available
i := newInitCmd(m.fileHandler, spinner, &kubeconfigMerger{log: log}, log)
if err := i.flags.parse(cmd.Flags()); err != nil {
return err
}
if err := i.initialize(cmd, newDialer, license.NewClient(), m.configFetcher,
newAttestationApplier, newHelmClient); err != nil {
return err return err
} }
m.log.Debugf("Initialized mini cluster") m.log.Debugf("Initialized mini cluster")

View file

@ -8,36 +8,25 @@ package cmd
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"path/filepath"
"strings" "strings"
"time" "time"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/cli/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/cli/internal/state" "github.com/edgelesssys/constellation/v2/cli/internal/state"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform" "github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant" "github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/compatibility"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/rogpeppe/go-internal/diff" "github.com/rogpeppe/go-internal/diff"
"github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/pflag"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
) )
const ( const (
// skipInitPhase skips the init RPC of the apply process.
skipInitPhase skipPhase = "init"
// skipInfrastructurePhase skips the terraform apply of the upgrade process. // skipInfrastructurePhase skips the terraform apply of the upgrade process.
skipInfrastructurePhase skipPhase = "infrastructure" skipInfrastructurePhase skipPhase = "infrastructure"
// skipHelmPhase skips the helm upgrade of the upgrade process. // skipHelmPhase skips the helm upgrade of the upgrade process.
@ -57,7 +46,11 @@ func newUpgradeApplyCmd() *cobra.Command {
Short: "Apply an upgrade to a Constellation cluster", Short: "Apply an upgrade to a Constellation cluster",
Long: "Apply an upgrade to a Constellation cluster by applying the chosen configuration.", Long: "Apply an upgrade to a Constellation cluster by applying the chosen configuration.",
Args: cobra.NoArgs, Args: cobra.NoArgs,
RunE: runUpgradeApply, RunE: func(cmd *cobra.Command, args []string) error {
// Define flags for apply backend that are not set by upgrade-apply
cmd.Flags().Bool("merge-kubeconfig", false, "")
return runApply(cmd, args)
},
} }
cmd.Flags().BoolP("yes", "y", false, "run upgrades without further confirmation\n"+ cmd.Flags().BoolP("yes", "y", false, "run upgrades without further confirmation\n"+
@ -69,238 +62,11 @@ func newUpgradeApplyCmd() *cobra.Command {
cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready") cmd.Flags().Bool("skip-helm-wait", false, "install helm charts without waiting for deployments to be ready")
cmd.Flags().StringSlice("skip-phases", nil, "comma-separated list of upgrade phases to skip\n"+ cmd.Flags().StringSlice("skip-phases", nil, "comma-separated list of upgrade phases to skip\n"+
"one or multiple of { infrastructure | helm | image | k8s }") "one or multiple of { infrastructure | helm | image | k8s }")
if err := cmd.Flags().MarkHidden("timeout"); err != nil { must(cmd.Flags().MarkHidden("timeout"))
panic(err)
}
return cmd return cmd
} }
type upgradeApplyFlags struct {
rootFlags
yes bool
upgradeTimeout time.Duration
conformance bool
helmWaitMode helm.WaitMode
skipPhases skipPhases
}
func (f *upgradeApplyFlags) parse(flags *pflag.FlagSet) error {
if err := f.rootFlags.parse(flags); err != nil {
return err
}
rawSkipPhases, err := flags.GetStringSlice("skip-phases")
if err != nil {
return fmt.Errorf("parsing skip-phases flag: %w", err)
}
var skipPhases []skipPhase
for _, phase := range rawSkipPhases {
switch skipPhase(phase) {
case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase:
skipPhases = append(skipPhases, skipPhase(phase))
default:
return fmt.Errorf("invalid phase %s", phase)
}
}
f.skipPhases = skipPhases
f.yes, err = flags.GetBool("yes")
if err != nil {
return fmt.Errorf("getting 'yes' flag: %w", err)
}
f.upgradeTimeout, err = flags.GetDuration("timeout")
if err != nil {
return fmt.Errorf("getting 'timeout' flag: %w", err)
}
f.conformance, err = flags.GetBool("conformance")
if err != nil {
return fmt.Errorf("getting 'conformance' flag: %w", err)
}
skipHelmWait, err := flags.GetBool("skip-helm-wait")
if err != nil {
return fmt.Errorf("getting 'skip-helm-wait' flag: %w", err)
}
f.helmWaitMode = helm.WaitModeAtomic
if skipHelmWait {
f.helmWaitMode = helm.WaitModeNone
}
return nil
}
func runUpgradeApply(cmd *cobra.Command, _ []string) error {
log, err := newCLILogger(cmd)
if err != nil {
return fmt.Errorf("creating logger: %w", err)
}
defer log.Sync()
fileHandler := file.NewHandler(afero.NewOsFs())
upgradeID := generateUpgradeID(upgradeCmdKindApply)
kubeUpgrader, err := kubecmd.New(cmd.OutOrStdout(), constants.AdminConfFilename, fileHandler, log)
if err != nil {
return err
}
configFetcher := attestationconfigapi.NewFetcher()
var flags upgradeApplyFlags
if err := flags.parse(cmd.Flags()); err != nil {
return err
}
// Set up terraform upgrader
upgradeDir := filepath.Join(constants.UpgradeDir, upgradeID)
clusterUpgrader, err := cloudcmd.NewClusterUpgrader(
cmd.Context(),
constants.TerraformWorkingDir,
upgradeDir,
flags.tfLogLevel,
fileHandler,
)
if err != nil {
return fmt.Errorf("setting up cluster upgrader: %w", err)
}
helmClient, err := helm.NewClient(constants.AdminConfFilename, log)
if err != nil {
return fmt.Errorf("creating Helm client: %w", err)
}
applyCmd := upgradeApplyCmd{
kubeUpgrader: kubeUpgrader,
helmApplier: helmClient,
clusterUpgrader: clusterUpgrader,
configFetcher: configFetcher,
fileHandler: fileHandler,
flags: flags,
log: log,
}
return applyCmd.upgradeApply(cmd, upgradeDir)
}
type upgradeApplyCmd struct {
helmApplier helmApplier
kubeUpgrader kubernetesUpgrader
clusterUpgrader clusterUpgrader
configFetcher attestationconfigapi.Fetcher
fileHandler file.Handler
flags upgradeApplyFlags
log debugLog
}
func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, upgradeDir string) error {
conf, err := config.New(u.fileHandler, constants.ConfigFilename, u.configFetcher, u.flags.force)
var configValidationErr *config.ValidationError
if errors.As(err, &configValidationErr) {
cmd.PrintErrln(configValidationErr.LongMessage())
}
if err != nil {
return err
}
if cloudcmd.UpgradeRequiresIAMMigration(conf.GetProvider()) {
cmd.Println("WARNING: This upgrade requires an IAM migration. Please make sure you have applied the IAM migration using `iam upgrade apply` before continuing.")
if !u.flags.yes {
yes, err := askToConfirm(cmd, "Did you upgrade the IAM resources?")
if err != nil {
return fmt.Errorf("asking for confirmation: %w", err)
}
if !yes {
cmd.Println("Skipping upgrade.")
return nil
}
}
}
conf.KubernetesVersion, err = validK8sVersion(cmd, string(conf.KubernetesVersion), u.flags.yes)
if err != nil {
return err
}
stateFile, err := state.ReadFromFile(u.fileHandler, constants.StateFilename)
if err != nil {
return fmt.Errorf("reading state file: %w", err)
}
if err := u.confirmAndUpgradeAttestationConfig(cmd, conf.GetAttestationConfig(), stateFile.ClusterValues.MeasurementSalt); err != nil {
return fmt.Errorf("upgrading measurements: %w", err)
}
// If infrastructure phase is skipped, we expect the new infrastructure
// to be in the Terraform configuration already. Otherwise, perform
// the Terraform migrations.
if !u.flags.skipPhases.contains(skipInfrastructurePhase) {
migrationRequired, err := u.planTerraformMigration(cmd, conf)
if err != nil {
return fmt.Errorf("planning Terraform migrations: %w", err)
}
if migrationRequired {
postMigrationInfraState, err := u.migrateTerraform(cmd, conf, upgradeDir)
if err != nil {
return fmt.Errorf("performing Terraform migrations: %w", err)
}
// Merge the pre-upgrade state with the post-migration infrastructure values
if _, err := stateFile.Merge(
// temporary state with post-migration infrastructure values
state.New().SetInfrastructure(postMigrationInfraState),
); err != nil {
return fmt.Errorf("merging pre-upgrade state with post-migration infrastructure values: %w", err)
}
// Write the post-migration state to disk
if err := stateFile.WriteToFile(u.fileHandler, constants.StateFilename); err != nil {
return fmt.Errorf("writing state file: %w", err)
}
}
}
// extend the clusterConfig cert SANs with any of the supported endpoints:
// - (legacy) public IP
// - fallback endpoint
// - custom (user-provided) endpoint
sans := append([]string{stateFile.Infrastructure.ClusterEndpoint, conf.CustomEndpoint}, stateFile.Infrastructure.APIServerCertSANs...)
if err := u.kubeUpgrader.ExtendClusterConfigCertSANs(cmd.Context(), sans); err != nil {
return fmt.Errorf("extending cert SANs: %w", err)
}
if conf.GetProvider() != cloudprovider.Azure && conf.GetProvider() != cloudprovider.GCP && conf.GetProvider() != cloudprovider.AWS {
cmd.PrintErrln("WARNING: Skipping service and image upgrades, which are currently only supported for AWS, Azure, and GCP.")
return nil
}
var upgradeErr *compatibility.InvalidUpgradeError
if !u.flags.skipPhases.contains(skipHelmPhase) {
err = u.handleServiceUpgrade(cmd, conf, stateFile, upgradeDir)
switch {
case errors.As(err, &upgradeErr):
cmd.PrintErrln(err)
case err == nil:
cmd.Println("Successfully upgraded Constellation services.")
case err != nil:
return fmt.Errorf("upgrading services: %w", err)
}
}
skipImageUpgrade := u.flags.skipPhases.contains(skipImagePhase)
skipK8sUpgrade := u.flags.skipPhases.contains(skipK8sPhase)
if !(skipImageUpgrade && skipK8sUpgrade) {
err = u.kubeUpgrader.UpgradeNodeVersion(cmd.Context(), conf, u.flags.force, skipImageUpgrade, skipK8sUpgrade)
switch {
case errors.Is(err, kubecmd.ErrInProgress):
cmd.PrintErrln("Skipping image and Kubernetes upgrades. Another upgrade is in progress.")
case errors.As(err, &upgradeErr):
cmd.PrintErrln(err)
case err != nil:
return fmt.Errorf("upgrading NodeVersion: %w", err)
}
}
return nil
}
func diffAttestationCfg(currentAttestationCfg config.AttestationCfg, newAttestationCfg config.AttestationCfg) (string, error) { func diffAttestationCfg(currentAttestationCfg config.AttestationCfg, newAttestationCfg config.AttestationCfg) (string, error) {
// cannot compare structs directly with go-cmp because of unexported fields in the attestation config // cannot compare structs directly with go-cmp because of unexported fields in the attestation config
currentYml, err := yaml.Marshal(currentAttestationCfg) currentYml, err := yaml.Marshal(currentAttestationCfg)
@ -315,220 +81,23 @@ func diffAttestationCfg(currentAttestationCfg config.AttestationCfg, newAttestat
return diff, nil return diff, nil
} }
// planTerraformMigration checks if the Constellation version the cluster is being upgraded to requires a migration.
func (u *upgradeApplyCmd) planTerraformMigration(cmd *cobra.Command, conf *config.Config) (bool, error) {
u.log.Debugf("Planning Terraform migrations")
vars, err := cloudcmd.TerraformUpgradeVars(conf)
if err != nil {
return false, fmt.Errorf("parsing upgrade variables: %w", err)
}
u.log.Debugf("Using Terraform variables:\n%v", vars)
// Check if there are any Terraform migrations to apply
// Add manual migrations here if required
//
// var manualMigrations []terraform.StateMigration
// for _, migration := range manualMigrations {
// u.log.Debugf("Adding manual Terraform migration: %s", migration.DisplayName)
// u.upgrader.AddManualStateMigration(migration)
// }
return u.clusterUpgrader.PlanClusterUpgrade(cmd.Context(), cmd.OutOrStdout(), vars, conf.GetProvider())
}
// migrateTerraform checks if the Constellation version the cluster is being upgraded to requires a migration
// of cloud resources with Terraform. If so, the migration is performed and the post-migration infrastructure state is returned.
// If no migration is required, the current (pre-upgrade) infrastructure state is returned.
func (u *upgradeApplyCmd) migrateTerraform(cmd *cobra.Command, conf *config.Config, upgradeDir string,
) (state.Infrastructure, error) {
// If there are any Terraform migrations to apply, ask for confirmation
fmt.Fprintln(cmd.OutOrStdout(), "The upgrade requires a migration of Constellation cloud resources by applying an updated Terraform template. Please manually review the suggested changes below.")
if !u.flags.yes {
ok, err := askToConfirm(cmd, "Do you want to apply the Terraform migrations?")
if err != nil {
return state.Infrastructure{}, fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
cmd.Println("Aborting upgrade.")
// User doesn't expect to see any changes in his workspace after aborting an "upgrade apply",
// therefore, roll back to the backed up state.
if err := u.clusterUpgrader.RestoreClusterWorkspace(); err != nil {
return state.Infrastructure{}, fmt.Errorf(
"restoring Terraform workspace: %w, restore the Terraform workspace manually from %s ",
err,
filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir),
)
}
return state.Infrastructure{}, fmt.Errorf("cluster upgrade aborted by user")
}
}
u.log.Debugf("Applying Terraform migrations")
infraState, err := u.clusterUpgrader.ApplyClusterUpgrade(cmd.Context(), conf.GetProvider())
if err != nil {
return state.Infrastructure{}, fmt.Errorf("applying terraform migrations: %w", err)
}
cmd.Printf("Infrastructure migrations applied successfully and output written to: %s\n"+
"A backup of the pre-upgrade state has been written to: %s\n",
u.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename),
u.flags.pathPrefixer.PrefixPrintablePath(filepath.Join(upgradeDir, constants.TerraformUpgradeBackupDir)),
)
return infraState, nil
}
// validK8sVersion checks if the Kubernetes patch version is supported and asks for confirmation if not.
func validK8sVersion(cmd *cobra.Command, version string, yes bool) (validVersion versions.ValidK8sVersion, err error) {
validVersion, err = versions.NewValidK8sVersion(version, true)
if versions.IsPreviewK8sVersion(validVersion) {
cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", validVersion)
}
valid := err == nil
if !valid && !yes {
confirmed, err := askToConfirm(cmd, fmt.Sprintf("WARNING: The Kubernetes patch version %s is not supported. If you continue, Kubernetes upgrades will be skipped. Do you want to continue anyway?", version))
if err != nil {
return validVersion, fmt.Errorf("asking for confirmation: %w", err)
}
if !confirmed {
return validVersion, fmt.Errorf("aborted by user")
}
}
return validVersion, nil
}
// confirmAndUpgradeAttestationConfig checks if the locally configured measurements are different from the cluster's measurements.
// If so the function will ask the user to confirm (if --yes is not set) and upgrade the cluster's config.
func (u *upgradeApplyCmd) confirmAndUpgradeAttestationConfig(
cmd *cobra.Command, newConfig config.AttestationCfg, measurementSalt []byte,
) error {
clusterAttestationConfig, err := u.kubeUpgrader.GetClusterAttestationConfig(cmd.Context(), newConfig.GetVariant())
if err != nil {
return fmt.Errorf("getting cluster attestation config: %w", err)
}
// If the current config is equal, or there is an error when comparing the configs, we skip the upgrade.
equal, err := newConfig.EqualTo(clusterAttestationConfig)
if err != nil {
return fmt.Errorf("comparing attestation configs: %w", err)
}
if equal {
return nil
}
cmd.Println("The configured attestation config is different from the attestation config in the cluster.")
diffStr, err := diffAttestationCfg(clusterAttestationConfig, newConfig)
if err != nil {
return fmt.Errorf("diffing attestation configs: %w", err)
}
cmd.Println("The following changes will be applied to the attestation config:")
cmd.Println(diffStr)
if !u.flags.yes {
ok, err := askToConfirm(cmd, "Are you sure you want to change your cluster's attestation config?")
if err != nil {
return fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
return errors.New("aborting upgrade since attestation config is different")
}
}
if err := u.kubeUpgrader.ApplyJoinConfig(cmd.Context(), newConfig, measurementSalt); err != nil {
return fmt.Errorf("updating attestation config: %w", err)
}
cmd.Println("Successfully updated the cluster's attestation config")
return nil
}
func (u *upgradeApplyCmd) handleServiceUpgrade(
cmd *cobra.Command, conf *config.Config, stateFile *state.State, upgradeDir string,
) error {
var secret uri.MasterSecret
if err := u.fileHandler.ReadJSON(constants.MasterSecretFilename, &secret); err != nil {
return fmt.Errorf("reading master secret: %w", err)
}
serviceAccURI, err := cloudcmd.GetMarshaledServiceAccountURI(conf.GetProvider(), conf, u.flags.pathPrefixer, u.log, u.fileHandler)
if err != nil {
return fmt.Errorf("getting service account URI: %w", err)
}
options := helm.Options{
Force: u.flags.force,
Conformance: u.flags.conformance,
HelmWaitMode: u.flags.helmWaitMode,
}
prepareApply := func(allowDestructive bool) (helm.Applier, bool, error) {
options.AllowDestructive = allowDestructive
executor, includesUpgrades, err := u.helmApplier.PrepareApply(conf, stateFile, options, serviceAccURI, secret)
var upgradeErr *compatibility.InvalidUpgradeError
switch {
case errors.As(err, &upgradeErr):
cmd.PrintErrln(err)
case err != nil:
return nil, false, fmt.Errorf("getting chart executor: %w", err)
}
return executor, includesUpgrades, nil
}
executor, includesUpgrades, err := prepareApply(helm.DenyDestructive)
if err != nil {
if !errors.Is(err, helm.ErrConfirmationMissing) {
return fmt.Errorf("upgrading charts with deny destructive mode: %w", err)
}
if !u.flags.yes {
cmd.PrintErrln("WARNING: Upgrading cert-manager will destroy all custom resources you have manually created that are based on the current version of cert-manager.")
ok, askErr := askToConfirm(cmd, "Do you want to upgrade cert-manager anyway?")
if askErr != nil {
return fmt.Errorf("asking for confirmation: %w", err)
}
if !ok {
cmd.Println("Skipping upgrade.")
return nil
}
}
executor, includesUpgrades, err = prepareApply(helm.AllowDestructive)
if err != nil {
return fmt.Errorf("upgrading charts with allow destructive mode: %w", err)
}
}
// Save the Helm charts for the upgrade to disk
chartDir := filepath.Join(upgradeDir, "helm-charts")
if err := executor.SaveCharts(chartDir, u.fileHandler); err != nil {
return fmt.Errorf("saving Helm charts to disk: %w", err)
}
u.log.Debugf("Helm charts saved to %s", chartDir)
if includesUpgrades {
u.log.Debugf("Creating backup of CRDs and CRs")
crds, err := u.kubeUpgrader.BackupCRDs(cmd.Context(), upgradeDir)
if err != nil {
return fmt.Errorf("creating CRD backup: %w", err)
}
if err := u.kubeUpgrader.BackupCRs(cmd.Context(), crds, upgradeDir); err != nil {
return fmt.Errorf("creating CR backup: %w", err)
}
}
if err := executor.Apply(cmd.Context()); err != nil {
return fmt.Errorf("applying Helm charts: %w", err)
}
return nil
}
// skipPhases is a list of phases that can be skipped during the upgrade process. // skipPhases is a list of phases that can be skipped during the upgrade process.
type skipPhases []skipPhase type skipPhases map[skipPhase]struct{}
// contains returns true if the list of phases contains the given phase. // contains returns true if the list of phases contains the given phase.
func (s skipPhases) contains(phase skipPhase) bool { func (s skipPhases) contains(phase skipPhase) bool {
for _, p := range s { _, ok := s[skipPhase(strings.ToLower(string(phase)))]
if strings.EqualFold(string(p), string(phase)) { return ok
return true }
}
// add a phase to the list of phases.
func (s *skipPhases) add(phases ...skipPhase) {
if *s == nil {
*s = make(skipPhases)
}
for _, phase := range phases {
(*s)[skipPhase(strings.ToLower(string(phase)))] = struct{}{}
} }
return false
} }
type kubernetesUpgrader interface { type kubernetesUpgrader interface {

View file

@ -41,8 +41,9 @@ func TestUpgradeApply(t *testing.T) {
InitSecret: []byte{0x42}, InitSecret: []byte{0x42},
}). }).
SetClusterValues(state.ClusterValues{MeasurementSalt: []byte{0x41}}) SetClusterValues(state.ClusterValues{MeasurementSalt: []byte{0x41}})
fsWithStateFile := func() file.Handler { fsWithStateFileAndTfState := func() file.Handler {
fh := file.NewHandler(afero.NewMemMapFs()) fh := file.NewHandler(afero.NewMemMapFs())
require.NoError(t, fh.MkdirAll(constants.TerraformWorkingDir))
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState)) require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState))
return fh return fh
} }
@ -55,15 +56,15 @@ func TestUpgradeApply(t *testing.T) {
terraformUpgrader clusterUpgrader terraformUpgrader clusterUpgrader
wantErr bool wantErr bool
customK8sVersion string customK8sVersion string
flags upgradeApplyFlags flags applyFlags
stdin string stdin string
}{ }{
"success": { "success": {
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()}, kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
helmUpgrader: stubApplier{}, helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{}, terraformUpgrader: &stubTerraformUpgrader{},
flags: upgradeApplyFlags{yes: true}, flags: applyFlags{yes: true},
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
fhAssertions: func(require *require.Assertions, assert *assert.Assertions, fh file.Handler) { fhAssertions: func(require *require.Assertions, assert *assert.Assertions, fh file.Handler) {
gotState, err := state.ReadFromFile(fh, constants.StateFilename) gotState, err := state.ReadFromFile(fh, constants.StateFilename)
require.NoError(err) require.NoError(err)
@ -71,11 +72,11 @@ func TestUpgradeApply(t *testing.T) {
assert.Equal(defaultState, gotState) assert.Equal(defaultState, gotState)
}, },
}, },
"state file does not exist": { "id file and state file do not exist": {
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()}, kubeUpgrader: &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
helmUpgrader: stubApplier{}, helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{}, terraformUpgrader: &stubTerraformUpgrader{},
flags: upgradeApplyFlags{yes: true}, flags: applyFlags{yes: true},
fh: func() file.Handler { fh: func() file.Handler {
return file.NewHandler(afero.NewMemMapFs()) return file.NewHandler(afero.NewMemMapFs())
}, },
@ -89,8 +90,8 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{}, helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{}, terraformUpgrader: &stubTerraformUpgrader{},
wantErr: true, wantErr: true,
flags: upgradeApplyFlags{yes: true}, flags: applyFlags{yes: true},
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
}, },
"nodeVersion in progress error": { "nodeVersion in progress error": {
kubeUpgrader: &stubKubernetesUpgrader{ kubeUpgrader: &stubKubernetesUpgrader{
@ -99,8 +100,8 @@ func TestUpgradeApply(t *testing.T) {
}, },
helmUpgrader: stubApplier{}, helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{}, terraformUpgrader: &stubTerraformUpgrader{},
flags: upgradeApplyFlags{yes: true}, flags: applyFlags{yes: true},
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
}, },
"helm other error": { "helm other error": {
kubeUpgrader: &stubKubernetesUpgrader{ kubeUpgrader: &stubKubernetesUpgrader{
@ -109,8 +110,8 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{err: assert.AnError}, helmUpgrader: stubApplier{err: assert.AnError},
terraformUpgrader: &stubTerraformUpgrader{}, terraformUpgrader: &stubTerraformUpgrader{},
wantErr: true, wantErr: true,
flags: upgradeApplyFlags{yes: true}, flags: applyFlags{yes: true},
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
}, },
"abort": { "abort": {
kubeUpgrader: &stubKubernetesUpgrader{ kubeUpgrader: &stubKubernetesUpgrader{
@ -120,7 +121,7 @@ func TestUpgradeApply(t *testing.T) {
terraformUpgrader: &stubTerraformUpgrader{terraformDiff: true}, terraformUpgrader: &stubTerraformUpgrader{terraformDiff: true},
wantErr: true, wantErr: true,
stdin: "no\n", stdin: "no\n",
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
}, },
"abort, restore terraform err": { "abort, restore terraform err": {
kubeUpgrader: &stubKubernetesUpgrader{ kubeUpgrader: &stubKubernetesUpgrader{
@ -130,7 +131,7 @@ func TestUpgradeApply(t *testing.T) {
terraformUpgrader: &stubTerraformUpgrader{terraformDiff: true, rollbackWorkspaceErr: assert.AnError}, terraformUpgrader: &stubTerraformUpgrader{terraformDiff: true, rollbackWorkspaceErr: assert.AnError},
wantErr: true, wantErr: true,
stdin: "no\n", stdin: "no\n",
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
}, },
"plan terraform error": { "plan terraform error": {
kubeUpgrader: &stubKubernetesUpgrader{ kubeUpgrader: &stubKubernetesUpgrader{
@ -139,8 +140,8 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{}, helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{planTerraformErr: assert.AnError}, terraformUpgrader: &stubTerraformUpgrader{planTerraformErr: assert.AnError},
wantErr: true, wantErr: true,
flags: upgradeApplyFlags{yes: true}, flags: applyFlags{yes: true},
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
}, },
"apply terraform error": { "apply terraform error": {
kubeUpgrader: &stubKubernetesUpgrader{ kubeUpgrader: &stubKubernetesUpgrader{
@ -152,8 +153,8 @@ func TestUpgradeApply(t *testing.T) {
terraformDiff: true, terraformDiff: true,
}, },
wantErr: true, wantErr: true,
flags: upgradeApplyFlags{yes: true}, flags: applyFlags{yes: true},
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
}, },
"outdated K8s patch version": { "outdated K8s patch version": {
kubeUpgrader: &stubKubernetesUpgrader{ kubeUpgrader: &stubKubernetesUpgrader{
@ -166,8 +167,8 @@ func TestUpgradeApply(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
return semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String() return semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String()
}(), }(),
flags: upgradeApplyFlags{yes: true}, flags: applyFlags{yes: true},
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
}, },
"outdated K8s version": { "outdated K8s version": {
kubeUpgrader: &stubKubernetesUpgrader{ kubeUpgrader: &stubKubernetesUpgrader{
@ -176,9 +177,9 @@ func TestUpgradeApply(t *testing.T) {
helmUpgrader: stubApplier{}, helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{}, terraformUpgrader: &stubTerraformUpgrader{},
customK8sVersion: "v1.20.0", customK8sVersion: "v1.20.0",
flags: upgradeApplyFlags{yes: true}, flags: applyFlags{yes: true},
wantErr: true, wantErr: true,
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
}, },
"skip all upgrade phases": { "skip all upgrade phases": {
kubeUpgrader: &stubKubernetesUpgrader{ kubeUpgrader: &stubKubernetesUpgrader{
@ -186,11 +187,14 @@ func TestUpgradeApply(t *testing.T) {
}, },
helmUpgrader: &mockApplier{}, // mocks ensure that no methods are called helmUpgrader: &mockApplier{}, // mocks ensure that no methods are called
terraformUpgrader: &mockTerraformUpgrader{}, terraformUpgrader: &mockTerraformUpgrader{},
flags: upgradeApplyFlags{ flags: applyFlags{
skipPhases: []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, skipPhases: skipPhases{
yes: true, skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
skipK8sPhase: struct{}{}, skipImagePhase: struct{}{},
},
yes: true,
}, },
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
}, },
"skip all phases except node upgrade": { "skip all phases except node upgrade": {
kubeUpgrader: &stubKubernetesUpgrader{ kubeUpgrader: &stubKubernetesUpgrader{
@ -198,11 +202,37 @@ func TestUpgradeApply(t *testing.T) {
}, },
helmUpgrader: &mockApplier{}, // mocks ensure that no methods are called helmUpgrader: &mockApplier{}, // mocks ensure that no methods are called
terraformUpgrader: &mockTerraformUpgrader{}, terraformUpgrader: &mockTerraformUpgrader{},
flags: upgradeApplyFlags{ flags: applyFlags{
skipPhases: []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase}, skipPhases: skipPhases{
yes: true, skipInfrastructurePhase: struct{}{}, skipHelmPhase: struct{}{},
skipK8sPhase: struct{}{},
},
yes: true,
}, },
fh: fsWithStateFile, fh: fsWithStateFileAndTfState,
},
"no tf state, skip infrastructure upgrade": {
kubeUpgrader: &stubKubernetesUpgrader{
currentConfig: config.DefaultForAzureSEVSNP(),
},
helmUpgrader: &stubApplier{},
terraformUpgrader: &mockTerraformUpgrader{},
flags: applyFlags{
yes: true,
},
fh: func() file.Handler {
fh := file.NewHandler(afero.NewMemMapFs())
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState))
return fh
},
},
"attempt to change attestation variant": {
kubeUpgrader: &stubKubernetesUpgrader{currentConfig: &config.AzureTrustedLaunch{}},
helmUpgrader: stubApplier{},
terraformUpgrader: &stubTerraformUpgrader{},
flags: applyFlags{yes: true},
fh: fsWithStateFileAndTfState,
wantErr: true,
}, },
} }
@ -218,20 +248,26 @@ func TestUpgradeApply(t *testing.T) {
cfg.KubernetesVersion = versions.ValidK8sVersion(tc.customK8sVersion) cfg.KubernetesVersion = versions.ValidK8sVersion(tc.customK8sVersion)
} }
fh := tc.fh() fh := tc.fh()
require.NoError(fh.Write(constants.AdminConfFilename, []byte{}))
require.NoError(fh.WriteYAML(constants.ConfigFilename, cfg)) require.NoError(fh.WriteYAML(constants.ConfigFilename, cfg))
require.NoError(fh.WriteJSON(constants.MasterSecretFilename, uri.MasterSecret{})) require.NoError(fh.WriteJSON(constants.MasterSecretFilename, uri.MasterSecret{}))
upgrader := upgradeApplyCmd{ upgrader := &applyCmd{
kubeUpgrader: tc.kubeUpgrader, fileHandler: fh,
helmApplier: tc.helmUpgrader, flags: tc.flags,
log: logger.NewTest(t),
spinner: &nopSpinner{},
merger: &stubMerger{},
quotaChecker: &stubLicenseClient{},
newHelmClient: func(string, debugLog) (helmApplier, error) {
return tc.helmUpgrader, nil
},
newKubeUpgrader: func(_ io.Writer, _ string, _ debugLog) (kubernetesUpgrader, error) {
return tc.kubeUpgrader, nil
},
clusterUpgrader: tc.terraformUpgrader, clusterUpgrader: tc.terraformUpgrader,
log: logger.NewTest(t),
configFetcher: stubAttestationFetcher{},
flags: tc.flags,
fileHandler: fh,
} }
err := upgrader.apply(cmd, stubAttestationFetcher{}, "test")
err := upgrader.upgradeApply(cmd, "test")
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
return return
@ -255,27 +291,37 @@ func TestUpgradeApplyFlagsForSkipPhases(t *testing.T) {
cmd.Flags().Bool("force", true, "") cmd.Flags().Bool("force", true, "")
cmd.Flags().String("tf-log", "NONE", "") cmd.Flags().String("tf-log", "NONE", "")
cmd.Flags().Bool("debug", false, "") cmd.Flags().Bool("debug", false, "")
cmd.Flags().Bool("merge-kubeconfig", false, "")
require.NoError(cmd.Flags().Set("skip-phases", "infrastructure,helm,k8s,image")) require.NoError(cmd.Flags().Set("skip-phases", "infrastructure,helm,k8s,image"))
wantPhases := skipPhases{}
wantPhases.add(skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase)
var flags upgradeApplyFlags var flags applyFlags
err := flags.parse(cmd.Flags()) err := flags.parse(cmd.Flags())
require.NoError(err) require.NoError(err)
assert.ElementsMatch(t, []skipPhase{skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase}, flags.skipPhases) assert.Equal(t, wantPhases, flags.skipPhases)
} }
type stubKubernetesUpgrader struct { type stubKubernetesUpgrader struct {
nodeVersionErr error nodeVersionErr error
currentConfig config.AttestationCfg currentConfig config.AttestationCfg
calledNodeUpgrade bool getClusterAttestationConfigErr error
calledNodeUpgrade bool
backupCRDsErr error
backupCRDsCalled bool
backupCRsErr error
backupCRsCalled bool
} }
func (u *stubKubernetesUpgrader) BackupCRDs(_ context.Context, _ string) ([]apiextensionsv1.CustomResourceDefinition, error) { func (u *stubKubernetesUpgrader) BackupCRDs(_ context.Context, _ string) ([]apiextensionsv1.CustomResourceDefinition, error) {
return []apiextensionsv1.CustomResourceDefinition{}, nil u.backupCRDsCalled = true
return []apiextensionsv1.CustomResourceDefinition{}, u.backupCRDsErr
} }
func (u *stubKubernetesUpgrader) BackupCRs(_ context.Context, _ []apiextensionsv1.CustomResourceDefinition, _ string) error { func (u *stubKubernetesUpgrader) BackupCRs(_ context.Context, _ []apiextensionsv1.CustomResourceDefinition, _ string) error {
return nil u.backupCRsCalled = true
return u.backupCRsErr
} }
func (u *stubKubernetesUpgrader) UpgradeNodeVersion(_ context.Context, _ *config.Config, _, _, _ bool) error { func (u *stubKubernetesUpgrader) UpgradeNodeVersion(_ context.Context, _ *config.Config, _, _, _ bool) error {
@ -288,7 +334,7 @@ func (u *stubKubernetesUpgrader) ApplyJoinConfig(_ context.Context, _ config.Att
} }
func (u *stubKubernetesUpgrader) GetClusterAttestationConfig(_ context.Context, _ variant.Variant) (config.AttestationCfg, error) { func (u *stubKubernetesUpgrader) GetClusterAttestationConfig(_ context.Context, _ variant.Variant) (config.AttestationCfg, error) {
return u.currentConfig, nil return u.currentConfig, u.getClusterAttestationConfigErr
} }
func (u *stubKubernetesUpgrader) ExtendClusterConfigCertSANs(_ context.Context, _ []string) error { func (u *stubKubernetesUpgrader) ExtendClusterConfigCertSANs(_ context.Context, _ []string) error {

View file

@ -287,12 +287,17 @@ func (k *KubeCmd) ExtendClusterConfigCertSANs(ctx context.Context, alternativeNa
var missingSANs []string var missingSANs []string
for _, san := range alternativeNames { for _, san := range alternativeNames {
if san == "" {
continue // skip empty SANs
}
if _, ok := existingSANs[san]; !ok { if _, ok := existingSANs[san]; !ok {
missingSANs = append(missingSANs, san) missingSANs = append(missingSANs, san)
existingSANs[san] = struct{}{} // make sure we don't add the same SAN twice
} }
} }
if len(missingSANs) == 0 { if len(missingSANs) == 0 {
k.log.Debugf("No new SANs to add to the cluster's apiserver SAN field")
return nil return nil
} }
k.log.Debugf("Extending the cluster's apiserver SAN field with the following SANs: %s\n", strings.Join(missingSANs, ", ")) k.log.Debugf("Extending the cluster's apiserver SAN field with the following SANs: %s\n", strings.Join(missingSANs, ", "))

View file

@ -24,7 +24,8 @@ type AttestationCfg interface {
SetMeasurements(m measurements.M) SetMeasurements(m measurements.M)
// GetVariant returns the variant of the attestation config. // GetVariant returns the variant of the attestation config.
GetVariant() variant.Variant GetVariant() variant.Variant
// NewerThan returns true if the config is equal to the given config. // EqualTo returns true if the config is equal to the given config.
// If the variant differs, an error must be returned.
EqualTo(AttestationCfg) (bool, error) EqualTo(AttestationCfg) (bool, error)
} }