constellation/cli/internal/cmd/apply_test.go
Daniel Weiße a0863bafe7
cli: fix apply flag issues (#2526)
* Fix flag order
* Fix missing phases in flag parsing

---------

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
2023-10-30 09:30:35 +01:00

172 lines
4.2 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/helm"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseApplyFlags(t *testing.T) {
require := require.New(t)
defaultFlags := func() *pflag.FlagSet {
flags := NewApplyCmd().Flags()
// Register persistent flags
flags.String("workspace", "", "")
flags.String("tf-log", "NONE", "")
flags.Bool("force", false, "")
flags.Bool("debug", false, "")
return flags
}
testCases := map[string]struct {
flags *pflag.FlagSet
wantFlags applyFlags
wantErr bool
}{
"default flags": {
flags: defaultFlags(),
wantFlags: applyFlags{
helmWaitMode: helm.WaitModeAtomic,
upgradeTimeout: 5 * time.Minute,
},
},
"skip phases": {
flags: func() *pflag.FlagSet {
flags := defaultFlags()
require.NoError(flags.Set("skip-phases", fmt.Sprintf("%s,%s", skipHelmPhase, skipK8sPhase)))
return flags
}(),
wantFlags: applyFlags{
skipPhases: skipPhases{skipHelmPhase: struct{}{}, skipK8sPhase: struct{}{}},
helmWaitMode: helm.WaitModeAtomic,
upgradeTimeout: 5 * time.Minute,
},
},
"skip helm wait": {
flags: func() *pflag.FlagSet {
flags := defaultFlags()
require.NoError(flags.Set("skip-helm-wait", "true"))
return flags
}(),
wantFlags: applyFlags{
helmWaitMode: helm.WaitModeNone,
upgradeTimeout: 5 * time.Minute,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
var flags applyFlags
err := flags.parse(tc.flags)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantFlags, flags)
})
}
}
func TestBackupHelmCharts(t *testing.T) {
testCases := map[string]struct {
helmApplier helm.Applier
backupClient *stubKubernetesUpgrader
includesUpgrades bool
wantErr bool
}{
"success, no upgrades": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{},
},
"success with upgrades": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{},
includesUpgrades: true,
},
"saving charts fails": {
helmApplier: &stubRunner{
saveChartsErr: assert.AnError,
},
backupClient: &stubKubernetesUpgrader{},
wantErr: true,
},
"backup CRDs fails": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{
backupCRDsErr: assert.AnError,
},
includesUpgrades: true,
wantErr: true,
},
"backup CRs fails": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{
backupCRsErr: assert.AnError,
},
includesUpgrades: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
a := applyCmd{
fileHandler: file.NewHandler(afero.NewMemMapFs()),
log: logger.NewTest(t),
}
err := a.backupHelmCharts(context.Background(), tc.backupClient, tc.helmApplier, tc.includesUpgrades, "")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
if tc.includesUpgrades {
assert.True(tc.backupClient.backupCRDsCalled)
assert.True(tc.backupClient.backupCRsCalled)
}
})
}
}
func TestSkipPhases(t *testing.T) {
require := require.New(t)
cmd := NewApplyCmd()
// register persistent flags manually
cmd.Flags().String("workspace", "", "")
cmd.Flags().Bool("force", true, "")
cmd.Flags().String("tf-log", "NONE", "")
cmd.Flags().Bool("debug", false, "")
require.NoError(cmd.Flags().Set("skip-phases", strings.Join(allPhases(), ",")))
wantPhases := skipPhases{}
wantPhases.add(skipInfrastructurePhase, skipInitPhase, skipAttestationConfigPhase, skipCertSANsPhase, skipHelmPhase, skipK8sPhase, skipImagePhase)
var flags applyFlags
err := flags.parse(cmd.Flags())
require.NoError(err)
assert.Equal(t, wantPhases, flags.skipPhases)
}