From 10004875f4a8926d822a6c574ea5b1adc2431f54 Mon Sep 17 00:00:00 2001 From: katexochen <49727155+katexochen@users.noreply.github.com> Date: Fri, 7 Oct 2022 19:35:07 +0200 Subject: [PATCH] Add spinner interrrupt for rollback --- cli/internal/cmd/create.go | 12 +++--- cli/internal/cmd/create_test.go | 2 +- cli/internal/cmd/init.go | 9 +++-- cli/internal/cmd/init_test.go | 4 +- cli/internal/cmd/miniup.go | 12 ++++-- cli/internal/cmd/spinner.go | 65 +++++++++++++++++++++--------- cli/internal/cmd/spinner_test.go | 59 +++++++++++++++++++-------- cli/internal/cmd/terminate.go | 10 +++-- cli/internal/cmd/terminate_test.go | 2 +- 9 files changed, 119 insertions(+), 56 deletions(-) diff --git a/cli/internal/cmd/create.go b/cli/internal/cmd/create.go index 251f4f132..ac590cb10 100644 --- a/cli/internal/cmd/create.go +++ b/cli/internal/cmd/create.go @@ -42,12 +42,15 @@ func NewCreateCmd() *cobra.Command { func runCreate(cmd *cobra.Command, args []string) error { fileHandler := file.NewHandler(afero.NewOsFs()) - creator := cloudcmd.NewCreator(cmd.OutOrStdout()) + spinner, writer := newSpinner(cmd, cmd.OutOrStdout()) + defer spinner.Stop() + creator := cloudcmd.NewCreator(writer) - return create(cmd, creator, fileHandler) + return create(cmd, creator, fileHandler, spinner) } -func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler) (retErr error) { +func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler, spinner spinnerInterf, +) (retErr error) { flags, err := parseCreateFlags(cmd) if err != nil { return err @@ -114,8 +117,7 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler) } } - spinner := newSpinner(cmd, "Creating", false) - spinner.Start() + spinner.Start("Creating", false) state, err := creator.Create(cmd.Context(), provider, config, flags.name, instanceType, flags.controllerCount, flags.workerCount) spinner.Stop() if err != nil { diff --git a/cli/internal/cmd/create_test.go b/cli/internal/cmd/create_test.go index 1600166a2..150daa3e2 100644 --- a/cli/internal/cmd/create_test.go +++ b/cli/internal/cmd/create_test.go @@ -227,7 +227,7 @@ func TestCreate(t *testing.T) { fileHandler := file.NewHandler(tc.setupFs(require, tc.provider)) - err := create(cmd, tc.creator, fileHandler) + err := create(cmd, tc.creator, fileHandler, nopSpinner{}) if tc.wantErr { assert.Error(err) diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index d9969cc2f..3e21c3091 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -60,17 +60,19 @@ func runInitialize(cmd *cobra.Command, args []string) error { return dialer.New(nil, validator.V(cmd), &net.Dialer{}) } helmLoader := &helm.ChartLoader{} + spinner, _ := newSpinner(cmd, cmd.OutOrStdout()) + defer spinner.Stop() ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour) defer cancel() cmd.SetContext(ctx) - return initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient()) + return initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient(), spinner) } // initialize initializes a Constellation. func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer, - fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker, + fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker, spinner spinnerInterf, ) error { flags, err := evalFlagArgs(cmd, fileHandler) if err != nil { @@ -124,8 +126,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator return fmt.Errorf("parsing or generating master secret from file %s: %w", flags.masterSecretPath, err) } - spinner := newSpinner(cmd, "Initializing cluster ", false) - spinner.Start() + spinner.Start("Initializing cluster ", false) req := &initproto.InitRequest{ MasterSecret: masterSecret.Key, Salt: masterSecret.Salt, diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index a87e455c1..2dd270ea4 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -187,7 +187,7 @@ func TestInitialize(t *testing.T) { defer cancel() cmd.SetContext(ctx) - err := initialize(cmd, newDialer, fileHandler, &tc.helmLoader, &stubLicenseClient{}) + err := initialize(cmd, newDialer, fileHandler, &tc.helmLoader, &stubLicenseClient{}, nopSpinner{}) if tc.wantErr { assert.Error(err) @@ -423,7 +423,7 @@ func TestAttestation(t *testing.T) { defer cancel() cmd.SetContext(ctx) - err := initialize(cmd, newDialer, fileHandler, &stubHelmLoader{}, &stubLicenseClient{}) + err := initialize(cmd, newDialer, fileHandler, &stubHelmLoader{}, &stubLicenseClient{}, nopSpinner{}) assert.Error(err) // make sure the error is actually a TLS handshake error assert.Contains(err.Error(), "transport: authentication handshake failed") diff --git a/cli/internal/cmd/miniup.go b/cli/internal/cmd/miniup.go index dbfa32eae..b14ffdf8f 100644 --- a/cli/internal/cmd/miniup.go +++ b/cli/internal/cmd/miniup.go @@ -51,6 +51,13 @@ func newMiniUpCmd() *cobra.Command { } func runUp(cmd *cobra.Command, args []string) error { + spinner, _ := newSpinner(cmd, cmd.OutOrStdout()) + defer spinner.Stop() + + return up(cmd, spinner) +} + +func up(cmd *cobra.Command, spinner spinnerInterf) error { if err := checkSystemRequirements(cmd.OutOrStdout()); err != nil { return fmt.Errorf("system requirements not met: %w", err) } @@ -64,8 +71,7 @@ func runUp(cmd *cobra.Command, args []string) error { } // create cluster - spinner := newSpinner(cmd, "Creating cluster in QEMU ", false) - spinner.Start() + spinner.Start("Creating cluster in QEMU ", false) err = createMiniCluster(cmd.Context(), fileHandler, cloudcmd.NewCreator(cmd.OutOrStdout()), config) spinner.Stop() if err != nil { @@ -224,7 +230,7 @@ func initializeMiniCluster(cmd *cobra.Command, fileHandler file.Handler) (retErr cmd.Flags().String("endpoint", "", "") cmd.Flags().Bool("conformance", false, "") - if err := initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient()); err != nil { + if err := initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient(), nopSpinner{}); err != nil { return err } return nil diff --git a/cli/internal/cmd/spinner.go b/cli/internal/cmd/spinner.go index 5d0d817e7..e64556f27 100644 --- a/cli/internal/cmd/spinner.go +++ b/cli/internal/cmd/spinner.go @@ -8,6 +8,7 @@ package cmd import ( "fmt" + "io" "sync" "sync/atomic" "time" @@ -20,27 +21,38 @@ var ( dotsStates = []string{".", "..", "..."} ) +type spinnerInterf interface { + Start(text string, showDots bool) + Stop() +} + type spinner struct { - out *cobra.Command - text string - showDots bool - delay time.Duration - wg *sync.WaitGroup - stop int32 + out *cobra.Command + delay time.Duration + wg *sync.WaitGroup + stop int32 } -func newSpinner(c *cobra.Command, text string, showDots bool) *spinner { - return &spinner{ - out: c, - text: text, - showDots: showDots, - wg: &sync.WaitGroup{}, - delay: 100 * time.Millisecond, - stop: 0, +func newSpinner(c *cobra.Command, writer io.Writer) (*spinner, *interruptSpinWriter) { + spinner := &spinner{ + out: c, + wg: &sync.WaitGroup{}, + delay: 100 * time.Millisecond, + stop: 0, } + + if writer != nil { + interruptWriter := &interruptSpinWriter{ + writer: writer, + spinner: spinner, + } + return spinner, interruptWriter + } + + return spinner, nil } -func (s *spinner) Start() { +func (s *spinner) Start(text string, showDots bool) { s.wg.Add(1) go func() { defer s.wg.Done() @@ -50,19 +62,19 @@ func (s *spinner) Start() { break } dotsState := "" - if s.showDots { + if showDots { dotsState = dotsStates[i%len(dotsStates)] } - state := fmt.Sprintf("\r%s %s%s", spinnerStates[i], s.text, dotsState) + state := fmt.Sprintf("\r%s %s%s", spinnerStates[i], text, dotsState) s.out.Print(state) time.Sleep(s.delay) } dotsState := "" - if s.showDots { + if showDots { dotsState = dotsStates[len(dotsStates)-1] } - finalState := fmt.Sprintf("\r%s%s ", s.text, dotsState) + finalState := fmt.Sprintf("\r%s%s ", text, dotsState) s.out.Println(finalState) }() } @@ -71,3 +83,18 @@ func (s *spinner) Stop() { atomic.StoreInt32(&s.stop, 1) s.wg.Wait() } + +type interruptSpinWriter struct { + spinner *spinner + writer io.Writer +} + +func (w *interruptSpinWriter) Write(p []byte) (n int, err error) { + w.spinner.Stop() + return w.writer.Write(p) +} + +type nopSpinner struct{} + +func (s nopSpinner) Start(string, bool) {} +func (s nopSpinner) Stop() {} diff --git a/cli/internal/cmd/spinner_test.go b/cli/internal/cmd/spinner_test.go index 2692a5b56..712983599 100644 --- a/cli/internal/cmd/spinner_test.go +++ b/cli/internal/cmd/spinner_test.go @@ -10,10 +10,11 @@ import ( "bytes" "fmt" "strings" + "sync/atomic" "testing" "time" - "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" ) const ( @@ -22,58 +23,82 @@ const ( ) func TestSpinnerInitialState(t *testing.T) { + assert := assert.New(t) + cmd := NewInitCmd() var out bytes.Buffer cmd.SetOut(&out) var errOut bytes.Buffer cmd.SetErr(&errOut) - s := newSpinner(cmd, baseText, true) - s.Start() + s, _ := newSpinner(cmd, nil) + s.Start(baseText, true) time.Sleep(baseWait * time.Second) s.Stop() - require.True(t, out.Len() > 0) - require.True(t, errOut.Len() == 0) + assert.True(out.Len() > 0) + assert.True(errOut.Len() == 0) outStr := out.String() - require.True(t, strings.HasPrefix(outStr, generateAllStatesAsString(baseText, true))) + assert.True(strings.HasPrefix(outStr, generateAllStatesAsString(baseText, true))) } func TestSpinnerFinalState(t *testing.T) { + assert := assert.New(t) + cmd := NewInitCmd() var out bytes.Buffer cmd.SetOut(&out) var errOut bytes.Buffer cmd.SetErr(&errOut) - s := newSpinner(cmd, baseText, true) - s.Start() + s, _ := newSpinner(cmd, nil) + s.Start(baseText, true) time.Sleep(baseWait * time.Second) s.Stop() - require.True(t, out.Len() > 0) - require.True(t, errOut.Len() == 0) + assert.True(out.Len() > 0) + assert.True(errOut.Len() == 0) outStr := out.String() - require.True(t, strings.HasSuffix(outStr, baseText+"... \n")) + assert.True(strings.HasSuffix(outStr, baseText+"... \n")) } func TestSpinnerDisabledShowDotsFlag(t *testing.T) { + assert := assert.New(t) + cmd := NewInitCmd() var out bytes.Buffer cmd.SetOut(&out) var errOut bytes.Buffer cmd.SetErr(&errOut) - s := newSpinner(cmd, baseText, false) - s.Start() + s, _ := newSpinner(cmd, nil) + s.Start(baseText, false) time.Sleep(baseWait * time.Second) s.Stop() - require.True(t, out.Len() > 0) - require.True(t, errOut.Len() == 0) + assert.True(out.Len() > 0) + assert.True(errOut.Len() == 0) outStr := out.String() - require.True(t, strings.HasPrefix(outStr, generateAllStatesAsString(baseText, false))) - require.True(t, strings.HasSuffix(outStr, baseText+" \n")) + assert.True(strings.HasPrefix(outStr, generateAllStatesAsString(baseText, false))) + assert.True(strings.HasSuffix(outStr, baseText+" \n")) +} + +func TestSpinnerInterruptWriter(t *testing.T) { + assert := assert.New(t) + + cmd := NewInitCmd() + var out bytes.Buffer + cmd.SetOut(&out) + var errOut bytes.Buffer + cmd.SetErr(&errOut) + + s, interruptWriter := newSpinner(cmd, &out) + s.Start(baseText, false) + time.Sleep(200 * time.Millisecond) + _, err := interruptWriter.Write([]byte("test")) + assert.NoError(err) + assert.Equal(int32(1), atomic.LoadInt32(&s.stop)) + assert.True(strings.HasSuffix(out.String(), "test")) } func generateAllStatesAsString(text string, showDots bool) string { diff --git a/cli/internal/cmd/terminate.go b/cli/internal/cmd/terminate.go index c69dd02c3..a5fd08334 100644 --- a/cli/internal/cmd/terminate.go +++ b/cli/internal/cmd/terminate.go @@ -36,19 +36,21 @@ func NewTerminateCmd() *cobra.Command { // runTerminate runs the terminate command. func runTerminate(cmd *cobra.Command, args []string) error { fileHandler := file.NewHandler(afero.NewOsFs()) + spinner, _ := newSpinner(cmd, cmd.OutOrStdout()) + defer spinner.Stop() terminator := cloudcmd.NewTerminator() - return terminate(cmd, terminator, fileHandler) + return terminate(cmd, terminator, fileHandler, spinner) } -func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.Handler) error { +func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.Handler, spinner spinnerInterf, +) error { var stat state.ConstellationState if err := fileHandler.ReadJSON(constants.StateFilename, &stat); err != nil { return fmt.Errorf("reading Constellation state: %w", err) } - spinner := newSpinner(cmd, "Terminating", false) - spinner.Start() + spinner.Start("Terminating", false) err := terminator.Terminate(cmd.Context(), stat) spinner.Stop() if err != nil { diff --git a/cli/internal/cmd/terminate_test.go b/cli/internal/cmd/terminate_test.go index b8b011db8..0cca98423 100644 --- a/cli/internal/cmd/terminate_test.go +++ b/cli/internal/cmd/terminate_test.go @@ -119,7 +119,7 @@ func TestTerminate(t *testing.T) { require.NotNil(tc.setupFs) fileHandler := file.NewHandler(tc.setupFs(require, tc.state)) - err := terminate(cmd, tc.terminator, fileHandler) + err := terminate(cmd, tc.terminator, fileHandler, nopSpinner{}) if tc.wantErr { assert.Error(err)