AB#2532 Dont clean up workspace if rollback fails (#360)

* Dont clean up workspace if rollback fails

* Remove dependency on CSP from terminate

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-10-26 15:57:00 +02:00 committed by GitHub
parent 1f8eba37c8
commit e66cb84d6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 190 additions and 111 deletions

View file

@ -10,10 +10,11 @@ import (
"context"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
)
type terraformClient interface {
CreateCluster(ctx context.Context, name string, input terraform.Variables) (string, error)
CreateCluster(ctx context.Context, provider cloudprovider.Provider, name string, input terraform.Variables) (string, error)
DestroyCluster(ctx context.Context) error
CleanUpWorkspace() error
RemoveInstaller()

View file

@ -11,6 +11,7 @@ import (
"testing"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"go.uber.org/goleak"
)
@ -31,7 +32,7 @@ type stubTerraformClient struct {
cleanUpWorkspaceErr error
}
func (c *stubTerraformClient) CreateCluster(ctx context.Context, name string, input terraform.Variables) (string, error) {
func (c *stubTerraformClient) CreateCluster(ctx context.Context, provider cloudprovider.Provider, name string, input terraform.Variables) (string, error) {
return c.ip, c.createClusterErr
}

View file

@ -25,7 +25,7 @@ import (
// Creator creates cloud resources.
type Creator struct {
out io.Writer
newTerraformClient func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error)
newTerraformClient func(ctx context.Context) (terraformClient, error)
newLibvirtRunner func() libvirtRunner
}
@ -33,8 +33,8 @@ type Creator struct {
func NewCreator(out io.Writer) *Creator {
return &Creator{
out: out,
newTerraformClient: func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error) {
return terraform.New(ctx, provider)
newTerraformClient: func(ctx context.Context) (terraformClient, error) {
return terraform.New(ctx)
},
newLibvirtRunner: func() libvirtRunner {
return libvirt.New()
@ -51,21 +51,21 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c
if os.Getenv("CONSTELLATION_AWS_DEV") != "1" {
return clusterid.File{}, fmt.Errorf("AWS isn't supported yet")
}
cl, err := c.newTerraformClient(ctx, provider)
cl, err := c.newTerraformClient(ctx)
if err != nil {
return clusterid.File{}, err
}
defer cl.RemoveInstaller()
return c.createAWS(ctx, cl, config, name, insType, controlPlaneCount, workerCount)
case cloudprovider.GCP:
cl, err := c.newTerraformClient(ctx, provider)
cl, err := c.newTerraformClient(ctx)
if err != nil {
return clusterid.File{}, err
}
defer cl.RemoveInstaller()
return c.createGCP(ctx, cl, config, name, insType, controlPlaneCount, workerCount)
case cloudprovider.Azure:
cl, err := c.newTerraformClient(ctx, provider)
cl, err := c.newTerraformClient(ctx)
if err != nil {
return clusterid.File{}, err
}
@ -75,7 +75,7 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c
if runtime.GOARCH != "amd64" || runtime.GOOS != "linux" {
return clusterid.File{}, fmt.Errorf("creation of a QEMU based Constellation is not supported for %s/%s", runtime.GOOS, runtime.GOARCH)
}
cl, err := c.newTerraformClient(ctx, provider)
cl, err := c.newTerraformClient(ctx)
if err != nil {
return clusterid.File{}, err
}
@ -108,7 +108,7 @@ func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *con
Debug: config.IsDebugCluster(),
}
ip, err := cl.CreateCluster(ctx, name, vars)
ip, err := cl.CreateCluster(ctx, cloudprovider.AWS, name, vars)
if err != nil {
return clusterid.File{}, err
}
@ -141,7 +141,7 @@ func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *con
Debug: config.IsDebugCluster(),
}
ip, err := cl.CreateCluster(ctx, name, &vars)
ip, err := cl.CreateCluster(ctx, cloudprovider.GCP, name, &vars)
if err != nil {
return clusterid.File{}, err
}
@ -177,7 +177,7 @@ func (c *Creator) createAzure(ctx context.Context, cl terraformClient, config *c
vars = normalizeAzureURIs(vars)
ip, err := cl.CreateCluster(ctx, name, &vars)
ip, err := cl.CreateCluster(ctx, cloudprovider.Azure, name, &vars)
if err != nil {
return clusterid.File{}, err
}
@ -258,7 +258,7 @@ func (c *Creator) createQEMU(ctx context.Context, cl terraformClient, lv libvirt
Firmware: config.Provider.QEMU.Firmware,
}
ip, err := cl.CreateCluster(ctx, name, &vars)
ip, err := cl.CreateCluster(ctx, cloudprovider.QEMU, name, &vars)
if err != nil {
return clusterid.File{}, err
}

View file

@ -94,7 +94,7 @@ func TestCreator(t *testing.T) {
creator := &Creator{
out: &bytes.Buffer{},
newTerraformClient: func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error) {
newTerraformClient: func(ctx context.Context) (terraformClient, error) {
return tc.tfClient, tc.newTfClientErr
},
newLibvirtRunner: func() libvirtRunner {

View file

@ -41,7 +41,9 @@ type rollbackerTerraform struct {
func (r *rollbackerTerraform) rollback(ctx context.Context) error {
var err error
err = multierr.Append(err, r.client.DestroyCluster(ctx))
err = multierr.Append(err, r.client.CleanUpWorkspace())
if err == nil {
err = multierr.Append(err, r.client.CleanUpWorkspace())
}
return err
}
@ -54,6 +56,8 @@ func (r *rollbackerQEMU) rollback(ctx context.Context) error {
var err error
err = multierr.Append(err, r.client.DestroyCluster(ctx))
err = multierr.Append(err, r.libvirt.Stop(ctx))
err = multierr.Append(err, r.client.CleanUpWorkspace())
if err == nil {
err = r.client.CleanUpWorkspace()
}
return err
}

View file

@ -0,0 +1,112 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cloudcmd
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRollbackTerraform(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
tfClient *stubTerraformClient
wantErr bool
}{
"success": {
tfClient: &stubTerraformClient{},
},
"destroy cluster error": {
tfClient: &stubTerraformClient{destroyClusterErr: someErr},
wantErr: true,
},
"clean up workspace error": {
tfClient: &stubTerraformClient{cleanUpWorkspaceErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
rollbacker := &rollbackerTerraform{
client: tc.tfClient,
}
err := rollbacker.rollback(context.Background())
if tc.wantErr {
assert.Error(err)
if tc.tfClient.cleanUpWorkspaceErr == nil {
assert.False(tc.tfClient.cleanUpWorkspaceCalled)
}
return
}
assert.NoError(err)
assert.True(tc.tfClient.destroyClusterCalled)
assert.True(tc.tfClient.cleanUpWorkspaceCalled)
})
}
}
func TestRollbackQEMU(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
libvirt *stubLibvirtRunner
tfClient *stubTerraformClient
wantErr bool
}{
"success": {
libvirt: &stubLibvirtRunner{},
tfClient: &stubTerraformClient{},
},
"stop libvirt error": {
libvirt: &stubLibvirtRunner{stopErr: someErr},
tfClient: &stubTerraformClient{},
wantErr: true,
},
"destroy cluster error": {
libvirt: &stubLibvirtRunner{stopErr: someErr},
tfClient: &stubTerraformClient{destroyClusterErr: someErr},
wantErr: true,
},
"clean up workspace error": {
libvirt: &stubLibvirtRunner{},
tfClient: &stubTerraformClient{cleanUpWorkspaceErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
rollbacker := &rollbackerQEMU{
libvirt: tc.libvirt,
client: tc.tfClient,
}
err := rollbacker.rollback(context.Background())
if tc.wantErr {
assert.Error(err)
if tc.tfClient.cleanUpWorkspaceErr == nil {
assert.False(tc.tfClient.cleanUpWorkspaceCalled)
}
return
}
assert.NoError(err)
assert.True(tc.libvirt.stopCalled)
assert.True(tc.tfClient.destroyClusterCalled)
assert.True(tc.tfClient.cleanUpWorkspaceCalled)
})
}
}

View file

@ -8,24 +8,22 @@ package cloudcmd
import (
"context"
"errors"
"github.com/edgelesssys/constellation/v2/cli/internal/libvirt"
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
)
// Terminator deletes cloud provider resources.
type Terminator struct {
newTerraformClient func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error)
newTerraformClient func(ctx context.Context) (terraformClient, error)
newLibvirtRunner func() libvirtRunner
}
// NewTerminator create a new cloud terminator.
func NewTerminator() *Terminator {
return &Terminator{
newTerraformClient: func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error) {
return terraform.New(ctx, provider)
newTerraformClient: func(ctx context.Context) (terraformClient, error) {
return terraform.New(ctx)
},
newLibvirtRunner: func() libvirtRunner {
return libvirt.New()
@ -34,21 +32,14 @@ func NewTerminator() *Terminator {
}
// Terminate deletes the could provider resources.
func (t *Terminator) Terminate(ctx context.Context, provider cloudprovider.Provider) (retErr error) {
if provider == cloudprovider.Unknown {
return errors.New("unknown cloud provider")
}
func (t *Terminator) Terminate(ctx context.Context) (retErr error) {
defer func() {
if retErr == nil {
retErr = t.newLibvirtRunner().Stop(ctx)
}
}()
if provider == cloudprovider.QEMU {
libvirt := t.newLibvirtRunner()
defer func() {
if retErr == nil {
retErr = libvirt.Stop(ctx)
}
}()
}
cl, err := t.newTerraformClient(ctx, provider)
cl, err := t.newTerraformClient(ctx)
if err != nil {
return err
}

View file

@ -11,7 +11,6 @@ import (
"errors"
"testing"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/stretchr/testify/assert"
)
@ -22,54 +21,32 @@ func TestTerminator(t *testing.T) {
tfClient terraformClient
newTfClientErr error
libvirt *stubLibvirtRunner
provider cloudprovider.Provider
wantErr bool
}{
"gcp": {
libvirt: &stubLibvirtRunner{},
tfClient: &stubTerraformClient{},
provider: cloudprovider.GCP,
},
"gcp newTfClientErr": {
"newTfClientErr": {
libvirt: &stubLibvirtRunner{},
newTfClientErr: someErr,
provider: cloudprovider.GCP,
wantErr: true,
},
"gcp destroy cluster error": {
tfClient: &stubTerraformClient{destroyClusterErr: someErr},
provider: cloudprovider.GCP,
wantErr: true,
},
"gcp clean up workspace error": {
tfClient: &stubTerraformClient{cleanUpWorkspaceErr: someErr},
provider: cloudprovider.GCP,
wantErr: true,
},
"qemu": {
tfClient: &stubTerraformClient{},
libvirt: &stubLibvirtRunner{},
provider: cloudprovider.QEMU,
},
"qemu destroy cluster error": {
"destroy cluster error": {
tfClient: &stubTerraformClient{destroyClusterErr: someErr},
libvirt: &stubLibvirtRunner{},
provider: cloudprovider.QEMU,
wantErr: true,
},
"qemu clean up workspace error": {
"clean up workspace error": {
tfClient: &stubTerraformClient{cleanUpWorkspaceErr: someErr},
libvirt: &stubLibvirtRunner{},
provider: cloudprovider.QEMU,
wantErr: true,
},
"qemu stop libvirt error": {
tfClient: &stubTerraformClient{},
libvirt: &stubLibvirtRunner{stopErr: someErr},
provider: cloudprovider.QEMU,
wantErr: true,
},
"unknown cloud provider": {
wantErr: true,
},
}
for name, tc := range testCases {
@ -77,7 +54,7 @@ func TestTerminator(t *testing.T) {
assert := assert.New(t)
terminator := &Terminator{
newTerraformClient: func(ctx context.Context, provider cloudprovider.Provider) (terraformClient, error) {
newTerraformClient: func(ctx context.Context) (terraformClient, error) {
return tc.tfClient, tc.newTfClientErr
},
newLibvirtRunner: func() libvirtRunner {
@ -85,19 +62,17 @@ func TestTerminator(t *testing.T) {
},
}
err := terminator.Terminate(context.Background(), tc.provider)
err := terminator.Terminate(context.Background())
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
cl := tc.tfClient.(*stubTerraformClient)
assert.True(cl.destroyClusterCalled)
assert.True(cl.removeInstallerCalled)
if tc.provider == cloudprovider.QEMU {
assert.True(tc.libvirt.stopCalled)
}
return
}
assert.NoError(err)
cl := tc.tfClient.(*stubTerraformClient)
assert.True(cl.destroyClusterCalled)
assert.True(cl.removeInstallerCalled)
assert.True(tc.libvirt.stopCalled)
})
}
}