constellation/cli/internal/cmd/apply_test.go
Moritz Sanft 4d6a7fa759
license: refactor license check to be agnostic of input (#2659)
* license: refactor license check to be agnostic of input

* license: remove unused code

* cli: only check license file in enterprise version

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* bazel: fix enterprise CLI build

* bazel: add keep directive

* Update internal/constellation/apply.go

Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>

* license: check for return value

---------

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>
Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>
2023-12-01 08:37:52 +01:00

496 lines
16 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"bytes"
"context"
"errors"
"fmt"
"path/filepath"
"strings"
"testing"
"time"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/helm"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/state"
"github.com/spf13/afero"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// defaultStateFile returns a valid default state for testing.
func defaultStateFile(csp cloudprovider.Provider) *state.State {
stateFile := &state.State{
Version: "v1",
Infrastructure: state.Infrastructure{
UID: "123",
Name: "test-cluster",
ClusterEndpoint: "192.0.2.1",
InClusterEndpoint: "192.0.2.1",
InitSecret: []byte{0x41},
APIServerCertSANs: []string{
"127.0.0.1",
"www.example.com",
},
IPCidrNode: "0.0.0.0/24",
Azure: &state.Azure{
ResourceGroup: "test-rg",
SubscriptionID: "test-sub",
NetworkSecurityGroupName: "test-nsg",
LoadBalancerName: "test-lb",
UserAssignedIdentity: "test-uami",
AttestationURL: "test-maaUrl",
},
GCP: &state.GCP{
ProjectID: "test-project",
IPCidrPod: "0.0.0.0/24",
},
},
ClusterValues: state.ClusterValues{
ClusterID: "deadbeef",
OwnerID: "deadbeef",
MeasurementSalt: []byte{0x41},
},
}
switch csp {
case cloudprovider.GCP:
stateFile.Infrastructure.Azure = nil
case cloudprovider.Azure:
stateFile.Infrastructure.GCP = nil
default:
stateFile.Infrastructure.Azure = nil
stateFile.Infrastructure.GCP = nil
}
return stateFile
}
func TestParseApplyFlags(t *testing.T) {
require := require.New(t)
defaultFlags := func() *pflag.FlagSet {
flags := NewApplyCmd().Flags()
// Register persistent flags
flags.String("workspace", "", "")
flags.String("tf-log", "NONE", "")
flags.Bool("force", false, "")
flags.Bool("debug", false, "")
return flags
}
testCases := map[string]struct {
flags *pflag.FlagSet
wantFlags applyFlags
wantErr bool
}{
"default flags": {
flags: defaultFlags(),
wantFlags: applyFlags{
helmWaitMode: helm.WaitModeAtomic,
helmTimeout: 10 * time.Minute,
},
},
"skip phases": {
flags: func() *pflag.FlagSet {
flags := defaultFlags()
require.NoError(flags.Set("skip-phases", fmt.Sprintf("%s,%s", skipHelmPhase, skipK8sPhase)))
return flags
}(),
wantFlags: applyFlags{
skipPhases: newPhases(skipHelmPhase, skipK8sPhase),
helmWaitMode: helm.WaitModeAtomic,
helmTimeout: 10 * time.Minute,
},
},
"skip helm wait": {
flags: func() *pflag.FlagSet {
flags := defaultFlags()
require.NoError(flags.Set("skip-helm-wait", "true"))
return flags
}(),
wantFlags: applyFlags{
helmWaitMode: helm.WaitModeNone,
helmTimeout: 10 * time.Minute,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
var flags applyFlags
err := flags.parse(tc.flags)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantFlags, flags)
})
}
}
func TestBackupHelmCharts(t *testing.T) {
testCases := map[string]struct {
helmApplier helm.Applier
backupClient *stubKubernetesUpgrader
includesUpgrades bool
wantErr bool
}{
"success, no upgrades": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{},
},
"success with upgrades": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{},
includesUpgrades: true,
},
"saving charts fails": {
helmApplier: &stubRunner{
saveChartsErr: assert.AnError,
},
backupClient: &stubKubernetesUpgrader{},
wantErr: true,
},
"backup CRDs fails": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{
backupCRDsErr: assert.AnError,
},
includesUpgrades: true,
wantErr: true,
},
"backup CRs fails": {
helmApplier: &stubRunner{},
backupClient: &stubKubernetesUpgrader{
backupCRsErr: assert.AnError,
},
includesUpgrades: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
a := applyCmd{
fileHandler: file.NewHandler(afero.NewMemMapFs()),
log: logger.NewTest(t),
}
err := a.backupHelmCharts(context.Background(), tc.backupClient, tc.helmApplier, tc.includesUpgrades, "")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
if tc.includesUpgrades {
assert.True(tc.backupClient.backupCRDsCalled)
assert.True(tc.backupClient.backupCRsCalled)
}
})
}
}
func TestSkipPhases(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
cmd := NewApplyCmd()
// register persistent flags manually
cmd.Flags().String("workspace", "", "")
cmd.Flags().Bool("force", true, "")
cmd.Flags().String("tf-log", "NONE", "")
cmd.Flags().Bool("debug", false, "")
require.NoError(cmd.Flags().Set("skip-phases", strings.Join(allPhases(), ",")))
wantPhases := newPhases(skipInfrastructurePhase, skipInitPhase, skipAttestationConfigPhase, skipCertSANsPhase, skipHelmPhase, skipK8sPhase, skipImagePhase)
var flags applyFlags
err := flags.parse(cmd.Flags())
require.NoError(err)
assert.Equal(wantPhases, flags.skipPhases)
phases := newPhases(skipAttestationConfigPhase, skipCertSANsPhase)
assert.True(phases.contains(skipAttestationConfigPhase, skipCertSANsPhase))
assert.False(phases.contains(skipAttestationConfigPhase, skipInitPhase))
assert.False(phases.contains(skipInitPhase, skipInfrastructurePhase))
}
func TestValidateInputs(t *testing.T) {
defaultConfig := func(csp cloudprovider.Provider) func(require *require.Assertions, fh file.Handler) {
return func(require *require.Assertions, fh file.Handler) {
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), csp)
if csp == cloudprovider.GCP {
require.NoError(fh.WriteJSON("saKey.json", &gcpshared.ServiceAccountKey{
Type: "service_account",
ProjectID: "project_id",
PrivateKeyID: "key_id",
PrivateKey: "key",
ClientEmail: "client_email",
ClientID: "client_id",
AuthURI: "auth_uri",
TokenURI: "token_uri",
AuthProviderX509CertURL: "cert",
ClientX509CertURL: "client_cert",
}))
cfg.Provider.GCP.ServiceAccountKeyPath = "saKey.json"
}
require.NoError(fh.WriteYAML(constants.ConfigFilename, cfg))
}
}
preInitState := func(csp cloudprovider.Provider) func(require *require.Assertions, fh file.Handler) {
return func(require *require.Assertions, fh file.Handler) {
stateFile := defaultStateFile(csp)
stateFile.ClusterValues = state.ClusterValues{}
require.NoError(fh.WriteYAML(constants.StateFilename, stateFile))
}
}
postInitState := func(csp cloudprovider.Provider) func(require *require.Assertions, fh file.Handler) {
return func(require *require.Assertions, fh file.Handler) {
require.NoError(fh.WriteYAML(constants.StateFilename, defaultStateFile(csp)))
}
}
defaultMasterSecret := func(require *require.Assertions, fh file.Handler) {
require.NoError(fh.WriteJSON(constants.MasterSecretFilename, &uri.MasterSecret{}))
}
defaultAdminConfig := func(require *require.Assertions, fh file.Handler) {
require.NoError(fh.Write(constants.AdminConfFilename, []byte("admin config")))
}
defaultTfState := func(require *require.Assertions, fh file.Handler) {
require.NoError(fh.Write(filepath.Join(constants.TerraformWorkingDir, "tfvars"), []byte("tf state")))
}
testCases := map[string]struct {
createConfig func(require *require.Assertions, fh file.Handler)
createState func(require *require.Assertions, fh file.Handler)
createMasterSecret func(require *require.Assertions, fh file.Handler)
createAdminConfig func(require *require.Assertions, fh file.Handler)
createTfState func(require *require.Assertions, fh file.Handler)
stdin string
flags applyFlags
wantPhases skipPhases
wantErr bool
}{
"[upgrade] gcp: all files exist": {
createConfig: defaultConfig(cloudprovider.GCP),
createState: postInitState(cloudprovider.GCP),
createMasterSecret: defaultMasterSecret,
createAdminConfig: defaultAdminConfig,
createTfState: defaultTfState,
flags: applyFlags{},
wantPhases: newPhases(skipInitPhase),
},
"[upgrade] aws: all files exist": {
createConfig: defaultConfig(cloudprovider.AWS),
createState: postInitState(cloudprovider.AWS),
createMasterSecret: defaultMasterSecret,
createAdminConfig: defaultAdminConfig,
createTfState: defaultTfState,
flags: applyFlags{},
wantPhases: newPhases(skipInitPhase),
},
"[upgrade] azure: all files exist": {
createConfig: defaultConfig(cloudprovider.Azure),
createState: postInitState(cloudprovider.Azure),
createMasterSecret: defaultMasterSecret,
createAdminConfig: defaultAdminConfig,
createTfState: defaultTfState,
flags: applyFlags{},
wantPhases: newPhases(skipInitPhase),
},
"[upgrade] qemu: all files exist": {
createConfig: defaultConfig(cloudprovider.QEMU),
createState: postInitState(cloudprovider.QEMU),
createMasterSecret: defaultMasterSecret,
createAdminConfig: defaultAdminConfig,
createTfState: defaultTfState,
flags: applyFlags{},
wantPhases: newPhases(skipInitPhase, skipImagePhase), // No image upgrades on QEMU
},
"no config file errors": {
createConfig: func(require *require.Assertions, fh file.Handler) {},
createState: postInitState(cloudprovider.GCP),
createMasterSecret: defaultMasterSecret,
createAdminConfig: defaultAdminConfig,
createTfState: defaultTfState,
flags: applyFlags{},
wantErr: true,
},
"[init] no admin config file, but mastersecret file exists errors": {
createConfig: defaultConfig(cloudprovider.GCP),
createState: preInitState(cloudprovider.GCP),
createMasterSecret: defaultMasterSecret,
createAdminConfig: func(require *require.Assertions, fh file.Handler) {},
createTfState: defaultTfState,
flags: applyFlags{},
wantErr: true,
},
"[init] no admin config file, no master secret": {
createConfig: defaultConfig(cloudprovider.GCP),
createState: preInitState(cloudprovider.GCP),
createMasterSecret: func(require *require.Assertions, fh file.Handler) {},
createAdminConfig: func(require *require.Assertions, fh file.Handler) {},
createTfState: defaultTfState,
flags: applyFlags{},
wantPhases: newPhases(skipImagePhase, skipK8sPhase),
},
"[create] no tf state, but admin config exists errors": {
createConfig: defaultConfig(cloudprovider.GCP),
createState: preInitState(cloudprovider.GCP),
createMasterSecret: defaultMasterSecret,
createAdminConfig: defaultAdminConfig,
createTfState: func(require *require.Assertions, fh file.Handler) {},
flags: applyFlags{},
wantErr: true,
},
"[create] only config, skip everything but infrastructure": {
createConfig: defaultConfig(cloudprovider.GCP),
createState: func(require *require.Assertions, fh file.Handler) {},
createMasterSecret: func(require *require.Assertions, fh file.Handler) {},
createAdminConfig: func(require *require.Assertions, fh file.Handler) {},
createTfState: func(require *require.Assertions, fh file.Handler) {},
flags: applyFlags{
skipPhases: newPhases(skipInitPhase, skipAttestationConfigPhase, skipCertSANsPhase, skipHelmPhase, skipK8sPhase, skipImagePhase),
},
wantPhases: newPhases(skipInitPhase, skipAttestationConfigPhase, skipCertSANsPhase, skipHelmPhase, skipK8sPhase, skipImagePhase),
},
"[create + init] only config file": {
createConfig: defaultConfig(cloudprovider.GCP),
createState: func(require *require.Assertions, fh file.Handler) {},
createMasterSecret: func(require *require.Assertions, fh file.Handler) {},
createAdminConfig: func(require *require.Assertions, fh file.Handler) {},
createTfState: func(require *require.Assertions, fh file.Handler) {},
flags: applyFlags{},
wantPhases: newPhases(skipImagePhase, skipK8sPhase),
},
"[init] self-managed: config and state file exist, skip-phases=infrastructure": {
createConfig: defaultConfig(cloudprovider.GCP),
createState: preInitState(cloudprovider.GCP),
createMasterSecret: func(require *require.Assertions, fh file.Handler) {},
createAdminConfig: func(require *require.Assertions, fh file.Handler) {},
createTfState: func(require *require.Assertions, fh file.Handler) {},
flags: applyFlags{
skipPhases: newPhases(skipInfrastructurePhase),
},
wantPhases: newPhases(skipInfrastructurePhase, skipImagePhase, skipK8sPhase),
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
tc.createConfig(require, fileHandler)
tc.createState(require, fileHandler)
tc.createMasterSecret(require, fileHandler)
tc.createAdminConfig(require, fileHandler)
tc.createTfState(require, fileHandler)
cmd := NewApplyCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
cmd.SetIn(bytes.NewBufferString(tc.stdin))
a := applyCmd{
log: logger.NewTest(t),
fileHandler: fileHandler,
flags: tc.flags,
}
_, _, err := a.validateInputs(cmd, &stubAttestationFetcher{})
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
var cfgErr *config.ValidationError
if errors.As(err, &cfgErr) {
t.Log(cfgErr.LongMessage())
}
assert.Equal(tc.wantPhases, a.flags.skipPhases)
})
}
}
func TestSkipPhasesCompletion(t *testing.T) {
testCases := map[string]struct {
toComplete string
wantSuggestions []string
}{
"empty": {
toComplete: "",
wantSuggestions: allPhases(),
},
"partial": {
toComplete: "hel",
wantSuggestions: []string{string(skipHelmPhase)},
},
"one full word": {
toComplete: string(skipHelmPhase),
},
"one full word with comma": {
toComplete: string(skipHelmPhase) + ",",
wantSuggestions: func() []string {
allPhases := allPhases()
var suggestions []string
for _, phase := range allPhases {
if phase == string(skipHelmPhase) {
continue
}
suggestions = append(suggestions, fmt.Sprintf("%s,%s", skipHelmPhase, phase))
}
return suggestions
}(),
},
"one full word, one partial": {
toComplete: string(skipHelmPhase) + ",ima",
wantSuggestions: []string{fmt.Sprintf("%s,%s", skipHelmPhase, skipImagePhase)},
},
"all phases": {
toComplete: strings.Join(allPhases(), ","),
wantSuggestions: []string{},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
suggestions, _ := skipPhasesCompletion(nil, nil, tc.toComplete)
assert.ElementsMatch(tc.wantSuggestions, suggestions, "got: %v, want: %v", suggestions, tc.wantSuggestions)
})
}
}
func newPhases(phases ...skipPhase) skipPhases {
skipPhases := skipPhases{}
skipPhases.add(phases...)
return skipPhases
}
type stubConstellApplier struct{}
func (s *stubConstellApplier) CheckLicense(context.Context, cloudprovider.Provider, string) (int, error) {
return 0, nil
}