diff --git a/cli/internal/cloudcmd/clients.go b/cli/internal/cloudcmd/clients.go index 4959bbc3d..78978b6d8 100644 --- a/cli/internal/cloudcmd/clients.go +++ b/cli/internal/cloudcmd/clients.go @@ -10,10 +10,11 @@ import ( "context" "github.com/edgelesssys/constellation/v2/cli/internal/terraform" + "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" ) type terraformClient interface { - CreateCluster(ctx context.Context, name string, input terraform.Variables) (string, error) + CreateCluster(ctx context.Context, provider cloudprovider.Provider, name string, input terraform.Variables) (string, error) DestroyCluster(ctx context.Context) error CleanUpWorkspace() error RemoveInstaller() diff --git a/cli/internal/cloudcmd/clients_test.go b/cli/internal/cloudcmd/clients_test.go index 6a88a2121..effcb4d52 100644 --- a/cli/internal/cloudcmd/clients_test.go +++ b/cli/internal/cloudcmd/clients_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/edgelesssys/constellation/v2/cli/internal/terraform" + "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "go.uber.org/goleak" ) @@ -31,7 +32,7 @@ type stubTerraformClient struct { cleanUpWorkspaceErr error } -func (c *stubTerraformClient) CreateCluster(ctx context.Context, name string, input terraform.Variables) (string, error) { +func (c *stubTerraformClient) CreateCluster(ctx context.Context, provider cloudprovider.Provider, name string, input terraform.Variables) (string, error) { return c.ip, c.createClusterErr } diff --git a/cli/internal/cloudcmd/create.go b/cli/internal/cloudcmd/create.go index 490cc84a5..61ea85e13 100644 --- a/cli/internal/cloudcmd/create.go +++ b/cli/internal/cloudcmd/create.go @@ -25,7 +25,7 @@ import ( // Creator creates cloud resources. type Creator struct { out io.Writer - newTerraformClient func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error) + newTerraformClient func(ctx context.Context) (terraformClient, error) newLibvirtRunner func() libvirtRunner } @@ -33,8 +33,8 @@ type Creator struct { func NewCreator(out io.Writer) *Creator { return &Creator{ out: out, - newTerraformClient: func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error) { - return terraform.New(ctx, provider) + newTerraformClient: func(ctx context.Context) (terraformClient, error) { + return terraform.New(ctx) }, newLibvirtRunner: func() libvirtRunner { return libvirt.New() @@ -51,21 +51,21 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c if os.Getenv("CONSTELLATION_AWS_DEV") != "1" { return clusterid.File{}, fmt.Errorf("AWS isn't supported yet") } - cl, err := c.newTerraformClient(ctx, provider) + cl, err := c.newTerraformClient(ctx) if err != nil { return clusterid.File{}, err } defer cl.RemoveInstaller() return c.createAWS(ctx, cl, config, name, insType, controlPlaneCount, workerCount) case cloudprovider.GCP: - cl, err := c.newTerraformClient(ctx, provider) + cl, err := c.newTerraformClient(ctx) if err != nil { return clusterid.File{}, err } defer cl.RemoveInstaller() return c.createGCP(ctx, cl, config, name, insType, controlPlaneCount, workerCount) case cloudprovider.Azure: - cl, err := c.newTerraformClient(ctx, provider) + cl, err := c.newTerraformClient(ctx) if err != nil { return clusterid.File{}, err } @@ -75,7 +75,7 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c if runtime.GOARCH != "amd64" || runtime.GOOS != "linux" { return clusterid.File{}, fmt.Errorf("creation of a QEMU based Constellation is not supported for %s/%s", runtime.GOOS, runtime.GOARCH) } - cl, err := c.newTerraformClient(ctx, provider) + cl, err := c.newTerraformClient(ctx) if err != nil { return clusterid.File{}, err } @@ -108,7 +108,7 @@ func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *con Debug: config.IsDebugCluster(), } - ip, err := cl.CreateCluster(ctx, name, vars) + ip, err := cl.CreateCluster(ctx, cloudprovider.AWS, name, vars) if err != nil { return clusterid.File{}, err } @@ -141,7 +141,7 @@ func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *con Debug: config.IsDebugCluster(), } - ip, err := cl.CreateCluster(ctx, name, &vars) + ip, err := cl.CreateCluster(ctx, cloudprovider.GCP, name, &vars) if err != nil { return clusterid.File{}, err } @@ -177,7 +177,7 @@ func (c *Creator) createAzure(ctx context.Context, cl terraformClient, config *c vars = normalizeAzureURIs(vars) - ip, err := cl.CreateCluster(ctx, name, &vars) + ip, err := cl.CreateCluster(ctx, cloudprovider.Azure, name, &vars) if err != nil { return clusterid.File{}, err } @@ -258,7 +258,7 @@ func (c *Creator) createQEMU(ctx context.Context, cl terraformClient, lv libvirt Firmware: config.Provider.QEMU.Firmware, } - ip, err := cl.CreateCluster(ctx, name, &vars) + ip, err := cl.CreateCluster(ctx, cloudprovider.QEMU, name, &vars) if err != nil { return clusterid.File{}, err } diff --git a/cli/internal/cloudcmd/create_test.go b/cli/internal/cloudcmd/create_test.go index 16be13c55..388d0d228 100644 --- a/cli/internal/cloudcmd/create_test.go +++ b/cli/internal/cloudcmd/create_test.go @@ -94,7 +94,7 @@ func TestCreator(t *testing.T) { creator := &Creator{ out: &bytes.Buffer{}, - newTerraformClient: func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error) { + newTerraformClient: func(ctx context.Context) (terraformClient, error) { return tc.tfClient, tc.newTfClientErr }, newLibvirtRunner: func() libvirtRunner { diff --git a/cli/internal/cloudcmd/rollback.go b/cli/internal/cloudcmd/rollback.go index 1e28fc949..0154f8f39 100644 --- a/cli/internal/cloudcmd/rollback.go +++ b/cli/internal/cloudcmd/rollback.go @@ -41,7 +41,9 @@ type rollbackerTerraform struct { func (r *rollbackerTerraform) rollback(ctx context.Context) error { var err error err = multierr.Append(err, r.client.DestroyCluster(ctx)) - err = multierr.Append(err, r.client.CleanUpWorkspace()) + if err == nil { + err = multierr.Append(err, r.client.CleanUpWorkspace()) + } return err } @@ -54,6 +56,8 @@ func (r *rollbackerQEMU) rollback(ctx context.Context) error { var err error err = multierr.Append(err, r.client.DestroyCluster(ctx)) err = multierr.Append(err, r.libvirt.Stop(ctx)) - err = multierr.Append(err, r.client.CleanUpWorkspace()) + if err == nil { + err = r.client.CleanUpWorkspace() + } return err } diff --git a/cli/internal/cloudcmd/rollback_test.go b/cli/internal/cloudcmd/rollback_test.go new file mode 100644 index 000000000..b92573dae --- /dev/null +++ b/cli/internal/cloudcmd/rollback_test.go @@ -0,0 +1,112 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package cloudcmd + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRollbackTerraform(t *testing.T) { + someErr := errors.New("failed") + + testCases := map[string]struct { + tfClient *stubTerraformClient + wantErr bool + }{ + "success": { + tfClient: &stubTerraformClient{}, + }, + "destroy cluster error": { + tfClient: &stubTerraformClient{destroyClusterErr: someErr}, + wantErr: true, + }, + "clean up workspace error": { + tfClient: &stubTerraformClient{cleanUpWorkspaceErr: someErr}, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + rollbacker := &rollbackerTerraform{ + client: tc.tfClient, + } + + err := rollbacker.rollback(context.Background()) + if tc.wantErr { + assert.Error(err) + if tc.tfClient.cleanUpWorkspaceErr == nil { + assert.False(tc.tfClient.cleanUpWorkspaceCalled) + } + return + } + assert.NoError(err) + assert.True(tc.tfClient.destroyClusterCalled) + assert.True(tc.tfClient.cleanUpWorkspaceCalled) + }) + } +} + +func TestRollbackQEMU(t *testing.T) { + someErr := errors.New("failed") + + testCases := map[string]struct { + libvirt *stubLibvirtRunner + tfClient *stubTerraformClient + wantErr bool + }{ + "success": { + libvirt: &stubLibvirtRunner{}, + tfClient: &stubTerraformClient{}, + }, + "stop libvirt error": { + libvirt: &stubLibvirtRunner{stopErr: someErr}, + tfClient: &stubTerraformClient{}, + wantErr: true, + }, + "destroy cluster error": { + libvirt: &stubLibvirtRunner{stopErr: someErr}, + tfClient: &stubTerraformClient{destroyClusterErr: someErr}, + wantErr: true, + }, + "clean up workspace error": { + libvirt: &stubLibvirtRunner{}, + tfClient: &stubTerraformClient{cleanUpWorkspaceErr: someErr}, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + rollbacker := &rollbackerQEMU{ + libvirt: tc.libvirt, + client: tc.tfClient, + } + + err := rollbacker.rollback(context.Background()) + if tc.wantErr { + assert.Error(err) + if tc.tfClient.cleanUpWorkspaceErr == nil { + assert.False(tc.tfClient.cleanUpWorkspaceCalled) + } + return + } + assert.NoError(err) + assert.True(tc.libvirt.stopCalled) + assert.True(tc.tfClient.destroyClusterCalled) + assert.True(tc.tfClient.cleanUpWorkspaceCalled) + }) + } +} diff --git a/cli/internal/cloudcmd/terminate.go b/cli/internal/cloudcmd/terminate.go index 5bc582940..752bf8f57 100644 --- a/cli/internal/cloudcmd/terminate.go +++ b/cli/internal/cloudcmd/terminate.go @@ -8,24 +8,22 @@ package cloudcmd import ( "context" - "errors" "github.com/edgelesssys/constellation/v2/cli/internal/libvirt" "github.com/edgelesssys/constellation/v2/cli/internal/terraform" - "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" ) // Terminator deletes cloud provider resources. type Terminator struct { - newTerraformClient func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error) + newTerraformClient func(ctx context.Context) (terraformClient, error) newLibvirtRunner func() libvirtRunner } // NewTerminator create a new cloud terminator. func NewTerminator() *Terminator { return &Terminator{ - newTerraformClient: func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error) { - return terraform.New(ctx, provider) + newTerraformClient: func(ctx context.Context) (terraformClient, error) { + return terraform.New(ctx) }, newLibvirtRunner: func() libvirtRunner { return libvirt.New() @@ -34,21 +32,14 @@ func NewTerminator() *Terminator { } // Terminate deletes the could provider resources. -func (t *Terminator) Terminate(ctx context.Context, provider cloudprovider.Provider) (retErr error) { - if provider == cloudprovider.Unknown { - return errors.New("unknown cloud provider") - } +func (t *Terminator) Terminate(ctx context.Context) (retErr error) { + defer func() { + if retErr == nil { + retErr = t.newLibvirtRunner().Stop(ctx) + } + }() - if provider == cloudprovider.QEMU { - libvirt := t.newLibvirtRunner() - defer func() { - if retErr == nil { - retErr = libvirt.Stop(ctx) - } - }() - } - - cl, err := t.newTerraformClient(ctx, provider) + cl, err := t.newTerraformClient(ctx) if err != nil { return err } diff --git a/cli/internal/cloudcmd/terminate_test.go b/cli/internal/cloudcmd/terminate_test.go index f3e7b99f4..711d4881e 100644 --- a/cli/internal/cloudcmd/terminate_test.go +++ b/cli/internal/cloudcmd/terminate_test.go @@ -11,7 +11,6 @@ import ( "errors" "testing" - "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/stretchr/testify/assert" ) @@ -22,54 +21,32 @@ func TestTerminator(t *testing.T) { tfClient terraformClient newTfClientErr error libvirt *stubLibvirtRunner - provider cloudprovider.Provider wantErr bool }{ "gcp": { + libvirt: &stubLibvirtRunner{}, tfClient: &stubTerraformClient{}, - provider: cloudprovider.GCP, }, - "gcp newTfClientErr": { + "newTfClientErr": { + libvirt: &stubLibvirtRunner{}, newTfClientErr: someErr, - provider: cloudprovider.GCP, wantErr: true, }, - "gcp destroy cluster error": { - tfClient: &stubTerraformClient{destroyClusterErr: someErr}, - provider: cloudprovider.GCP, - wantErr: true, - }, - "gcp clean up workspace error": { - tfClient: &stubTerraformClient{cleanUpWorkspaceErr: someErr}, - provider: cloudprovider.GCP, - wantErr: true, - }, - "qemu": { - tfClient: &stubTerraformClient{}, - libvirt: &stubLibvirtRunner{}, - provider: cloudprovider.QEMU, - }, - "qemu destroy cluster error": { + "destroy cluster error": { tfClient: &stubTerraformClient{destroyClusterErr: someErr}, libvirt: &stubLibvirtRunner{}, - provider: cloudprovider.QEMU, wantErr: true, }, - "qemu clean up workspace error": { + "clean up workspace error": { tfClient: &stubTerraformClient{cleanUpWorkspaceErr: someErr}, libvirt: &stubLibvirtRunner{}, - provider: cloudprovider.QEMU, wantErr: true, }, "qemu stop libvirt error": { tfClient: &stubTerraformClient{}, libvirt: &stubLibvirtRunner{stopErr: someErr}, - provider: cloudprovider.QEMU, wantErr: true, }, - "unknown cloud provider": { - wantErr: true, - }, } for name, tc := range testCases { @@ -77,7 +54,7 @@ func TestTerminator(t *testing.T) { assert := assert.New(t) terminator := &Terminator{ - newTerraformClient: func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error) { + newTerraformClient: func(ctx context.Context) (terraformClient, error) { return tc.tfClient, tc.newTfClientErr }, newLibvirtRunner: func() libvirtRunner { @@ -85,19 +62,17 @@ func TestTerminator(t *testing.T) { }, } - err := terminator.Terminate(context.Background(), tc.provider) + err := terminator.Terminate(context.Background()) if tc.wantErr { assert.Error(err) - } else { - assert.NoError(err) - cl := tc.tfClient.(*stubTerraformClient) - assert.True(cl.destroyClusterCalled) - assert.True(cl.removeInstallerCalled) - if tc.provider == cloudprovider.QEMU { - assert.True(tc.libvirt.stopCalled) - } + return } + assert.NoError(err) + cl := tc.tfClient.(*stubTerraformClient) + assert.True(cl.destroyClusterCalled) + assert.True(cl.removeInstallerCalled) + assert.True(tc.libvirt.stopCalled) }) } } diff --git a/cli/internal/cmd/cloud.go b/cli/internal/cmd/cloud.go index aad1be6ea..dbe0330eb 100644 --- a/cli/internal/cmd/cloud.go +++ b/cli/internal/cmd/cloud.go @@ -25,5 +25,5 @@ type cloudCreator interface { } type cloudTerminator interface { - Terminate(context.Context, cloudprovider.Provider) error + Terminate(context.Context) error } diff --git a/cli/internal/cmd/cloud_test.go b/cli/internal/cmd/cloud_test.go index 4269f8883..9791dffb0 100644 --- a/cli/internal/cmd/cloud_test.go +++ b/cli/internal/cmd/cloud_test.go @@ -46,7 +46,7 @@ type stubCloudTerminator struct { terminateErr error } -func (c *stubCloudTerminator) Terminate(context.Context, cloudprovider.Provider) error { +func (c *stubCloudTerminator) Terminate(context.Context) error { c.called = true return c.terminateErr } diff --git a/cli/internal/cmd/terminate.go b/cli/internal/cmd/terminate.go index 07647f909..611275d6e 100644 --- a/cli/internal/cmd/terminate.go +++ b/cli/internal/cmd/terminate.go @@ -16,7 +16,6 @@ import ( "go.uber.org/multierr" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" - "github.com/edgelesssys/constellation/v2/cli/internal/clusterid" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/file" ) @@ -45,13 +44,8 @@ func runTerminate(cmd *cobra.Command, args []string) error { func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.Handler, spinner spinnerInterf, ) error { - var idFile clusterid.File - if err := fileHandler.ReadJSON(constants.ClusterIDsFileName, &idFile); err != nil { - return err - } - spinner.Start("Terminating", false) - err := terminator.Terminate(cmd.Context(), idFile.CloudProvider) + err := terminator.Terminate(cmd.Context()) spinner.Stop() if err != nil { return fmt.Errorf("terminating Constellation cluster: %w", err) diff --git a/cli/internal/cmd/terminate_test.go b/cli/internal/cmd/terminate_test.go index 14d79d4da..0a06ede74 100644 --- a/cli/internal/cmd/terminate_test.go +++ b/cli/internal/cmd/terminate_test.go @@ -83,7 +83,7 @@ func TestTerminate(t *testing.T) { terminator: &stubCloudTerminator{terminateErr: someErr}, wantErr: true, }, - "missing id file": { + "missing id file does not error": { idFile: clusterid.File{CloudProvider: cloudprovider.GCP}, setupFs: func(require *require.Assertions, idFile clusterid.File) afero.Fs { fs := afero.NewMemMapFs() @@ -92,7 +92,6 @@ func TestTerminate(t *testing.T) { return fs }, terminator: &stubCloudTerminator{}, - wantErr: true, }, "remove file fails": { idFile: clusterid.File{CloudProvider: cloudprovider.GCP}, diff --git a/cli/internal/terraform/loader.go b/cli/internal/terraform/loader.go index e662bbbfb..6bfca0711 100644 --- a/cli/internal/terraform/loader.go +++ b/cli/internal/terraform/loader.go @@ -45,15 +45,21 @@ func prepareWorkspace(fileHandler file.Handler, provider cloudprovider.Provider) } // cleanUpWorkspace removes files that were loaded into the workspace. -func cleanUpWorkspace(fileHandler file.Handler, provider cloudprovider.Provider) error { - rootDir := path.Join("terraform", strings.ToLower(provider.String())) - return fs.WalkDir(terraformFS, rootDir, func(path string, d fs.DirEntry, err error) error { - if err != nil { +func cleanUpWorkspace(fileHandler file.Handler) error { + // try to remove any terraform files in the workspace + for _, csp := range []string{"aws", "azure", "gcp", "qemu"} { + rootDir := path.Join("terraform", csp) + if err := fs.WalkDir(terraformFS, rootDir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + fileName := strings.TrimPrefix(path, rootDir+"/") + return ignoreFileNotFoundErr(fileHandler.RemoveAll(fileName)) + }); err != nil { return err } - fileName := strings.TrimPrefix(path, rootDir+"/") - return ignoreFileNotFoundErr(fileHandler.RemoveAll(fileName)) - }) + } + return nil } // ignoreFileNotFoundErr ignores the error if it is a file not found error. diff --git a/cli/internal/terraform/loader_test.go b/cli/internal/terraform/loader_test.go index 64b00a6e6..2c3c01593 100644 --- a/cli/internal/terraform/loader_test.go +++ b/cli/internal/terraform/loader_test.go @@ -63,7 +63,7 @@ func TestLoader(t *testing.T) { checkFiles(t, file, func(err error) { assert.NoError(err) }, tc.fileList) - err = cleanUpWorkspace(file, tc.provider) + err = cleanUpWorkspace(file) require.NoError(err) checkFiles(t, file, func(err error) { assert.ErrorIs(err, fs.ErrNotExist) }, tc.fileList) diff --git a/cli/internal/terraform/terraform.go b/cli/internal/terraform/terraform.go index 72eecea6a..f666b268f 100644 --- a/cli/internal/terraform/terraform.go +++ b/cli/internal/terraform/terraform.go @@ -32,14 +32,12 @@ const ( type Client struct { tf tfInterface - provider cloudprovider.Provider - file file.Handler remove func() } // New sets up a new Client for Terraform. -func New(ctx context.Context, provider cloudprovider.Provider) (*Client, error) { +func New(ctx context.Context) (*Client, error) { tf, remove, err := GetExecutable(ctx, ".") if err != nil { return nil, err @@ -48,16 +46,17 @@ func New(ctx context.Context, provider cloudprovider.Provider) (*Client, error) file := file.NewHandler(afero.NewOsFs()) return &Client{ - tf: tf, - provider: provider, - remove: remove, - file: file, + tf: tf, + remove: remove, + file: file, }, nil } // CreateCluster creates a Constellation cluster using Terraform. -func (c *Client) CreateCluster(ctx context.Context, name string, vars Variables) (string, error) { - if err := prepareWorkspace(c.file, c.provider); err != nil { +func (c *Client) CreateCluster( + ctx context.Context, provider cloudprovider.Provider, name string, vars Variables, +) (string, error) { + if err := prepareWorkspace(c.file, provider); err != nil { return "", err } @@ -102,7 +101,7 @@ func (c *Client) RemoveInstaller() { // CleanUpWorkspace removes terraform files from the current directory. func (c *Client) CleanUpWorkspace() error { - if err := cleanUpWorkspace(c.file, c.provider); err != nil { + if err := cleanUpWorkspace(c.file); err != nil { return err } diff --git a/cli/internal/terraform/terraform_test.go b/cli/internal/terraform/terraform_test.go index 0193ac30a..e0a4b659e 100644 --- a/cli/internal/terraform/terraform_test.go +++ b/cli/internal/terraform/terraform_test.go @@ -122,12 +122,11 @@ func TestCreateCluster(t *testing.T) { assert := assert.New(t) c := &Client{ - provider: tc.provider, - tf: tc.tf, - file: file.NewHandler(tc.fs), + tf: tc.tf, + file: file.NewHandler(tc.fs), } - ip, err := c.CreateCluster(context.Background(), "test", tc.vars) + ip, err := c.CreateCluster(context.Background(), tc.provider, "test", tc.vars) if tc.wantErr { assert.Error(err) @@ -160,8 +159,7 @@ func TestDestroyInstances(t *testing.T) { assert := assert.New(t) c := &Client{ - provider: cloudprovider.QEMU, - tf: tc.tf, + tf: tc.tf, } err := c.DestroyCluster(context.Background()) @@ -207,9 +205,8 @@ func TestCleanupWorkspace(t *testing.T) { require.NoError(tc.prepareFS(file)) c := &Client{ - provider: tc.provider, - file: file, - tf: &stubTerraform{}, + file: file, + tf: &stubTerraform{}, } err := c.CleanUpWorkspace()