diff --git a/cli/cmd/create_aws.go b/cli/cmd/create_aws.go index 0c1a83cf9..cbc165f1e 100644 --- a/cli/cmd/create_aws.go +++ b/cli/cmd/create_aws.go @@ -112,7 +112,7 @@ func createAWS(cmd *cobra.Command, cl ec2client, fileHandler file.Handler, confi if err != nil { return err } - if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil { + if err := fileHandler.WriteJSON(*config.StatePath, stat, file.OptNone); err != nil { return err } diff --git a/cli/cmd/create_aws_test.go b/cli/cmd/create_aws_test.go index 48f47078f..b71aadcaa 100644 --- a/cli/cmd/create_aws_test.go +++ b/cli/cmd/create_aws_test.go @@ -138,7 +138,7 @@ func TestCreateAWS(t *testing.T) { fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) if tc.existingState != nil { - require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false)) + require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, file.OptNone)) } err := createAWS(cmd, tc.client, fileHandler, config, "xlarge", "name", 3) diff --git a/cli/cmd/create_azure.go b/cli/cmd/create_azure.go index e45c653a3..6296ec75a 100644 --- a/cli/cmd/create_azure.go +++ b/cli/cmd/create_azure.go @@ -120,7 +120,7 @@ func createAzure(cmd *cobra.Command, cl azureclient, fileHandler file.Handler, c if err != nil { return err } - if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil { + if err := fileHandler.WriteJSON(*config.StatePath, stat, file.OptNone); err != nil { return err } diff --git a/cli/cmd/create_azure_test.go b/cli/cmd/create_azure_test.go index a558a208a..38921c947 100644 --- a/cli/cmd/create_azure_test.go +++ b/cli/cmd/create_azure_test.go @@ -155,7 +155,7 @@ func TestCreateAzure(t *testing.T) { fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) if tc.existingState != nil { - require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false)) + require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, file.OptNone)) } err := createAzure(cmd, tc.client, fileHandler, config, "Standard_D2s_v3", 3, 2) diff --git a/cli/cmd/create_gcp.go b/cli/cmd/create_gcp.go index 23a93f3c4..a971a0fb0 100644 --- a/cli/cmd/create_gcp.go +++ b/cli/cmd/create_gcp.go @@ -114,7 +114,7 @@ func createGCP(cmd *cobra.Command, cl gcpclient, fileHandler file.Handler, confi return err } - if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil { + if err := fileHandler.WriteJSON(*config.StatePath, stat, file.OptNone); err != nil { return err } diff --git a/cli/cmd/create_gcp_test.go b/cli/cmd/create_gcp_test.go index eaa112611..0c8f6ece0 100644 --- a/cli/cmd/create_gcp_test.go +++ b/cli/cmd/create_gcp_test.go @@ -153,7 +153,7 @@ func TestCreateGCP(t *testing.T) { fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) if tc.existingState != nil { - require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false)) + require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, file.OptNone)) } err := createGCP(cmd, tc.client, fileHandler, config, "n2d-standard-2", 3, 2) diff --git a/cli/cmd/create_test.go b/cli/cmd/create_test.go index dbf495cc4..2baef702f 100644 --- a/cli/cmd/create_test.go +++ b/cli/cmd/create_test.go @@ -49,7 +49,7 @@ func TestCheckDirClean(t *testing.T) { require := require.New(t) for _, f := range tc.existingFiles { - require.NoError(tc.fileHandler.Write(f, []byte{1, 2, 3}, false)) + require.NoError(tc.fileHandler.Write(f, []byte{1, 2, 3}, file.OptNone)) } err := checkDirClean(tc.fileHandler, config) diff --git a/cli/cmd/init.go b/cli/cmd/init.go index 211cb7da2..c26739de1 100644 --- a/cli/cmd/init.go +++ b/cli/cmd/init.go @@ -101,7 +101,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser if err != nil { return err } - if err := fileHandler.WriteJSON(*config.StatePath, stat, true); err != nil { + if err := fileHandler.WriteJSON(*config.StatePath, stat, file.OptOverwrite); err != nil { return err } @@ -228,7 +228,7 @@ func writeWGQuickFile(fileHandler file.Handler, config *config.Config, vpnHandle if err != nil { return err } - return fileHandler.Write(*config.WGQuickConfigPath, data, false) + return fileHandler.Write(*config.WGQuickConfigPath, data, file.OptNone) } func (r activationResult) writeOutput(wr io.Writer, fileHandler file.Handler, config *config.Config) error { @@ -245,7 +245,7 @@ func (r activationResult) writeOutput(wr io.Writer, fileHandler file.Handler, co tw.Flush() fmt.Fprintln(wr) - if err := fileHandler.Write(*config.AdminConfPath, []byte(r.kubeconfig), false); err != nil { + if err := fileHandler.Write(*config.AdminConfPath, []byte(r.kubeconfig), file.OptNone); err != nil { return fmt.Errorf("write kubeconfig: %w", err) } @@ -360,7 +360,7 @@ func readOrGeneratedMasterSecret(w io.Writer, fileHandler file.Handler, filename if err != nil { return nil, err } - if err := fileHandler.Write(*config.MasterSecretPath, []byte(base64.StdEncoding.EncodeToString(masterSecret)), false); err != nil { + if err := fileHandler.Write(*config.MasterSecretPath, []byte(base64.StdEncoding.EncodeToString(masterSecret)), file.OptNone); err != nil { return nil, err } fmt.Fprintf(w, "Your Constellation master secret was successfully written to ./%s\n", *config.MasterSecretPath) diff --git a/cli/cmd/init_test.go b/cli/cmd/init_test.go index 3394e375e..b21a5bd93 100644 --- a/cli/cmd/init_test.go +++ b/cli/cmd/init_test.go @@ -336,7 +336,7 @@ func TestInitialize(t *testing.T) { cmd.SetErr(&errOut) fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) - require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false)) + require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, file.OptNone)) // Write key file to filesystem and set path in flag. require.NoError(afero.Afero{Fs: fs}.WriteFile("privK", []byte(tc.privKey), 0o600)) @@ -444,7 +444,7 @@ func TestReadOrGenerateVPNKey(t *testing.T) { testKey := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))) fileHandler := file.NewHandler(afero.NewMemMapFs()) - require.NoError(fileHandler.Write("testKey", testKey, false)) + require.NoError(fileHandler.Write("testKey", testKey, file.OptNone)) privK, pubK, err := readOrGenerateVPNKey(fileHandler, "testKey") assert.NoError(err) @@ -525,7 +525,7 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) { config := config.Default() if tc.createFile { - require.NoError(fileHandler.Write(tc.filename, []byte(tc.filecontent), false)) + require.NoError(fileHandler.Write(tc.filename, []byte(tc.filecontent), file.OptNone)) } var out bytes.Buffer @@ -697,7 +697,7 @@ func TestAutoscaleFlag(t *testing.T) { fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) vpnHandler := stubVPNHandler{} - require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false)) + require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, file.OptNone)) // Write key file to filesystem and set path in flag. require.NoError(afero.Afero{Fs: fs}.WriteFile("privK", []byte(tc.privKey), 0o600)) diff --git a/cli/file/file.go b/cli/file/file.go index fe90149f0..d9abbd3c7 100644 --- a/cli/file/file.go +++ b/cli/file/file.go @@ -9,10 +9,28 @@ import ( "io" "io/fs" "os" + "path" "github.com/spf13/afero" ) +// Option is a bitmask of options for file operations. +type Option uint + +// Has determines if a set of options contains the given options. +func (o Option) Has(op Option) bool { + return o&op == op +} + +const ( + // OptNone is a no-op. + OptNone Option = 1 << iota / 2 + // OptOverwrite overwrites an existing file. + OptOverwrite + // OptMkdirAll creates the path to the file. + OptMkdirAll +) + // Handler handles file interaction. type Handler struct { fs *afero.Afero @@ -35,11 +53,14 @@ func (h *Handler) Read(name string) ([]byte, error) { } // Write writes the data bytes into the file with the given name. -// If a file already exists at path and overwrite is true, the file will be -// overwritten. Otherwise, an error is returned. -func (h *Handler) Write(name string, data []byte, overwrite bool) error { +func (h *Handler) Write(name string, data []byte, options Option) error { + if options.Has(OptMkdirAll) { + if err := h.fs.MkdirAll(path.Dir(name), os.ModePerm); err != nil { + return err + } + } flags := os.O_WRONLY | os.O_CREATE | os.O_EXCL - if overwrite { + if options.Has(OptOverwrite) { flags = os.O_WRONLY | os.O_CREATE | os.O_TRUNC } file, err := h.fs.OpenFile(name, flags, 0o644) @@ -64,14 +85,12 @@ func (h *Handler) ReadJSON(name string, content interface{}) error { } // WriteJSON marshals the content interface to JSON and writes it to the path with the given name. -// If a file already exists and overwrite is true, the file will be -// overwritten. Otherwise, an error is returned. -func (h *Handler) WriteJSON(name string, content interface{}, overwrite bool) error { +func (h *Handler) WriteJSON(name string, content interface{}, options Option) error { jsonData, err := json.MarshalIndent(content, "", "\t") if err != nil { return err } - return h.Write(name, jsonData, overwrite) + return h.Write(name, jsonData, options) } // Remove deletes the file with the given name. diff --git a/cli/file/file_test.go b/cli/file/file_test.go index 8fc036494..bac3ddb02 100644 --- a/cli/file/file_test.go +++ b/cli/file/file_test.go @@ -80,12 +80,12 @@ func TestWriteJSON(t *testing.T) { notMarshalableContent := struct{ Foo chan int }{Foo: make(chan int)} testCases := map[string]struct { - fs afero.Fs - setupFs func(af afero.Afero) error - name string - content interface{} - overwrite bool - wantErr bool + fs afero.Fs + setupFs func(af afero.Afero) error + name string + content interface{} + options Option + wantErr bool }{ "successful write": { fs: afero.NewMemMapFs(), @@ -93,11 +93,11 @@ func TestWriteJSON(t *testing.T) { content: someContent, }, "successful overwrite": { - fs: afero.NewMemMapFs(), - setupFs: func(af afero.Afero) error { return af.WriteFile("test/statefile", []byte{}, 0o644) }, - name: "test/statefile", - content: someContent, - overwrite: true, + fs: afero.NewMemMapFs(), + setupFs: func(af afero.Afero) error { return af.WriteFile("test/statefile", []byte{}, 0o644) }, + name: "test/statefile", + content: someContent, + options: OptOverwrite, }, "read only fs": { fs: afero.NewReadOnlyFs(afero.NewMemMapFs()), @@ -118,6 +118,15 @@ func TestWriteJSON(t *testing.T) { content: notMarshalableContent, wantErr: true, }, + "mkdirAll works": { + fs: afero.NewMemMapFs(), + name: "test/statefile", + content: someContent, + options: OptMkdirAll, + }, + // TODO: add tests for mkdirAll actually creating the necessary folders when https://github.com/spf13/afero/issues/270 is fixed. + // Currently, MemMapFs will create files in nonexistent directories due to a bug in afero, + // making it impossible to test the actual behavior of the mkdirAll parameter. } for name, tc := range testCases { @@ -131,9 +140,9 @@ func TestWriteJSON(t *testing.T) { } if tc.wantErr { - assert.Error(handler.WriteJSON(tc.name, tc.content, tc.overwrite)) + assert.Error(handler.WriteJSON(tc.name, tc.content, tc.options)) } else { - assert.NoError(handler.WriteJSON(tc.name, tc.content, tc.overwrite)) + assert.NoError(handler.WriteJSON(tc.name, tc.content, tc.options)) resultContent := &testContent{} assert.NoError(handler.ReadJSON(tc.name, resultContent)) assert.Equal(tc.content, *resultContent) diff --git a/coordinator/nodestate/nodestate.go b/coordinator/nodestate/nodestate.go index 34a47636f..56f8a7153 100644 --- a/coordinator/nodestate/nodestate.go +++ b/coordinator/nodestate/nodestate.go @@ -27,5 +27,5 @@ func FromFile(fileHandler file.Handler) (*NodeState, error) { // ToFile writes a NodeState to disk. func (nodeState *NodeState) ToFile(fileHandler file.Handler) error { - return fileHandler.WriteJSON(nodeStatePath, nodeState, false) + return fileHandler.WriteJSON(nodeStatePath, nodeState, false, true) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f31e3aa31..2cd0dad8c 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -77,7 +77,7 @@ func TestFromFile(t *testing.T) { require := require.New(t) fileHandler := file.NewHandler(afero.NewMemMapFs()) - require.NoError(fileHandler.WriteJSON(configName, tc.from, false)) + require.NoError(fileHandler.WriteJSON(configName, tc.from, file.OptNone)) result, err := FromFile(fileHandler, tc.configName)