Terraform: Only rollback after we fully created the workspace

This commit is contained in:
Nils Hanke 2022-11-15 14:00:44 +01:00 committed by Nils Hanke
parent 19fb6f1233
commit e1d8926395
8 changed files with 178 additions and 60 deletions

View File

@ -14,7 +14,8 @@ import (
)
type terraformClient interface {
CreateCluster(ctx context.Context, provider cloudprovider.Provider, input terraform.Variables) (string, error)
PrepareWorkspace(provider cloudprovider.Provider, input terraform.Variables) error
CreateCluster(ctx context.Context) (string, error)
DestroyCluster(ctx context.Context) error
CleanUpWorkspace() error
RemoveInstaller()

View File

@ -12,6 +12,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"go.uber.org/goleak"
)
@ -29,13 +30,18 @@ type stubTerraformClient struct {
destroyClusterCalled bool
createClusterErr error
destroyClusterErr error
prepareWorkspaceErr error
cleanUpWorkspaceErr error
}
func (c *stubTerraformClient) CreateCluster(ctx context.Context, provider cloudprovider.Provider, input terraform.Variables) (string, error) {
func (c *stubTerraformClient) CreateCluster(ctx context.Context) (string, error) {
return c.ip, c.createClusterErr
}
func (c *stubTerraformClient) PrepareWorkspace(provider cloudprovider.Provider, input terraform.Variables) error {
return c.prepareWorkspaceErr
}
func (c *stubTerraformClient) DestroyCluster(ctx context.Context) error {
c.destroyClusterCalled = true
return c.destroyClusterErr

View File

@ -88,8 +88,6 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c
func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *config.Config,
name, insType string, controlPlaneCount, workerCount int,
) (idFile clusterid.File, retErr error) {
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
vars := terraform.AWSVariables{
CommonVariables: terraform.CommonVariables{
Name: name,
@ -107,7 +105,12 @@ func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *con
Debug: config.IsDebugCluster(),
}
ip, err := cl.CreateCluster(ctx, cloudprovider.AWS, &vars)
if err := cl.PrepareWorkspace(cloudprovider.AWS, &vars); err != nil {
return clusterid.File{}, err
}
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
ip, err := cl.CreateCluster(ctx)
if err != nil {
return clusterid.File{}, err
}
@ -121,8 +124,6 @@ func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *con
func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *config.Config,
name, insType string, controlPlaneCount, workerCount int,
) (idFile clusterid.File, retErr error) {
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
vars := terraform.GCPVariables{
CommonVariables: terraform.CommonVariables{
Name: name,
@ -140,7 +141,12 @@ func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *con
Debug: config.IsDebugCluster(),
}
ip, err := cl.CreateCluster(ctx, cloudprovider.GCP, &vars)
if err := cl.PrepareWorkspace(cloudprovider.GCP, &vars); err != nil {
return clusterid.File{}, err
}
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
ip, err := cl.CreateCluster(ctx)
if err != nil {
return clusterid.File{}, err
}
@ -154,8 +160,6 @@ func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *con
func (c *Creator) createAzure(ctx context.Context, cl terraformClient, config *config.Config,
name, insType string, controlPlaneCount, workerCount int,
) (idFile clusterid.File, retErr error) {
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
vars := terraform.AzureVariables{
CommonVariables: terraform.CommonVariables{
Name: name,
@ -176,7 +180,12 @@ func (c *Creator) createAzure(ctx context.Context, cl terraformClient, config *c
vars = normalizeAzureURIs(vars)
ip, err := cl.CreateCluster(ctx, cloudprovider.Azure, &vars)
if err := cl.PrepareWorkspace(cloudprovider.Azure, &vars); err != nil {
return clusterid.File{}, err
}
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
ip, err := cl.CreateCluster(ctx)
if err != nil {
return clusterid.File{}, err
}
@ -216,7 +225,8 @@ func normalizeAzureURIs(vars terraform.AzureVariables) terraform.AzureVariables
func (c *Creator) createQEMU(ctx context.Context, cl terraformClient, lv libvirtRunner, name string, config *config.Config,
controlPlaneCount, workerCount int,
) (idFile clusterid.File, retErr error) {
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerQEMU{client: cl, libvirt: lv})
qemuRollbacker := &rollbackerQEMU{client: cl, libvirt: lv, createdWorkspace: false}
defer rollbackOnError(context.Background(), c.out, &retErr, qemuRollbacker)
libvirtURI := config.Provider.QEMU.LibvirtURI
libvirtSocketPath := "."
@ -273,7 +283,14 @@ func (c *Creator) createQEMU(ctx context.Context, cl terraformClient, lv libvirt
Firmware: config.Provider.QEMU.Firmware,
}
ip, err := cl.CreateCluster(ctx, cloudprovider.QEMU, &vars)
if err := cl.PrepareWorkspace(cloudprovider.QEMU, &vars); err != nil {
return clusterid.File{}, err
}
// Allow rollback of QEMU Terraform workspace from this point on
qemuRollbacker.createdWorkspace = true
ip, err := cl.CreateCluster(ctx)
if err != nil {
return clusterid.File{}, err
}

View File

@ -25,13 +25,14 @@ func TestCreator(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
tfClient terraformClient
newTfClientErr error
libvirt *stubLibvirtRunner
provider cloudprovider.Provider
config *config.Config
wantErr bool
wantRollback bool // Use only together with stubClients.
tfClient terraformClient
newTfClientErr error
libvirt *stubLibvirtRunner
provider cloudprovider.Provider
config *config.Config
wantErr bool
wantRollback bool // Use only together with stubClients.
wantTerraformRollback bool // When libvirt fails, don't call into Terraform.
}{
"gcp": {
tfClient: &stubTerraformClient{ip: ip},
@ -45,11 +46,12 @@ func TestCreator(t *testing.T) {
wantErr: true,
},
"gcp create cluster error": {
tfClient: &stubTerraformClient{createClusterErr: someErr},
provider: cloudprovider.GCP,
config: config.Default(),
wantErr: true,
wantRollback: true,
tfClient: &stubTerraformClient{createClusterErr: someErr},
provider: cloudprovider.GCP,
config: config.Default(),
wantErr: true,
wantRollback: true,
wantTerraformRollback: true,
},
"qemu": {
tfClient: &stubTerraformClient{ip: ip},
@ -66,20 +68,22 @@ func TestCreator(t *testing.T) {
wantErr: true,
},
"qemu create cluster error": {
tfClient: &stubTerraformClient{createClusterErr: someErr},
libvirt: &stubLibvirtRunner{},
provider: cloudprovider.QEMU,
config: config.Default(),
wantErr: true,
wantRollback: !failOnNonAMD64, // if we run on non-AMD64/linux, we don't get to a point where rollback is needed
tfClient: &stubTerraformClient{createClusterErr: someErr},
libvirt: &stubLibvirtRunner{},
provider: cloudprovider.QEMU,
config: config.Default(),
wantErr: true,
wantRollback: !failOnNonAMD64, // if we run on non-AMD64/linux, we don't get to a point where rollback is needed
wantTerraformRollback: true,
},
"qemu start libvirt error": {
tfClient: &stubTerraformClient{ip: ip},
libvirt: &stubLibvirtRunner{startErr: someErr},
provider: cloudprovider.QEMU,
config: config.Default(),
wantErr: true,
wantRollback: !failOnNonAMD64,
tfClient: &stubTerraformClient{ip: ip},
libvirt: &stubLibvirtRunner{startErr: someErr},
provider: cloudprovider.QEMU,
config: config.Default(),
wantRollback: !failOnNonAMD64,
wantTerraformRollback: false,
wantErr: true,
},
"unknown provider": {
provider: cloudprovider.Unknown,
@ -108,7 +112,9 @@ func TestCreator(t *testing.T) {
assert.Error(err)
if tc.wantRollback {
cl := tc.tfClient.(*stubTerraformClient)
assert.True(cl.destroyClusterCalled)
if tc.wantTerraformRollback {
assert.True(cl.destroyClusterCalled)
}
assert.True(cl.cleanUpWorkspaceCalled)
if tc.provider == cloudprovider.QEMU {
assert.True(tc.libvirt.stopCalled)

View File

@ -48,13 +48,16 @@ func (r *rollbackerTerraform) rollback(ctx context.Context) error {
}
type rollbackerQEMU struct {
client terraformClient
libvirt libvirtRunner
client terraformClient
libvirt libvirtRunner
createdWorkspace bool
}
func (r *rollbackerQEMU) rollback(ctx context.Context) error {
var err error
err = multierr.Append(err, r.client.DestroyCluster(ctx))
if r.createdWorkspace {
err = multierr.Append(err, r.client.DestroyCluster(ctx))
}
err = multierr.Append(err, r.libvirt.Stop(ctx))
if err == nil {
err = r.client.CleanUpWorkspace()

View File

@ -61,13 +61,15 @@ func TestRollbackQEMU(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
libvirt *stubLibvirtRunner
tfClient *stubTerraformClient
wantErr bool
libvirt *stubLibvirtRunner
tfClient *stubTerraformClient
createdWorkspace bool
wantErr bool
}{
"success": {
libvirt: &stubLibvirtRunner{},
tfClient: &stubTerraformClient{},
libvirt: &stubLibvirtRunner{},
tfClient: &stubTerraformClient{},
createdWorkspace: true,
},
"stop libvirt error": {
libvirt: &stubLibvirtRunner{stopErr: someErr},
@ -91,8 +93,9 @@ func TestRollbackQEMU(t *testing.T) {
assert := assert.New(t)
rollbacker := &rollbackerQEMU{
libvirt: tc.libvirt,
client: tc.tfClient,
libvirt: tc.libvirt,
client: tc.tfClient,
createdWorkspace: tc.createdWorkspace,
}
err := rollbacker.rollback(context.Background())
@ -105,7 +108,11 @@ func TestRollbackQEMU(t *testing.T) {
}
assert.NoError(err)
assert.True(tc.libvirt.stopCalled)
assert.True(tc.tfClient.destroyClusterCalled)
if tc.createdWorkspace {
assert.True(tc.tfClient.destroyClusterCalled)
} else {
assert.False(tc.tfClient.destroyClusterCalled)
}
assert.True(tc.tfClient.cleanUpWorkspaceCalled)
})
}

View File

@ -57,16 +57,21 @@ func New(ctx context.Context, workingDir string) (*Client, error) {
}, nil
}
// CreateCluster creates a Constellation cluster using Terraform.
func (c *Client) CreateCluster(ctx context.Context, provider cloudprovider.Provider, vars Variables) (string, error) {
// PrepareWorkspace prepares a Terraform workspace for a Constellation cluster.
func (c *Client) PrepareWorkspace(provider cloudprovider.Provider, vars Variables) error {
if err := prepareWorkspace(c.file, provider, c.workingDir); err != nil {
return "", err
return err
}
if err := c.writeVars(vars); err != nil {
return "", err
return err
}
return nil
}
// CreateCluster creates a Constellation cluster using Terraform.
func (c *Client) CreateCluster(ctx context.Context) (string, error) {
if err := c.tf.Init(ctx); err != nil {
return "", err
}

View File

@ -24,6 +24,82 @@ import (
"go.uber.org/multierr"
)
func TestPrepareCluster(t *testing.T) {
qemuVars := &QEMUVariables{
CommonVariables: CommonVariables{
Name: "name",
CountControlPlanes: 1,
CountWorkers: 2,
StateDiskSizeGB: 11,
},
CPUCount: 1,
MemorySizeMiB: 1024,
ImagePath: "path",
ImageFormat: "format",
MetadataAPIImage: "api",
}
testCases := map[string]struct {
provider cloudprovider.Provider
vars Variables
fs afero.Fs
partiallyExtracted bool
wantErr bool
}{
"qemu": {
provider: cloudprovider.QEMU,
vars: qemuVars,
fs: afero.NewMemMapFs(),
wantErr: false,
},
"no vars": {
provider: cloudprovider.QEMU,
fs: afero.NewMemMapFs(),
wantErr: true,
},
"continue on partially extracted": {
provider: cloudprovider.QEMU,
vars: qemuVars,
fs: afero.NewMemMapFs(),
partiallyExtracted: true,
wantErr: false,
},
"prepare workspace fails": {
provider: cloudprovider.QEMU,
vars: qemuVars,
fs: afero.NewReadOnlyFs(afero.NewMemMapFs()),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
c := &Client{
tf: &stubTerraform{},
file: file.NewHandler(tc.fs),
workingDir: constants.TerraformWorkingDir,
}
err := c.PrepareWorkspace(tc.provider, tc.vars)
// Test case: Check if we can continue to create on an incomplete workspace.
if tc.partiallyExtracted {
require.NoError(c.file.Remove(filepath.Join(c.workingDir, "main.tf")))
err = c.PrepareWorkspace(tc.provider, tc.vars)
}
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestCreateCluster(t *testing.T) {
someErr := errors.New("failed")
newTestState := func() *tfjson.State {
@ -67,6 +143,7 @@ func TestCreateCluster(t *testing.T) {
},
"init fails": {
provider: cloudprovider.QEMU,
vars: qemuVars,
tf: &stubTerraform{initErr: someErr},
fs: afero.NewMemMapFs(),
wantErr: true,
@ -111,17 +188,12 @@ func TestCreateCluster(t *testing.T) {
fs: afero.NewMemMapFs(),
wantErr: true,
},
"prepare workspace fails": {
provider: cloudprovider.QEMU,
tf: &stubTerraform{showState: newTestState()},
fs: afero.NewReadOnlyFs(afero.NewMemMapFs()),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
c := &Client{
tf: tc.tf,
@ -129,7 +201,8 @@ func TestCreateCluster(t *testing.T) {
workingDir: constants.TerraformWorkingDir,
}
ip, err := c.CreateCluster(context.Background(), tc.provider, tc.vars)
require.NoError(c.PrepareWorkspace(tc.provider, tc.vars))
ip, err := c.CreateCluster(context.Background())
if tc.wantErr {
assert.Error(err)