/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package cloudcmd import ( "bytes" "context" "io" "path/filepath" "runtime" "testing" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/edgelesssys/constellation/v2/cli/internal/terraform" "github.com/edgelesssys/constellation/v2/internal/attestation/variant" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/file" ) func TestApplier(t *testing.T) { failOnNonAMD64 := (runtime.GOARCH != "amd64") || (runtime.GOOS != "linux") ip := "192.0.2.1" configWithProvider := func(provider cloudprovider.Provider) *config.Config { cfg := config.Default() cfg.RemoveProviderAndAttestationExcept(provider) return cfg } testCases := map[string]struct { tfClient tfResourceClient newTfClientErr error libvirt *stubLibvirtRunner provider cloudprovider.Provider config *config.Config policyPatcher *stubPolicyPatcher wantErr bool wantRollback bool // Use only together with stubClients. wantTerraformRollback bool // When libvirt fails, don't call into Terraform. }{ "gcp": { tfClient: &stubTerraformClient{ip: ip}, provider: cloudprovider.GCP, config: configWithProvider(cloudprovider.GCP), }, "gcp create cluster error": { tfClient: &stubTerraformClient{applyClusterErr: assert.AnError}, provider: cloudprovider.GCP, config: configWithProvider(cloudprovider.GCP), wantErr: true, wantRollback: true, wantTerraformRollback: true, }, "azure": { tfClient: &stubTerraformClient{ip: ip}, provider: cloudprovider.Azure, config: configWithProvider(cloudprovider.Azure), policyPatcher: &stubPolicyPatcher{}, }, "azure trusted launch": { tfClient: &stubTerraformClient{ip: ip}, provider: cloudprovider.Azure, config: func() *config.Config { cfg := config.Default() cfg.RemoveProviderAndAttestationExcept(cloudprovider.Azure) cfg.Attestation = config.AttestationConfig{ AzureTrustedLaunch: &config.AzureTrustedLaunch{}, } return cfg }(), policyPatcher: &stubPolicyPatcher{}, }, "azure new policy patch error": { tfClient: &stubTerraformClient{ip: ip}, provider: cloudprovider.Azure, config: configWithProvider(cloudprovider.Azure), policyPatcher: &stubPolicyPatcher{assert.AnError}, wantErr: true, }, "azure create cluster error": { tfClient: &stubTerraformClient{applyClusterErr: assert.AnError}, provider: cloudprovider.Azure, config: configWithProvider(cloudprovider.Azure), policyPatcher: &stubPolicyPatcher{}, wantErr: true, wantRollback: true, wantTerraformRollback: true, }, "openstack": { tfClient: &stubTerraformClient{ip: ip}, libvirt: &stubLibvirtRunner{}, provider: cloudprovider.OpenStack, config: func() *config.Config { cfg := config.Default() cfg.RemoveProviderAndAttestationExcept(cloudprovider.OpenStack) cfg.Provider.OpenStack.Cloud = "testcloud" return cfg }(), }, "openstack without clouds.yaml": { tfClient: &stubTerraformClient{ip: ip}, libvirt: &stubLibvirtRunner{}, provider: cloudprovider.OpenStack, config: configWithProvider(cloudprovider.OpenStack), wantErr: true, }, "openstack create cluster error": { tfClient: &stubTerraformClient{applyClusterErr: assert.AnError}, libvirt: &stubLibvirtRunner{}, provider: cloudprovider.OpenStack, config: func() *config.Config { cfg := config.Default() cfg.RemoveProviderAndAttestationExcept(cloudprovider.OpenStack) cfg.Provider.OpenStack.Cloud = "testcloud" return cfg }(), wantErr: true, wantRollback: true, wantTerraformRollback: true, }, "qemu": { tfClient: &stubTerraformClient{ip: ip}, libvirt: &stubLibvirtRunner{}, provider: cloudprovider.QEMU, config: configWithProvider(cloudprovider.QEMU), wantErr: failOnNonAMD64, }, "qemu create cluster error": { tfClient: &stubTerraformClient{applyClusterErr: assert.AnError}, libvirt: &stubLibvirtRunner{}, provider: cloudprovider.QEMU, config: configWithProvider(cloudprovider.QEMU), wantErr: true, wantRollback: !failOnNonAMD64, // if we run on non-AMD64/linux, we don't get to a point where rollback is needed wantTerraformRollback: true, }, "qemu start libvirt error": { tfClient: &stubTerraformClient{ip: ip}, libvirt: &stubLibvirtRunner{startErr: assert.AnError}, provider: cloudprovider.QEMU, config: configWithProvider(cloudprovider.QEMU), wantRollback: !failOnNonAMD64, wantTerraformRollback: false, wantErr: true, }, "unknown provider": { tfClient: &stubTerraformClient{}, provider: cloudprovider.Unknown, config: func() *config.Config { cfg := config.Default() cfg.RemoveProviderAndAttestationExcept(cloudprovider.AWS) cfg.Provider.AWS = nil return cfg }(), wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) applier := &Applier{ fileHandler: file.NewHandler(afero.NewMemMapFs()), imageFetcher: &stubImageFetcher{ reference: "some-image", }, terraformClient: tc.tfClient, libvirtRunner: tc.libvirt, rawDownloader: &stubRawDownloader{ destination: "some-destination", }, policyPatcher: tc.policyPatcher, logLevel: terraform.LogLevelNone, workingDir: "test", backupDir: "test-backup", out: &bytes.Buffer{}, } diff, err := applier.Plan(context.Background(), tc.config) if err != nil { assert.True(tc.wantErr, "unexpected error: %s", err) return } assert.False(diff) idFile, err := applier.Apply(context.Background(), tc.provider, tc.config.GetAttestationConfig().GetVariant(), true) if tc.wantErr { assert.Error(err) if tc.wantRollback { cl := tc.tfClient.(*stubTerraformClient) if tc.wantTerraformRollback { assert.True(cl.destroyCalled) } assert.True(cl.cleanUpWorkspaceCalled) if tc.provider == cloudprovider.QEMU { assert.True(tc.libvirt.stopCalled) } } } else { assert.NoError(err) assert.Equal(ip, idFile.ClusterEndpoint) } }) } } func TestPlan(t *testing.T) { setUpFilesystem := func(existingFiles []string) file.Handler { fs := file.NewHandler(afero.NewMemMapFs()) require.NoError(t, fs.Write("test/terraform.tfstate", []byte{}, file.OptMkdirAll)) for _, f := range existingFiles { require.NoError(t, fs.Write(f, []byte{})) } return fs } testCases := map[string]struct { upgradeID string tf *stubTerraformClient fs file.Handler want bool wantErr bool }{ "success no diff": { upgradeID: "1234", tf: &stubTerraformClient{}, fs: setUpFilesystem([]string{}), }, "success diff": { upgradeID: "1234", tf: &stubTerraformClient{ planDiff: true, }, fs: setUpFilesystem([]string{}), want: true, }, "prepare workspace error": { upgradeID: "1234", tf: &stubTerraformClient{ prepareWorkspaceErr: assert.AnError, }, fs: setUpFilesystem([]string{}), wantErr: true, }, "plan error": { tf: &stubTerraformClient{ planErr: assert.AnError, }, fs: setUpFilesystem([]string{}), wantErr: true, }, "show plan error no diff": { upgradeID: "1234", tf: &stubTerraformClient{ showPlanErr: assert.AnError, }, fs: setUpFilesystem([]string{}), }, "show plan error diff": { upgradeID: "1234", tf: &stubTerraformClient{ showPlanErr: assert.AnError, planDiff: true, }, fs: setUpFilesystem([]string{}), wantErr: true, }, "workspace not clean": { upgradeID: "1234", tf: &stubTerraformClient{}, fs: setUpFilesystem([]string{filepath.Join(constants.UpgradeDir, "1234", constants.TerraformUpgradeBackupDir)}), wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { require := require.New(t) u := &Applier{ terraformClient: tc.tf, policyPatcher: stubPolicyPatcher{}, fileHandler: tc.fs, imageFetcher: &stubImageFetcher{reference: "some-image"}, rawDownloader: &stubRawDownloader{destination: "some-destination"}, libvirtRunner: &stubLibvirtRunner{}, logLevel: terraform.LogLevelDebug, backupDir: filepath.Join(constants.UpgradeDir, tc.upgradeID), workingDir: "test", out: io.Discard, } cfg := config.Default() cfg.RemoveProviderAndAttestationExcept(cloudprovider.Azure) diff, err := u.Plan(context.Background(), cfg) if tc.wantErr { require.Error(err) } else { require.NoError(err) require.Equal(tc.want, diff) } }) } } func TestApply(t *testing.T) { testCases := map[string]struct { upgradeID string tf *stubTerraformClient policyPatcher stubPolicyPatcher fs file.Handler wantErr bool }{ "success": { upgradeID: "1234", tf: &stubTerraformClient{}, policyPatcher: stubPolicyPatcher{}, }, "apply error": { upgradeID: "1234", tf: &stubTerraformClient{ applyClusterErr: assert.AnError, }, policyPatcher: stubPolicyPatcher{}, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := require.New(t) u := &Applier{ terraformClient: tc.tf, logLevel: terraform.LogLevelDebug, libvirtRunner: &stubLibvirtRunner{}, policyPatcher: stubPolicyPatcher{}, fileHandler: tc.fs, backupDir: filepath.Join(constants.UpgradeDir, tc.upgradeID), workingDir: "test", out: io.Discard, } _, err := u.Apply(context.Background(), cloudprovider.QEMU, variant.QEMUVTPM{}, WithoutRollbackOnError) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) } }) } } type stubPolicyPatcher struct { patchErr error } func (s stubPolicyPatcher) Patch(_ context.Context, _ string) error { return s.patchErr }