From fe7e16e1cc60f03fd6ddd256309478c6dfb94603 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= <66256922+daniel-weisse@users.noreply.github.com> Date: Tue, 17 Oct 2023 14:37:09 +0200 Subject: [PATCH] cli: create or read state file during `constellation create` (#2470) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniel Weiße --- cli/internal/cmd/create.go | 2 +- cli/internal/cmd/create_test.go | 17 ++- cli/internal/state/state.go | 16 +++ cli/internal/state/state_test.go | 181 ++++++++++++++++++------------- 4 files changed, 128 insertions(+), 88 deletions(-) diff --git a/cli/internal/cmd/create.go b/cli/internal/cmd/create.go index c9be99d3c..6f92c18de 100644 --- a/cli/internal/cmd/create.go +++ b/cli/internal/cmd/create.go @@ -191,7 +191,7 @@ func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler } c.log.Debugf("Successfully created the cloud resources for the cluster") - stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename) + stateFile, err := state.CreateOrRead(fileHandler, constants.StateFilename) if err != nil { return fmt.Errorf("reading state file: %w", err) } diff --git a/cli/internal/cmd/create_test.go b/cli/internal/cmd/create_test.go index 5df2f7fa7..22216c7d6 100644 --- a/cli/internal/cmd/create_test.go +++ b/cli/internal/cmd/create_test.go @@ -72,14 +72,14 @@ func TestCreate(t *testing.T) { }, "interactive abort": { setupFs: fsWithDefaultConfigAndState, - creator: &stubCloudCreator{}, + creator: &stubCloudCreator{state: infraState}, provider: cloudprovider.GCP, stdin: "no\n", wantAbort: true, }, "interactive error": { setupFs: fsWithDefaultConfigAndState, - creator: &stubCloudCreator{}, + creator: &stubCloudCreator{state: infraState}, provider: cloudprovider.GCP, stdin: "foo\nfoo\nfoo\n", wantErr: true, @@ -92,7 +92,7 @@ func TestCreate(t *testing.T) { require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), csp))) return fs }, - creator: &stubCloudCreator{}, + creator: &stubCloudCreator{state: infraState}, provider: cloudprovider.GCP, yesFlag: true, wantErr: true, @@ -105,24 +105,23 @@ func TestCreate(t *testing.T) { require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), csp))) return fs }, - creator: &stubCloudCreator{}, + creator: &stubCloudCreator{state: infraState}, provider: cloudprovider.GCP, yesFlag: true, wantErr: true, }, "config does not exist": { setupFs: func(a *require.Assertions, p cloudprovider.Provider) afero.Fs { return afero.NewMemMapFs() }, - creator: &stubCloudCreator{}, + creator: &stubCloudCreator{state: infraState}, provider: cloudprovider.GCP, yesFlag: true, wantErr: true, }, "state file does not exist": { setupFs: fsWithoutState, - creator: &stubCloudCreator{}, + creator: &stubCloudCreator{state: infraState}, provider: cloudprovider.GCP, yesFlag: true, - wantErr: true, }, "create error": { setupFs: fsWithDefaultConfigAndState, @@ -131,14 +130,14 @@ func TestCreate(t *testing.T) { yesFlag: true, wantErr: true, }, - "write id file error": { + "write state file error": { setupFs: func(require *require.Assertions, csp cloudprovider.Provider) afero.Fs { fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), csp))) return afero.NewReadOnlyFs(fs) }, - creator: &stubCloudCreator{}, + creator: &stubCloudCreator{state: infraState}, provider: cloudprovider.GCP, yesFlag: true, wantErr: true, diff --git a/cli/internal/state/state.go b/cli/internal/state/state.go index 14bbb59cc..b4409eeee 100644 --- a/cli/internal/state/state.go +++ b/cli/internal/state/state.go @@ -14,7 +14,9 @@ package state import ( "encoding/hex" + "errors" "fmt" + "os" "dario.cat/mergo" "github.com/edgelesssys/constellation/v2/internal/file" @@ -34,6 +36,20 @@ func ReadFromFile(fileHandler file.Handler, path string) (*State, error) { return state, nil } +// CreateOrRead reads the state file at the given path, if it exists, and returns the state. +// If the file does not exist, a new state is created and written to disk. +func CreateOrRead(fileHandler file.Handler, path string) (*State, error) { + state := &State{} + if err := fileHandler.ReadYAML(path, &state); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("reading state file: %w", err) + } + state = New() + return state, state.WriteToFile(fileHandler, path) + } + return state, nil +} + // State describe the entire state to describe a Constellation cluster. type State struct { // description: | diff --git a/cli/internal/state/state_test.go b/cli/internal/state/state_test.go index 891b031eb..0eb567ef2 100644 --- a/cli/internal/state/state_test.go +++ b/cli/internal/state/state_test.go @@ -18,68 +18,63 @@ import ( "gopkg.in/yaml.v3" ) -var defaultState = &State{ - Version: "v1", - Infrastructure: Infrastructure{ - UID: "123", - ClusterEndpoint: "test-cluster-endpoint", - InitSecret: []byte{0x41}, - APIServerCertSANs: []string{ - "api-server-cert-san-test", - "api-server-cert-san-test-2", +func defaultState() *State { + return &State{ + Version: "v1", + Infrastructure: Infrastructure{ + UID: "123", + ClusterEndpoint: "test-cluster-endpoint", + InitSecret: []byte{0x41}, + APIServerCertSANs: []string{ + "api-server-cert-san-test", + "api-server-cert-san-test-2", + }, + Azure: &Azure{ + ResourceGroup: "test-rg", + SubscriptionID: "test-sub", + NetworkSecurityGroupName: "test-nsg", + LoadBalancerName: "test-lb", + UserAssignedIdentity: "test-uami", + AttestationURL: "test-maaUrl", + }, + GCP: &GCP{ + ProjectID: "test-project", + IPCidrNode: "test-cidr-node", + IPCidrPod: "test-cidr-pod", + }, }, - Azure: &Azure{ - ResourceGroup: "test-rg", - SubscriptionID: "test-sub", - NetworkSecurityGroupName: "test-nsg", - LoadBalancerName: "test-lb", - UserAssignedIdentity: "test-uami", - AttestationURL: "test-maaUrl", + ClusterValues: ClusterValues{ + ClusterID: "test-cluster-id", + OwnerID: "test-owner-id", + MeasurementSalt: []byte{0x41}, }, - GCP: &GCP{ - ProjectID: "test-project", - IPCidrNode: "test-cidr-node", - IPCidrPod: "test-cidr-pod", - }, - }, - ClusterValues: ClusterValues{ - ClusterID: "test-cluster-id", - OwnerID: "test-owner-id", - MeasurementSalt: []byte{0x41}, - }, + } } func TestWriteToFile(t *testing.T) { - prepareFs := func(existingFiles ...string) file.Handler { - fs := afero.NewMemMapFs() - fh := file.NewHandler(fs) - for _, name := range existingFiles { - if err := fh.Write(name, []byte{0x41}); err != nil { - t.Fatalf("failed to create file %s: %v", name, err) - } - } - return fh - } - testCases := map[string]struct { state *State fh file.Handler wantErr bool }{ "success": { - state: defaultState, - fh: prepareFs(), + state: defaultState(), + fh: file.NewHandler(afero.NewMemMapFs()), }, "overwrite": { - state: defaultState, - fh: prepareFs(constants.StateFilename), + state: defaultState(), + fh: func() file.Handler { + fs := file.NewHandler(afero.NewMemMapFs()) + require.NoError(t, fs.Write(constants.StateFilename, []byte{0x41})) + return fs + }(), }, "empty state": { state: &State{}, - fh: prepareFs(), + fh: file.NewHandler(afero.NewMemMapFs()), }, "rofs": { - state: defaultState, + state: defaultState(), fh: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())), wantErr: true, }, @@ -88,6 +83,7 @@ func TestWriteToFile(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) + require := require.New(t) err := tc.state.WriteToFile(tc.fh, constants.StateFilename) @@ -95,72 +91,59 @@ func TestWriteToFile(t *testing.T) { assert.Error(err) } else { assert.NoError(err) - assert.Equal(mustMarshalYaml(t, tc.state), mustReadFromFile(t, tc.fh)) + assert.YAMLEq(mustMarshalYaml(require, tc.state), mustReadFromFile(require, tc.fh)) } }) } } func TestReadFromFile(t *testing.T) { - prepareFs := func(existingFiles map[string][]byte) file.Handler { - fs := afero.NewMemMapFs() - fh := file.NewHandler(fs) - for name, content := range existingFiles { - if err := fh.Write(name, content); err != nil { - t.Fatalf("failed to create file %s: %v", name, err) - } - } - return fh - } - testCases := map[string]struct { - existingFiles map[string][]byte - wantErr bool + fs file.Handler + wantState *State + wantErr bool }{ "success": { - existingFiles: map[string][]byte{ - constants.StateFilename: mustMarshalYaml(t, defaultState), - }, + fs: func() file.Handler { + fh := file.NewHandler(afero.NewMemMapFs()) + require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState())) + return fh + }(), + wantState: defaultState(), }, "no state file present": { - existingFiles: map[string][]byte{}, - wantErr: true, + fs: file.NewHandler(afero.NewMemMapFs()), + wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) - fh := prepareFs(tc.existingFiles) + require := require.New(t) - state, err := ReadFromFile(fh, constants.StateFilename) + state, err := ReadFromFile(tc.fs, constants.StateFilename) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) - assert.Equal(tc.existingFiles[constants.StateFilename], mustMarshalYaml(t, state)) + assert.YAMLEq(mustMarshalYaml(require, tc.wantState), mustMarshalYaml(require, state)) } }) } } -func mustMarshalYaml(t *testing.T, v any) []byte { - t.Helper() +func mustMarshalYaml(require *require.Assertions, v any) string { b, err := encoder.NewEncoder(v).Encode() - if err != nil { - t.Fatalf("failed to marshal yaml: %v", err) - } - return b + require.NoError(err) + return string(b) } -func mustReadFromFile(t *testing.T, fh file.Handler) []byte { - t.Helper() +func mustReadFromFile(require *require.Assertions, fh file.Handler) string { b, err := fh.Read(constants.StateFilename) - if err != nil { - t.Fatalf("failed to read file: %v", err) - } - return b + require.NoError(err) + return string(b) } func TestMerge(t *testing.T) { @@ -419,3 +402,45 @@ func TestMarshalUnmarshalHexBytes(t *testing.T) { require.NoError(t, err) assert.Equal(t, in, actual2) } + +func TestCreateOrRead(t *testing.T) { + testCases := map[string]struct { + fs file.Handler + wantState *State + wantErr bool + }{ + "file exists": { + fs: func() file.Handler { + fh := file.NewHandler(afero.NewMemMapFs()) + require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState())) + return fh + }(), + wantState: defaultState(), + }, + "file does not exist": { + fs: file.NewHandler(afero.NewMemMapFs()), + wantState: New(), + }, + "unable to write file": { + fs: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())), + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + require := require.New(t) + state, err := CreateOrRead(tc.fs, constants.StateFilename) + + if tc.wantErr { + assert.Error(err) + return + } + assert.NoError(err) + assert.YAMLEq(mustMarshalYaml(require, tc.wantState), mustMarshalYaml(require, state)) + assert.YAMLEq(mustMarshalYaml(require, tc.wantState), mustReadFromFile(require, tc.fs)) + }) + } +}