Terraform: Only rollback after we fully created the workspace

This commit is contained in:
Nils Hanke 2022-11-15 14:00:44 +01:00 committed by Nils Hanke
parent 19fb6f1233
commit e1d8926395
8 changed files with 178 additions and 60 deletions

View file

@ -14,7 +14,8 @@ import (
) )
type terraformClient interface { type terraformClient interface {
CreateCluster(ctx context.Context, provider cloudprovider.Provider, input terraform.Variables) (string, error) PrepareWorkspace(provider cloudprovider.Provider, input terraform.Variables) error
CreateCluster(ctx context.Context) (string, error)
DestroyCluster(ctx context.Context) error DestroyCluster(ctx context.Context) error
CleanUpWorkspace() error CleanUpWorkspace() error
RemoveInstaller() RemoveInstaller()

View file

@ -12,6 +12,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/terraform" "github.com/edgelesssys/constellation/v2/cli/internal/terraform"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"go.uber.org/goleak" "go.uber.org/goleak"
) )
@ -29,13 +30,18 @@ type stubTerraformClient struct {
destroyClusterCalled bool destroyClusterCalled bool
createClusterErr error createClusterErr error
destroyClusterErr error destroyClusterErr error
prepareWorkspaceErr error
cleanUpWorkspaceErr error cleanUpWorkspaceErr error
} }
func (c *stubTerraformClient) CreateCluster(ctx context.Context, provider cloudprovider.Provider, input terraform.Variables) (string, error) { func (c *stubTerraformClient) CreateCluster(ctx context.Context) (string, error) {
return c.ip, c.createClusterErr return c.ip, c.createClusterErr
} }
func (c *stubTerraformClient) PrepareWorkspace(provider cloudprovider.Provider, input terraform.Variables) error {
return c.prepareWorkspaceErr
}
func (c *stubTerraformClient) DestroyCluster(ctx context.Context) error { func (c *stubTerraformClient) DestroyCluster(ctx context.Context) error {
c.destroyClusterCalled = true c.destroyClusterCalled = true
return c.destroyClusterErr return c.destroyClusterErr

View file

@ -88,8 +88,6 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c
func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *config.Config, func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *config.Config,
name, insType string, controlPlaneCount, workerCount int, name, insType string, controlPlaneCount, workerCount int,
) (idFile clusterid.File, retErr error) { ) (idFile clusterid.File, retErr error) {
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
vars := terraform.AWSVariables{ vars := terraform.AWSVariables{
CommonVariables: terraform.CommonVariables{ CommonVariables: terraform.CommonVariables{
Name: name, Name: name,
@ -107,7 +105,12 @@ func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *con
Debug: config.IsDebugCluster(), Debug: config.IsDebugCluster(),
} }
ip, err := cl.CreateCluster(ctx, cloudprovider.AWS, &vars) if err := cl.PrepareWorkspace(cloudprovider.AWS, &vars); err != nil {
return clusterid.File{}, err
}
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
ip, err := cl.CreateCluster(ctx)
if err != nil { if err != nil {
return clusterid.File{}, err return clusterid.File{}, err
} }
@ -121,8 +124,6 @@ func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *con
func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *config.Config, func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *config.Config,
name, insType string, controlPlaneCount, workerCount int, name, insType string, controlPlaneCount, workerCount int,
) (idFile clusterid.File, retErr error) { ) (idFile clusterid.File, retErr error) {
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
vars := terraform.GCPVariables{ vars := terraform.GCPVariables{
CommonVariables: terraform.CommonVariables{ CommonVariables: terraform.CommonVariables{
Name: name, Name: name,
@ -140,7 +141,12 @@ func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *con
Debug: config.IsDebugCluster(), Debug: config.IsDebugCluster(),
} }
ip, err := cl.CreateCluster(ctx, cloudprovider.GCP, &vars) if err := cl.PrepareWorkspace(cloudprovider.GCP, &vars); err != nil {
return clusterid.File{}, err
}
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
ip, err := cl.CreateCluster(ctx)
if err != nil { if err != nil {
return clusterid.File{}, err return clusterid.File{}, err
} }
@ -154,8 +160,6 @@ func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *con
func (c *Creator) createAzure(ctx context.Context, cl terraformClient, config *config.Config, func (c *Creator) createAzure(ctx context.Context, cl terraformClient, config *config.Config,
name, insType string, controlPlaneCount, workerCount int, name, insType string, controlPlaneCount, workerCount int,
) (idFile clusterid.File, retErr error) { ) (idFile clusterid.File, retErr error) {
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
vars := terraform.AzureVariables{ vars := terraform.AzureVariables{
CommonVariables: terraform.CommonVariables{ CommonVariables: terraform.CommonVariables{
Name: name, Name: name,
@ -176,7 +180,12 @@ func (c *Creator) createAzure(ctx context.Context, cl terraformClient, config *c
vars = normalizeAzureURIs(vars) vars = normalizeAzureURIs(vars)
ip, err := cl.CreateCluster(ctx, cloudprovider.Azure, &vars) if err := cl.PrepareWorkspace(cloudprovider.Azure, &vars); err != nil {
return clusterid.File{}, err
}
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
ip, err := cl.CreateCluster(ctx)
if err != nil { if err != nil {
return clusterid.File{}, err return clusterid.File{}, err
} }
@ -216,7 +225,8 @@ func normalizeAzureURIs(vars terraform.AzureVariables) terraform.AzureVariables
func (c *Creator) createQEMU(ctx context.Context, cl terraformClient, lv libvirtRunner, name string, config *config.Config, func (c *Creator) createQEMU(ctx context.Context, cl terraformClient, lv libvirtRunner, name string, config *config.Config,
controlPlaneCount, workerCount int, controlPlaneCount, workerCount int,
) (idFile clusterid.File, retErr error) { ) (idFile clusterid.File, retErr error) {
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerQEMU{client: cl, libvirt: lv}) qemuRollbacker := &rollbackerQEMU{client: cl, libvirt: lv, createdWorkspace: false}
defer rollbackOnError(context.Background(), c.out, &retErr, qemuRollbacker)
libvirtURI := config.Provider.QEMU.LibvirtURI libvirtURI := config.Provider.QEMU.LibvirtURI
libvirtSocketPath := "." libvirtSocketPath := "."
@ -273,7 +283,14 @@ func (c *Creator) createQEMU(ctx context.Context, cl terraformClient, lv libvirt
Firmware: config.Provider.QEMU.Firmware, Firmware: config.Provider.QEMU.Firmware,
} }
ip, err := cl.CreateCluster(ctx, cloudprovider.QEMU, &vars) if err := cl.PrepareWorkspace(cloudprovider.QEMU, &vars); err != nil {
return clusterid.File{}, err
}
// Allow rollback of QEMU Terraform workspace from this point on
qemuRollbacker.createdWorkspace = true
ip, err := cl.CreateCluster(ctx)
if err != nil { if err != nil {
return clusterid.File{}, err return clusterid.File{}, err
} }

View file

@ -25,13 +25,14 @@ func TestCreator(t *testing.T) {
someErr := errors.New("failed") someErr := errors.New("failed")
testCases := map[string]struct { testCases := map[string]struct {
tfClient terraformClient tfClient terraformClient
newTfClientErr error newTfClientErr error
libvirt *stubLibvirtRunner libvirt *stubLibvirtRunner
provider cloudprovider.Provider provider cloudprovider.Provider
config *config.Config config *config.Config
wantErr bool wantErr bool
wantRollback bool // Use only together with stubClients. wantRollback bool // Use only together with stubClients.
wantTerraformRollback bool // When libvirt fails, don't call into Terraform.
}{ }{
"gcp": { "gcp": {
tfClient: &stubTerraformClient{ip: ip}, tfClient: &stubTerraformClient{ip: ip},
@ -45,11 +46,12 @@ func TestCreator(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"gcp create cluster error": { "gcp create cluster error": {
tfClient: &stubTerraformClient{createClusterErr: someErr}, tfClient: &stubTerraformClient{createClusterErr: someErr},
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
config: config.Default(), config: config.Default(),
wantErr: true, wantErr: true,
wantRollback: true, wantRollback: true,
wantTerraformRollback: true,
}, },
"qemu": { "qemu": {
tfClient: &stubTerraformClient{ip: ip}, tfClient: &stubTerraformClient{ip: ip},
@ -66,20 +68,22 @@ func TestCreator(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"qemu create cluster error": { "qemu create cluster error": {
tfClient: &stubTerraformClient{createClusterErr: someErr}, tfClient: &stubTerraformClient{createClusterErr: someErr},
libvirt: &stubLibvirtRunner{}, libvirt: &stubLibvirtRunner{},
provider: cloudprovider.QEMU, provider: cloudprovider.QEMU,
config: config.Default(), config: config.Default(),
wantErr: true, wantErr: true,
wantRollback: !failOnNonAMD64, // if we run on non-AMD64/linux, we don't get to a point where rollback is needed wantRollback: !failOnNonAMD64, // if we run on non-AMD64/linux, we don't get to a point where rollback is needed
wantTerraformRollback: true,
}, },
"qemu start libvirt error": { "qemu start libvirt error": {
tfClient: &stubTerraformClient{ip: ip}, tfClient: &stubTerraformClient{ip: ip},
libvirt: &stubLibvirtRunner{startErr: someErr}, libvirt: &stubLibvirtRunner{startErr: someErr},
provider: cloudprovider.QEMU, provider: cloudprovider.QEMU,
config: config.Default(), config: config.Default(),
wantErr: true, wantRollback: !failOnNonAMD64,
wantRollback: !failOnNonAMD64, wantTerraformRollback: false,
wantErr: true,
}, },
"unknown provider": { "unknown provider": {
provider: cloudprovider.Unknown, provider: cloudprovider.Unknown,
@ -108,7 +112,9 @@ func TestCreator(t *testing.T) {
assert.Error(err) assert.Error(err)
if tc.wantRollback { if tc.wantRollback {
cl := tc.tfClient.(*stubTerraformClient) cl := tc.tfClient.(*stubTerraformClient)
assert.True(cl.destroyClusterCalled) if tc.wantTerraformRollback {
assert.True(cl.destroyClusterCalled)
}
assert.True(cl.cleanUpWorkspaceCalled) assert.True(cl.cleanUpWorkspaceCalled)
if tc.provider == cloudprovider.QEMU { if tc.provider == cloudprovider.QEMU {
assert.True(tc.libvirt.stopCalled) assert.True(tc.libvirt.stopCalled)

View file

@ -48,13 +48,16 @@ func (r *rollbackerTerraform) rollback(ctx context.Context) error {
} }
type rollbackerQEMU struct { type rollbackerQEMU struct {
client terraformClient client terraformClient
libvirt libvirtRunner libvirt libvirtRunner
createdWorkspace bool
} }
func (r *rollbackerQEMU) rollback(ctx context.Context) error { func (r *rollbackerQEMU) rollback(ctx context.Context) error {
var err error var err error
err = multierr.Append(err, r.client.DestroyCluster(ctx)) if r.createdWorkspace {
err = multierr.Append(err, r.client.DestroyCluster(ctx))
}
err = multierr.Append(err, r.libvirt.Stop(ctx)) err = multierr.Append(err, r.libvirt.Stop(ctx))
if err == nil { if err == nil {
err = r.client.CleanUpWorkspace() err = r.client.CleanUpWorkspace()

View file

@ -61,13 +61,15 @@ func TestRollbackQEMU(t *testing.T) {
someErr := errors.New("failed") someErr := errors.New("failed")
testCases := map[string]struct { testCases := map[string]struct {
libvirt *stubLibvirtRunner libvirt *stubLibvirtRunner
tfClient *stubTerraformClient tfClient *stubTerraformClient
wantErr bool createdWorkspace bool
wantErr bool
}{ }{
"success": { "success": {
libvirt: &stubLibvirtRunner{}, libvirt: &stubLibvirtRunner{},
tfClient: &stubTerraformClient{}, tfClient: &stubTerraformClient{},
createdWorkspace: true,
}, },
"stop libvirt error": { "stop libvirt error": {
libvirt: &stubLibvirtRunner{stopErr: someErr}, libvirt: &stubLibvirtRunner{stopErr: someErr},
@ -91,8 +93,9 @@ func TestRollbackQEMU(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
rollbacker := &rollbackerQEMU{ rollbacker := &rollbackerQEMU{
libvirt: tc.libvirt, libvirt: tc.libvirt,
client: tc.tfClient, client: tc.tfClient,
createdWorkspace: tc.createdWorkspace,
} }
err := rollbacker.rollback(context.Background()) err := rollbacker.rollback(context.Background())
@ -105,7 +108,11 @@ func TestRollbackQEMU(t *testing.T) {
} }
assert.NoError(err) assert.NoError(err)
assert.True(tc.libvirt.stopCalled) assert.True(tc.libvirt.stopCalled)
assert.True(tc.tfClient.destroyClusterCalled) if tc.createdWorkspace {
assert.True(tc.tfClient.destroyClusterCalled)
} else {
assert.False(tc.tfClient.destroyClusterCalled)
}
assert.True(tc.tfClient.cleanUpWorkspaceCalled) assert.True(tc.tfClient.cleanUpWorkspaceCalled)
}) })
} }

View file

@ -57,16 +57,21 @@ func New(ctx context.Context, workingDir string) (*Client, error) {
}, nil }, nil
} }
// CreateCluster creates a Constellation cluster using Terraform. // PrepareWorkspace prepares a Terraform workspace for a Constellation cluster.
func (c *Client) CreateCluster(ctx context.Context, provider cloudprovider.Provider, vars Variables) (string, error) { func (c *Client) PrepareWorkspace(provider cloudprovider.Provider, vars Variables) error {
if err := prepareWorkspace(c.file, provider, c.workingDir); err != nil { if err := prepareWorkspace(c.file, provider, c.workingDir); err != nil {
return "", err return err
} }
if err := c.writeVars(vars); err != nil { if err := c.writeVars(vars); err != nil {
return "", err return err
} }
return nil
}
// CreateCluster creates a Constellation cluster using Terraform.
func (c *Client) CreateCluster(ctx context.Context) (string, error) {
if err := c.tf.Init(ctx); err != nil { if err := c.tf.Init(ctx); err != nil {
return "", err return "", err
} }

View file

@ -24,6 +24,82 @@ import (
"go.uber.org/multierr" "go.uber.org/multierr"
) )
func TestPrepareCluster(t *testing.T) {
qemuVars := &QEMUVariables{
CommonVariables: CommonVariables{
Name: "name",
CountControlPlanes: 1,
CountWorkers: 2,
StateDiskSizeGB: 11,
},
CPUCount: 1,
MemorySizeMiB: 1024,
ImagePath: "path",
ImageFormat: "format",
MetadataAPIImage: "api",
}
testCases := map[string]struct {
provider cloudprovider.Provider
vars Variables
fs afero.Fs
partiallyExtracted bool
wantErr bool
}{
"qemu": {
provider: cloudprovider.QEMU,
vars: qemuVars,
fs: afero.NewMemMapFs(),
wantErr: false,
},
"no vars": {
provider: cloudprovider.QEMU,
fs: afero.NewMemMapFs(),
wantErr: true,
},
"continue on partially extracted": {
provider: cloudprovider.QEMU,
vars: qemuVars,
fs: afero.NewMemMapFs(),
partiallyExtracted: true,
wantErr: false,
},
"prepare workspace fails": {
provider: cloudprovider.QEMU,
vars: qemuVars,
fs: afero.NewReadOnlyFs(afero.NewMemMapFs()),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
c := &Client{
tf: &stubTerraform{},
file: file.NewHandler(tc.fs),
workingDir: constants.TerraformWorkingDir,
}
err := c.PrepareWorkspace(tc.provider, tc.vars)
// Test case: Check if we can continue to create on an incomplete workspace.
if tc.partiallyExtracted {
require.NoError(c.file.Remove(filepath.Join(c.workingDir, "main.tf")))
err = c.PrepareWorkspace(tc.provider, tc.vars)
}
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestCreateCluster(t *testing.T) { func TestCreateCluster(t *testing.T) {
someErr := errors.New("failed") someErr := errors.New("failed")
newTestState := func() *tfjson.State { newTestState := func() *tfjson.State {
@ -67,6 +143,7 @@ func TestCreateCluster(t *testing.T) {
}, },
"init fails": { "init fails": {
provider: cloudprovider.QEMU, provider: cloudprovider.QEMU,
vars: qemuVars,
tf: &stubTerraform{initErr: someErr}, tf: &stubTerraform{initErr: someErr},
fs: afero.NewMemMapFs(), fs: afero.NewMemMapFs(),
wantErr: true, wantErr: true,
@ -111,17 +188,12 @@ func TestCreateCluster(t *testing.T) {
fs: afero.NewMemMapFs(), fs: afero.NewMemMapFs(),
wantErr: true, wantErr: true,
}, },
"prepare workspace fails": {
provider: cloudprovider.QEMU,
tf: &stubTerraform{showState: newTestState()},
fs: afero.NewReadOnlyFs(afero.NewMemMapFs()),
wantErr: true,
},
} }
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t)
c := &Client{ c := &Client{
tf: tc.tf, tf: tc.tf,
@ -129,7 +201,8 @@ func TestCreateCluster(t *testing.T) {
workingDir: constants.TerraformWorkingDir, workingDir: constants.TerraformWorkingDir,
} }
ip, err := c.CreateCluster(context.Background(), tc.provider, tc.vars) require.NoError(c.PrepareWorkspace(tc.provider, tc.vars))
ip, err := c.CreateCluster(context.Background())
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)