cli: create or read state file during constellation create (#2470)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-10-17 14:37:09 +02:00 committed by GitHub
parent 1a141c3972
commit fe7e16e1cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 128 additions and 88 deletions

View file

@ -191,7 +191,7 @@ func (c *createCmd) create(cmd *cobra.Command, creator cloudCreator, fileHandler
} }
c.log.Debugf("Successfully created the cloud resources for the cluster") c.log.Debugf("Successfully created the cloud resources for the cluster")
stateFile, err := state.ReadFromFile(fileHandler, constants.StateFilename) stateFile, err := state.CreateOrRead(fileHandler, constants.StateFilename)
if err != nil { if err != nil {
return fmt.Errorf("reading state file: %w", err) return fmt.Errorf("reading state file: %w", err)
} }

View file

@ -72,14 +72,14 @@ func TestCreate(t *testing.T) {
}, },
"interactive abort": { "interactive abort": {
setupFs: fsWithDefaultConfigAndState, setupFs: fsWithDefaultConfigAndState,
creator: &stubCloudCreator{}, creator: &stubCloudCreator{state: infraState},
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
stdin: "no\n", stdin: "no\n",
wantAbort: true, wantAbort: true,
}, },
"interactive error": { "interactive error": {
setupFs: fsWithDefaultConfigAndState, setupFs: fsWithDefaultConfigAndState,
creator: &stubCloudCreator{}, creator: &stubCloudCreator{state: infraState},
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
stdin: "foo\nfoo\nfoo\n", stdin: "foo\nfoo\nfoo\n",
wantErr: true, wantErr: true,
@ -92,7 +92,7 @@ func TestCreate(t *testing.T) {
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), csp))) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), csp)))
return fs return fs
}, },
creator: &stubCloudCreator{}, creator: &stubCloudCreator{state: infraState},
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
yesFlag: true, yesFlag: true,
wantErr: true, wantErr: true,
@ -105,24 +105,23 @@ func TestCreate(t *testing.T) {
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), csp))) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), csp)))
return fs return fs
}, },
creator: &stubCloudCreator{}, creator: &stubCloudCreator{state: infraState},
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
yesFlag: true, yesFlag: true,
wantErr: true, wantErr: true,
}, },
"config does not exist": { "config does not exist": {
setupFs: func(a *require.Assertions, p cloudprovider.Provider) afero.Fs { return afero.NewMemMapFs() }, setupFs: func(a *require.Assertions, p cloudprovider.Provider) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{}, creator: &stubCloudCreator{state: infraState},
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
yesFlag: true, yesFlag: true,
wantErr: true, wantErr: true,
}, },
"state file does not exist": { "state file does not exist": {
setupFs: fsWithoutState, setupFs: fsWithoutState,
creator: &stubCloudCreator{}, creator: &stubCloudCreator{state: infraState},
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
yesFlag: true, yesFlag: true,
wantErr: true,
}, },
"create error": { "create error": {
setupFs: fsWithDefaultConfigAndState, setupFs: fsWithDefaultConfigAndState,
@ -131,14 +130,14 @@ func TestCreate(t *testing.T) {
yesFlag: true, yesFlag: true,
wantErr: true, wantErr: true,
}, },
"write id file error": { "write state file error": {
setupFs: func(require *require.Assertions, csp cloudprovider.Provider) afero.Fs { setupFs: func(require *require.Assertions, csp cloudprovider.Provider) afero.Fs {
fs := afero.NewMemMapFs() fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs) fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), csp))) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, defaultConfigWithExpectedMeasurements(t, config.Default(), csp)))
return afero.NewReadOnlyFs(fs) return afero.NewReadOnlyFs(fs)
}, },
creator: &stubCloudCreator{}, creator: &stubCloudCreator{state: infraState},
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
yesFlag: true, yesFlag: true,
wantErr: true, wantErr: true,

View file

@ -14,7 +14,9 @@ package state
import ( import (
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"os"
"dario.cat/mergo" "dario.cat/mergo"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
@ -34,6 +36,20 @@ func ReadFromFile(fileHandler file.Handler, path string) (*State, error) {
return state, nil return state, nil
} }
// CreateOrRead reads the state file at the given path, if it exists, and returns the state.
// If the file does not exist, a new state is created and written to disk.
func CreateOrRead(fileHandler file.Handler, path string) (*State, error) {
state := &State{}
if err := fileHandler.ReadYAML(path, &state); err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("reading state file: %w", err)
}
state = New()
return state, state.WriteToFile(fileHandler, path)
}
return state, nil
}
// State describe the entire state to describe a Constellation cluster. // State describe the entire state to describe a Constellation cluster.
type State struct { type State struct {
// description: | // description: |

View file

@ -18,68 +18,63 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
var defaultState = &State{ func defaultState() *State {
Version: "v1", return &State{
Infrastructure: Infrastructure{ Version: "v1",
UID: "123", Infrastructure: Infrastructure{
ClusterEndpoint: "test-cluster-endpoint", UID: "123",
InitSecret: []byte{0x41}, ClusterEndpoint: "test-cluster-endpoint",
APIServerCertSANs: []string{ InitSecret: []byte{0x41},
"api-server-cert-san-test", APIServerCertSANs: []string{
"api-server-cert-san-test-2", "api-server-cert-san-test",
"api-server-cert-san-test-2",
},
Azure: &Azure{
ResourceGroup: "test-rg",
SubscriptionID: "test-sub",
NetworkSecurityGroupName: "test-nsg",
LoadBalancerName: "test-lb",
UserAssignedIdentity: "test-uami",
AttestationURL: "test-maaUrl",
},
GCP: &GCP{
ProjectID: "test-project",
IPCidrNode: "test-cidr-node",
IPCidrPod: "test-cidr-pod",
},
}, },
Azure: &Azure{ ClusterValues: ClusterValues{
ResourceGroup: "test-rg", ClusterID: "test-cluster-id",
SubscriptionID: "test-sub", OwnerID: "test-owner-id",
NetworkSecurityGroupName: "test-nsg", MeasurementSalt: []byte{0x41},
LoadBalancerName: "test-lb",
UserAssignedIdentity: "test-uami",
AttestationURL: "test-maaUrl",
}, },
GCP: &GCP{ }
ProjectID: "test-project",
IPCidrNode: "test-cidr-node",
IPCidrPod: "test-cidr-pod",
},
},
ClusterValues: ClusterValues{
ClusterID: "test-cluster-id",
OwnerID: "test-owner-id",
MeasurementSalt: []byte{0x41},
},
} }
func TestWriteToFile(t *testing.T) { func TestWriteToFile(t *testing.T) {
prepareFs := func(existingFiles ...string) file.Handler {
fs := afero.NewMemMapFs()
fh := file.NewHandler(fs)
for _, name := range existingFiles {
if err := fh.Write(name, []byte{0x41}); err != nil {
t.Fatalf("failed to create file %s: %v", name, err)
}
}
return fh
}
testCases := map[string]struct { testCases := map[string]struct {
state *State state *State
fh file.Handler fh file.Handler
wantErr bool wantErr bool
}{ }{
"success": { "success": {
state: defaultState, state: defaultState(),
fh: prepareFs(), fh: file.NewHandler(afero.NewMemMapFs()),
}, },
"overwrite": { "overwrite": {
state: defaultState, state: defaultState(),
fh: prepareFs(constants.StateFilename), fh: func() file.Handler {
fs := file.NewHandler(afero.NewMemMapFs())
require.NoError(t, fs.Write(constants.StateFilename, []byte{0x41}))
return fs
}(),
}, },
"empty state": { "empty state": {
state: &State{}, state: &State{},
fh: prepareFs(), fh: file.NewHandler(afero.NewMemMapFs()),
}, },
"rofs": { "rofs": {
state: defaultState, state: defaultState(),
fh: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())), fh: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())),
wantErr: true, wantErr: true,
}, },
@ -88,6 +83,7 @@ func TestWriteToFile(t *testing.T) {
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t)
err := tc.state.WriteToFile(tc.fh, constants.StateFilename) err := tc.state.WriteToFile(tc.fh, constants.StateFilename)
@ -95,72 +91,59 @@ func TestWriteToFile(t *testing.T) {
assert.Error(err) assert.Error(err)
} else { } else {
assert.NoError(err) assert.NoError(err)
assert.Equal(mustMarshalYaml(t, tc.state), mustReadFromFile(t, tc.fh)) assert.YAMLEq(mustMarshalYaml(require, tc.state), mustReadFromFile(require, tc.fh))
} }
}) })
} }
} }
func TestReadFromFile(t *testing.T) { func TestReadFromFile(t *testing.T) {
prepareFs := func(existingFiles map[string][]byte) file.Handler {
fs := afero.NewMemMapFs()
fh := file.NewHandler(fs)
for name, content := range existingFiles {
if err := fh.Write(name, content); err != nil {
t.Fatalf("failed to create file %s: %v", name, err)
}
}
return fh
}
testCases := map[string]struct { testCases := map[string]struct {
existingFiles map[string][]byte fs file.Handler
wantErr bool wantState *State
wantErr bool
}{ }{
"success": { "success": {
existingFiles: map[string][]byte{ fs: func() file.Handler {
constants.StateFilename: mustMarshalYaml(t, defaultState), fh := file.NewHandler(afero.NewMemMapFs())
}, require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState()))
return fh
}(),
wantState: defaultState(),
}, },
"no state file present": { "no state file present": {
existingFiles: map[string][]byte{}, fs: file.NewHandler(afero.NewMemMapFs()),
wantErr: true, wantErr: true,
}, },
} }
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
fh := prepareFs(tc.existingFiles) require := require.New(t)
state, err := ReadFromFile(fh, constants.StateFilename) state, err := ReadFromFile(tc.fs, constants.StateFilename)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { } else {
assert.NoError(err) assert.NoError(err)
assert.Equal(tc.existingFiles[constants.StateFilename], mustMarshalYaml(t, state)) assert.YAMLEq(mustMarshalYaml(require, tc.wantState), mustMarshalYaml(require, state))
} }
}) })
} }
} }
func mustMarshalYaml(t *testing.T, v any) []byte { func mustMarshalYaml(require *require.Assertions, v any) string {
t.Helper()
b, err := encoder.NewEncoder(v).Encode() b, err := encoder.NewEncoder(v).Encode()
if err != nil { require.NoError(err)
t.Fatalf("failed to marshal yaml: %v", err) return string(b)
}
return b
} }
func mustReadFromFile(t *testing.T, fh file.Handler) []byte { func mustReadFromFile(require *require.Assertions, fh file.Handler) string {
t.Helper()
b, err := fh.Read(constants.StateFilename) b, err := fh.Read(constants.StateFilename)
if err != nil { require.NoError(err)
t.Fatalf("failed to read file: %v", err) return string(b)
}
return b
} }
func TestMerge(t *testing.T) { func TestMerge(t *testing.T) {
@ -419,3 +402,45 @@ func TestMarshalUnmarshalHexBytes(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, in, actual2) assert.Equal(t, in, actual2)
} }
func TestCreateOrRead(t *testing.T) {
testCases := map[string]struct {
fs file.Handler
wantState *State
wantErr bool
}{
"file exists": {
fs: func() file.Handler {
fh := file.NewHandler(afero.NewMemMapFs())
require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultState()))
return fh
}(),
wantState: defaultState(),
},
"file does not exist": {
fs: file.NewHandler(afero.NewMemMapFs()),
wantState: New(),
},
"unable to write file": {
fs: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
state, err := CreateOrRead(tc.fs, constants.StateFilename)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.YAMLEq(mustMarshalYaml(require, tc.wantState), mustMarshalYaml(require, state))
assert.YAMLEq(mustMarshalYaml(require, tc.wantState), mustReadFromFile(require, tc.fs))
})
}
}