mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
Terraform: Only rollback after we fully created the workspace
This commit is contained in:
parent
19fb6f1233
commit
e1d8926395
@ -14,7 +14,8 @@ import (
|
||||
)
|
||||
|
||||
type terraformClient interface {
|
||||
CreateCluster(ctx context.Context, provider cloudprovider.Provider, input terraform.Variables) (string, error)
|
||||
PrepareWorkspace(provider cloudprovider.Provider, input terraform.Variables) error
|
||||
CreateCluster(ctx context.Context) (string, error)
|
||||
DestroyCluster(ctx context.Context) error
|
||||
CleanUpWorkspace() error
|
||||
RemoveInstaller()
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/terraform"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
@ -29,13 +30,18 @@ type stubTerraformClient struct {
|
||||
destroyClusterCalled bool
|
||||
createClusterErr error
|
||||
destroyClusterErr error
|
||||
prepareWorkspaceErr error
|
||||
cleanUpWorkspaceErr error
|
||||
}
|
||||
|
||||
func (c *stubTerraformClient) CreateCluster(ctx context.Context, provider cloudprovider.Provider, input terraform.Variables) (string, error) {
|
||||
func (c *stubTerraformClient) CreateCluster(ctx context.Context) (string, error) {
|
||||
return c.ip, c.createClusterErr
|
||||
}
|
||||
|
||||
func (c *stubTerraformClient) PrepareWorkspace(provider cloudprovider.Provider, input terraform.Variables) error {
|
||||
return c.prepareWorkspaceErr
|
||||
}
|
||||
|
||||
func (c *stubTerraformClient) DestroyCluster(ctx context.Context) error {
|
||||
c.destroyClusterCalled = true
|
||||
return c.destroyClusterErr
|
||||
|
@ -88,8 +88,6 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c
|
||||
func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *config.Config,
|
||||
name, insType string, controlPlaneCount, workerCount int,
|
||||
) (idFile clusterid.File, retErr error) {
|
||||
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
|
||||
|
||||
vars := terraform.AWSVariables{
|
||||
CommonVariables: terraform.CommonVariables{
|
||||
Name: name,
|
||||
@ -107,7 +105,12 @@ func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *con
|
||||
Debug: config.IsDebugCluster(),
|
||||
}
|
||||
|
||||
ip, err := cl.CreateCluster(ctx, cloudprovider.AWS, &vars)
|
||||
if err := cl.PrepareWorkspace(cloudprovider.AWS, &vars); err != nil {
|
||||
return clusterid.File{}, err
|
||||
}
|
||||
|
||||
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
|
||||
ip, err := cl.CreateCluster(ctx)
|
||||
if err != nil {
|
||||
return clusterid.File{}, err
|
||||
}
|
||||
@ -121,8 +124,6 @@ func (c *Creator) createAWS(ctx context.Context, cl terraformClient, config *con
|
||||
func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *config.Config,
|
||||
name, insType string, controlPlaneCount, workerCount int,
|
||||
) (idFile clusterid.File, retErr error) {
|
||||
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
|
||||
|
||||
vars := terraform.GCPVariables{
|
||||
CommonVariables: terraform.CommonVariables{
|
||||
Name: name,
|
||||
@ -140,7 +141,12 @@ func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *con
|
||||
Debug: config.IsDebugCluster(),
|
||||
}
|
||||
|
||||
ip, err := cl.CreateCluster(ctx, cloudprovider.GCP, &vars)
|
||||
if err := cl.PrepareWorkspace(cloudprovider.GCP, &vars); err != nil {
|
||||
return clusterid.File{}, err
|
||||
}
|
||||
|
||||
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
|
||||
ip, err := cl.CreateCluster(ctx)
|
||||
if err != nil {
|
||||
return clusterid.File{}, err
|
||||
}
|
||||
@ -154,8 +160,6 @@ func (c *Creator) createGCP(ctx context.Context, cl terraformClient, config *con
|
||||
func (c *Creator) createAzure(ctx context.Context, cl terraformClient, config *config.Config,
|
||||
name, insType string, controlPlaneCount, workerCount int,
|
||||
) (idFile clusterid.File, retErr error) {
|
||||
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
|
||||
|
||||
vars := terraform.AzureVariables{
|
||||
CommonVariables: terraform.CommonVariables{
|
||||
Name: name,
|
||||
@ -176,7 +180,12 @@ func (c *Creator) createAzure(ctx context.Context, cl terraformClient, config *c
|
||||
|
||||
vars = normalizeAzureURIs(vars)
|
||||
|
||||
ip, err := cl.CreateCluster(ctx, cloudprovider.Azure, &vars)
|
||||
if err := cl.PrepareWorkspace(cloudprovider.Azure, &vars); err != nil {
|
||||
return clusterid.File{}, err
|
||||
}
|
||||
|
||||
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerTerraform{client: cl})
|
||||
ip, err := cl.CreateCluster(ctx)
|
||||
if err != nil {
|
||||
return clusterid.File{}, err
|
||||
}
|
||||
@ -216,7 +225,8 @@ func normalizeAzureURIs(vars terraform.AzureVariables) terraform.AzureVariables
|
||||
func (c *Creator) createQEMU(ctx context.Context, cl terraformClient, lv libvirtRunner, name string, config *config.Config,
|
||||
controlPlaneCount, workerCount int,
|
||||
) (idFile clusterid.File, retErr error) {
|
||||
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerQEMU{client: cl, libvirt: lv})
|
||||
qemuRollbacker := &rollbackerQEMU{client: cl, libvirt: lv, createdWorkspace: false}
|
||||
defer rollbackOnError(context.Background(), c.out, &retErr, qemuRollbacker)
|
||||
|
||||
libvirtURI := config.Provider.QEMU.LibvirtURI
|
||||
libvirtSocketPath := "."
|
||||
@ -273,7 +283,14 @@ func (c *Creator) createQEMU(ctx context.Context, cl terraformClient, lv libvirt
|
||||
Firmware: config.Provider.QEMU.Firmware,
|
||||
}
|
||||
|
||||
ip, err := cl.CreateCluster(ctx, cloudprovider.QEMU, &vars)
|
||||
if err := cl.PrepareWorkspace(cloudprovider.QEMU, &vars); err != nil {
|
||||
return clusterid.File{}, err
|
||||
}
|
||||
|
||||
// Allow rollback of QEMU Terraform workspace from this point on
|
||||
qemuRollbacker.createdWorkspace = true
|
||||
|
||||
ip, err := cl.CreateCluster(ctx)
|
||||
if err != nil {
|
||||
return clusterid.File{}, err
|
||||
}
|
||||
|
@ -25,13 +25,14 @@ func TestCreator(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
tfClient terraformClient
|
||||
newTfClientErr error
|
||||
libvirt *stubLibvirtRunner
|
||||
provider cloudprovider.Provider
|
||||
config *config.Config
|
||||
wantErr bool
|
||||
wantRollback bool // Use only together with stubClients.
|
||||
tfClient terraformClient
|
||||
newTfClientErr error
|
||||
libvirt *stubLibvirtRunner
|
||||
provider cloudprovider.Provider
|
||||
config *config.Config
|
||||
wantErr bool
|
||||
wantRollback bool // Use only together with stubClients.
|
||||
wantTerraformRollback bool // When libvirt fails, don't call into Terraform.
|
||||
}{
|
||||
"gcp": {
|
||||
tfClient: &stubTerraformClient{ip: ip},
|
||||
@ -45,11 +46,12 @@ func TestCreator(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
"gcp create cluster error": {
|
||||
tfClient: &stubTerraformClient{createClusterErr: someErr},
|
||||
provider: cloudprovider.GCP,
|
||||
config: config.Default(),
|
||||
wantErr: true,
|
||||
wantRollback: true,
|
||||
tfClient: &stubTerraformClient{createClusterErr: someErr},
|
||||
provider: cloudprovider.GCP,
|
||||
config: config.Default(),
|
||||
wantErr: true,
|
||||
wantRollback: true,
|
||||
wantTerraformRollback: true,
|
||||
},
|
||||
"qemu": {
|
||||
tfClient: &stubTerraformClient{ip: ip},
|
||||
@ -66,20 +68,22 @@ func TestCreator(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
"qemu create cluster error": {
|
||||
tfClient: &stubTerraformClient{createClusterErr: someErr},
|
||||
libvirt: &stubLibvirtRunner{},
|
||||
provider: cloudprovider.QEMU,
|
||||
config: config.Default(),
|
||||
wantErr: true,
|
||||
wantRollback: !failOnNonAMD64, // if we run on non-AMD64/linux, we don't get to a point where rollback is needed
|
||||
tfClient: &stubTerraformClient{createClusterErr: someErr},
|
||||
libvirt: &stubLibvirtRunner{},
|
||||
provider: cloudprovider.QEMU,
|
||||
config: config.Default(),
|
||||
wantErr: true,
|
||||
wantRollback: !failOnNonAMD64, // if we run on non-AMD64/linux, we don't get to a point where rollback is needed
|
||||
wantTerraformRollback: true,
|
||||
},
|
||||
"qemu start libvirt error": {
|
||||
tfClient: &stubTerraformClient{ip: ip},
|
||||
libvirt: &stubLibvirtRunner{startErr: someErr},
|
||||
provider: cloudprovider.QEMU,
|
||||
config: config.Default(),
|
||||
wantErr: true,
|
||||
wantRollback: !failOnNonAMD64,
|
||||
tfClient: &stubTerraformClient{ip: ip},
|
||||
libvirt: &stubLibvirtRunner{startErr: someErr},
|
||||
provider: cloudprovider.QEMU,
|
||||
config: config.Default(),
|
||||
wantRollback: !failOnNonAMD64,
|
||||
wantTerraformRollback: false,
|
||||
wantErr: true,
|
||||
},
|
||||
"unknown provider": {
|
||||
provider: cloudprovider.Unknown,
|
||||
@ -108,7 +112,9 @@ func TestCreator(t *testing.T) {
|
||||
assert.Error(err)
|
||||
if tc.wantRollback {
|
||||
cl := tc.tfClient.(*stubTerraformClient)
|
||||
assert.True(cl.destroyClusterCalled)
|
||||
if tc.wantTerraformRollback {
|
||||
assert.True(cl.destroyClusterCalled)
|
||||
}
|
||||
assert.True(cl.cleanUpWorkspaceCalled)
|
||||
if tc.provider == cloudprovider.QEMU {
|
||||
assert.True(tc.libvirt.stopCalled)
|
||||
|
@ -48,13 +48,16 @@ func (r *rollbackerTerraform) rollback(ctx context.Context) error {
|
||||
}
|
||||
|
||||
type rollbackerQEMU struct {
|
||||
client terraformClient
|
||||
libvirt libvirtRunner
|
||||
client terraformClient
|
||||
libvirt libvirtRunner
|
||||
createdWorkspace bool
|
||||
}
|
||||
|
||||
func (r *rollbackerQEMU) rollback(ctx context.Context) error {
|
||||
var err error
|
||||
err = multierr.Append(err, r.client.DestroyCluster(ctx))
|
||||
if r.createdWorkspace {
|
||||
err = multierr.Append(err, r.client.DestroyCluster(ctx))
|
||||
}
|
||||
err = multierr.Append(err, r.libvirt.Stop(ctx))
|
||||
if err == nil {
|
||||
err = r.client.CleanUpWorkspace()
|
||||
|
@ -61,13 +61,15 @@ func TestRollbackQEMU(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
libvirt *stubLibvirtRunner
|
||||
tfClient *stubTerraformClient
|
||||
wantErr bool
|
||||
libvirt *stubLibvirtRunner
|
||||
tfClient *stubTerraformClient
|
||||
createdWorkspace bool
|
||||
wantErr bool
|
||||
}{
|
||||
"success": {
|
||||
libvirt: &stubLibvirtRunner{},
|
||||
tfClient: &stubTerraformClient{},
|
||||
libvirt: &stubLibvirtRunner{},
|
||||
tfClient: &stubTerraformClient{},
|
||||
createdWorkspace: true,
|
||||
},
|
||||
"stop libvirt error": {
|
||||
libvirt: &stubLibvirtRunner{stopErr: someErr},
|
||||
@ -91,8 +93,9 @@ func TestRollbackQEMU(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
rollbacker := &rollbackerQEMU{
|
||||
libvirt: tc.libvirt,
|
||||
client: tc.tfClient,
|
||||
libvirt: tc.libvirt,
|
||||
client: tc.tfClient,
|
||||
createdWorkspace: tc.createdWorkspace,
|
||||
}
|
||||
|
||||
err := rollbacker.rollback(context.Background())
|
||||
@ -105,7 +108,11 @@ func TestRollbackQEMU(t *testing.T) {
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.True(tc.libvirt.stopCalled)
|
||||
assert.True(tc.tfClient.destroyClusterCalled)
|
||||
if tc.createdWorkspace {
|
||||
assert.True(tc.tfClient.destroyClusterCalled)
|
||||
} else {
|
||||
assert.False(tc.tfClient.destroyClusterCalled)
|
||||
}
|
||||
assert.True(tc.tfClient.cleanUpWorkspaceCalled)
|
||||
})
|
||||
}
|
||||
|
@ -57,16 +57,21 @@ func New(ctx context.Context, workingDir string) (*Client, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateCluster creates a Constellation cluster using Terraform.
|
||||
func (c *Client) CreateCluster(ctx context.Context, provider cloudprovider.Provider, vars Variables) (string, error) {
|
||||
// PrepareWorkspace prepares a Terraform workspace for a Constellation cluster.
|
||||
func (c *Client) PrepareWorkspace(provider cloudprovider.Provider, vars Variables) error {
|
||||
if err := prepareWorkspace(c.file, provider, c.workingDir); err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.writeVars(vars); err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateCluster creates a Constellation cluster using Terraform.
|
||||
func (c *Client) CreateCluster(ctx context.Context) (string, error) {
|
||||
if err := c.tf.Init(ctx); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -24,6 +24,82 @@ import (
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
func TestPrepareCluster(t *testing.T) {
|
||||
qemuVars := &QEMUVariables{
|
||||
CommonVariables: CommonVariables{
|
||||
Name: "name",
|
||||
CountControlPlanes: 1,
|
||||
CountWorkers: 2,
|
||||
StateDiskSizeGB: 11,
|
||||
},
|
||||
CPUCount: 1,
|
||||
MemorySizeMiB: 1024,
|
||||
ImagePath: "path",
|
||||
ImageFormat: "format",
|
||||
MetadataAPIImage: "api",
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
provider cloudprovider.Provider
|
||||
vars Variables
|
||||
fs afero.Fs
|
||||
partiallyExtracted bool
|
||||
wantErr bool
|
||||
}{
|
||||
"qemu": {
|
||||
provider: cloudprovider.QEMU,
|
||||
vars: qemuVars,
|
||||
fs: afero.NewMemMapFs(),
|
||||
wantErr: false,
|
||||
},
|
||||
"no vars": {
|
||||
provider: cloudprovider.QEMU,
|
||||
fs: afero.NewMemMapFs(),
|
||||
wantErr: true,
|
||||
},
|
||||
"continue on partially extracted": {
|
||||
provider: cloudprovider.QEMU,
|
||||
vars: qemuVars,
|
||||
fs: afero.NewMemMapFs(),
|
||||
partiallyExtracted: true,
|
||||
wantErr: false,
|
||||
},
|
||||
"prepare workspace fails": {
|
||||
provider: cloudprovider.QEMU,
|
||||
vars: qemuVars,
|
||||
fs: afero.NewReadOnlyFs(afero.NewMemMapFs()),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
c := &Client{
|
||||
tf: &stubTerraform{},
|
||||
file: file.NewHandler(tc.fs),
|
||||
workingDir: constants.TerraformWorkingDir,
|
||||
}
|
||||
|
||||
err := c.PrepareWorkspace(tc.provider, tc.vars)
|
||||
|
||||
// Test case: Check if we can continue to create on an incomplete workspace.
|
||||
if tc.partiallyExtracted {
|
||||
require.NoError(c.file.Remove(filepath.Join(c.workingDir, "main.tf")))
|
||||
err = c.PrepareWorkspace(tc.provider, tc.vars)
|
||||
}
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCluster(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
newTestState := func() *tfjson.State {
|
||||
@ -67,6 +143,7 @@ func TestCreateCluster(t *testing.T) {
|
||||
},
|
||||
"init fails": {
|
||||
provider: cloudprovider.QEMU,
|
||||
vars: qemuVars,
|
||||
tf: &stubTerraform{initErr: someErr},
|
||||
fs: afero.NewMemMapFs(),
|
||||
wantErr: true,
|
||||
@ -111,17 +188,12 @@ func TestCreateCluster(t *testing.T) {
|
||||
fs: afero.NewMemMapFs(),
|
||||
wantErr: true,
|
||||
},
|
||||
"prepare workspace fails": {
|
||||
provider: cloudprovider.QEMU,
|
||||
tf: &stubTerraform{showState: newTestState()},
|
||||
fs: afero.NewReadOnlyFs(afero.NewMemMapFs()),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
c := &Client{
|
||||
tf: tc.tf,
|
||||
@ -129,7 +201,8 @@ func TestCreateCluster(t *testing.T) {
|
||||
workingDir: constants.TerraformWorkingDir,
|
||||
}
|
||||
|
||||
ip, err := c.CreateCluster(context.Background(), tc.provider, tc.vars)
|
||||
require.NoError(c.PrepareWorkspace(tc.provider, tc.vars))
|
||||
ip, err := c.CreateCluster(context.Background())
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
|
Loading…
Reference in New Issue
Block a user