diff --git a/cli/internal/cmd/apply.go b/cli/internal/cmd/apply.go index 44a394d76..d5ca8fe6e 100644 --- a/cli/internal/cmd/apply.go +++ b/cli/internal/cmd/apply.go @@ -16,6 +16,7 @@ import ( "net" "os" "path/filepath" + "slices" "strings" "time" @@ -41,7 +42,7 @@ import ( ) // phases that can be skipped during apply. -// New phases should also be added to [formatSkipPhases]. +// New phases should also be added to [allPhases]. const ( // skipInfrastructurePhase skips the Terraform apply of the apply process. skipInfrastructurePhase skipPhase = "infrastructure" @@ -59,9 +60,9 @@ const ( skipK8sPhase skipPhase = "k8s" ) -// formatSkipPhases returns a formatted string of all phases that can be skipped. -func formatSkipPhases() string { - return fmt.Sprintf("{ %s }", strings.Join([]string{ +// allPhases returns a list of all phases that can be skipped as strings. +func allPhases() []string { + return []string{ string(skipInfrastructurePhase), string(skipInitPhase), string(skipAttestationConfigPhase), @@ -69,7 +70,12 @@ func formatSkipPhases() string { string(skipHelmPhase), string(skipImagePhase), string(skipK8sPhase), - }, " | ")) + } +} + +// formatSkipPhases returns a formatted string of all phases that can be skipped. +func formatSkipPhases() string { + return fmt.Sprintf("{ %s }", strings.Join(allPhases(), " | ")) } // skipPhase is a phase of the upgrade process that can be skipped. @@ -142,10 +148,10 @@ func (f *applyFlags) parse(flags *pflag.FlagSet) error { } var skipPhases skipPhases for _, phase := range rawSkipPhases { - switch skipPhase(strings.ToLower(phase)) { - case skipInfrastructurePhase, skipHelmPhase, skipImagePhase, skipK8sPhase: + phase = strings.ToLower(phase) + if slices.Contains(allPhases(), phase) { skipPhases.add(skipPhase(phase)) - default: + } else { return fmt.Errorf("invalid phase %s", phase) } } @@ -568,8 +574,8 @@ func (a *applyCmd) runK8sUpgrade(cmd *cobra.Command, conf *config.Config, kubeUp ) error { err := kubeUpgrader.UpgradeNodeVersion( cmd.Context(), conf, a.flags.force, - a.flags.skipPhases.contains(skipK8sPhase), a.flags.skipPhases.contains(skipImagePhase), + a.flags.skipPhases.contains(skipK8sPhase), ) var upgradeErr *compatibility.InvalidUpgradeError diff --git a/cli/internal/cmd/apply_test.go b/cli/internal/cmd/apply_test.go index 3942ad5a9..de346deaa 100644 --- a/cli/internal/cmd/apply_test.go +++ b/cli/internal/cmd/apply_test.go @@ -9,6 +9,7 @@ package cmd import ( "context" "fmt" + "strings" "testing" "time" @@ -149,3 +150,22 @@ func TestBackupHelmCharts(t *testing.T) { }) } } + +func TestSkipPhases(t *testing.T) { + require := require.New(t) + cmd := NewApplyCmd() + // register persistent flags manually + cmd.Flags().String("workspace", "", "") + cmd.Flags().Bool("force", true, "") + cmd.Flags().String("tf-log", "NONE", "") + cmd.Flags().Bool("debug", false, "") + + require.NoError(cmd.Flags().Set("skip-phases", strings.Join(allPhases(), ","))) + wantPhases := skipPhases{} + wantPhases.add(skipInfrastructurePhase, skipInitPhase, skipAttestationConfigPhase, skipCertSANsPhase, skipHelmPhase, skipK8sPhase, skipImagePhase) + + var flags applyFlags + err := flags.parse(cmd.Flags()) + require.NoError(err) + assert.Equal(t, wantPhases, flags.skipPhases) +} diff --git a/cli/internal/cmd/upgradeapply_test.go b/cli/internal/cmd/upgradeapply_test.go index 227c14187..2eb72a593 100644 --- a/cli/internal/cmd/upgradeapply_test.go +++ b/cli/internal/cmd/upgradeapply_test.go @@ -283,26 +283,6 @@ func TestUpgradeApply(t *testing.T) { } } -func TestUpgradeApplyFlagsForSkipPhases(t *testing.T) { - require := require.New(t) - cmd := newUpgradeApplyCmd() - // register persistent flags manually - cmd.Flags().String("workspace", "", "") - cmd.Flags().Bool("force", true, "") - cmd.Flags().String("tf-log", "NONE", "") - cmd.Flags().Bool("debug", false, "") - cmd.Flags().Bool("merge-kubeconfig", false, "") - - require.NoError(cmd.Flags().Set("skip-phases", "infrastructure,helm,k8s,image")) - wantPhases := skipPhases{} - wantPhases.add(skipInfrastructurePhase, skipHelmPhase, skipK8sPhase, skipImagePhase) - - var flags applyFlags - err := flags.parse(cmd.Flags()) - require.NoError(err) - assert.Equal(t, wantPhases, flags.skipPhases) -} - type stubKubernetesUpgrader struct { nodeVersionErr error currentConfig config.AttestationCfg