Add spinner interrrupt for rollback

This commit is contained in:
katexochen 2022-10-07 19:35:07 +02:00 committed by Paul Meyer
parent 75439344c9
commit 10004875f4
9 changed files with 119 additions and 56 deletions

View File

@ -42,12 +42,15 @@ func NewCreateCmd() *cobra.Command {
func runCreate(cmd *cobra.Command, args []string) error { func runCreate(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
creator := cloudcmd.NewCreator(cmd.OutOrStdout()) spinner, writer := newSpinner(cmd, cmd.OutOrStdout())
defer spinner.Stop()
creator := cloudcmd.NewCreator(writer)
return create(cmd, creator, fileHandler) return create(cmd, creator, fileHandler, spinner)
} }
func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler) (retErr error) { func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler, spinner spinnerInterf,
) (retErr error) {
flags, err := parseCreateFlags(cmd) flags, err := parseCreateFlags(cmd)
if err != nil { if err != nil {
return err return err
@ -114,8 +117,7 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler)
} }
} }
spinner := newSpinner(cmd, "Creating", false) spinner.Start("Creating", false)
spinner.Start()
state, err := creator.Create(cmd.Context(), provider, config, flags.name, instanceType, flags.controllerCount, flags.workerCount) state, err := creator.Create(cmd.Context(), provider, config, flags.name, instanceType, flags.controllerCount, flags.workerCount)
spinner.Stop() spinner.Stop()
if err != nil { if err != nil {

View File

@ -227,7 +227,7 @@ func TestCreate(t *testing.T) {
fileHandler := file.NewHandler(tc.setupFs(require, tc.provider)) fileHandler := file.NewHandler(tc.setupFs(require, tc.provider))
err := create(cmd, tc.creator, fileHandler) err := create(cmd, tc.creator, fileHandler, nopSpinner{})
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)

View File

@ -60,17 +60,19 @@ func runInitialize(cmd *cobra.Command, args []string) error {
return dialer.New(nil, validator.V(cmd), &net.Dialer{}) return dialer.New(nil, validator.V(cmd), &net.Dialer{})
} }
helmLoader := &helm.ChartLoader{} helmLoader := &helm.ChartLoader{}
spinner, _ := newSpinner(cmd, cmd.OutOrStdout())
defer spinner.Stop()
ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour) ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour)
defer cancel() defer cancel()
cmd.SetContext(ctx) cmd.SetContext(ctx)
return initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient()) return initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient(), spinner)
} }
// initialize initializes a Constellation. // initialize initializes a Constellation.
func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer, func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer,
fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker, fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker, spinner spinnerInterf,
) error { ) error {
flags, err := evalFlagArgs(cmd, fileHandler) flags, err := evalFlagArgs(cmd, fileHandler)
if err != nil { if err != nil {
@ -124,8 +126,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return fmt.Errorf("parsing or generating master secret from file %s: %w", flags.masterSecretPath, err) return fmt.Errorf("parsing or generating master secret from file %s: %w", flags.masterSecretPath, err)
} }
spinner := newSpinner(cmd, "Initializing cluster ", false) spinner.Start("Initializing cluster ", false)
spinner.Start()
req := &initproto.InitRequest{ req := &initproto.InitRequest{
MasterSecret: masterSecret.Key, MasterSecret: masterSecret.Key,
Salt: masterSecret.Salt, Salt: masterSecret.Salt,

View File

@ -187,7 +187,7 @@ func TestInitialize(t *testing.T) {
defer cancel() defer cancel()
cmd.SetContext(ctx) cmd.SetContext(ctx)
err := initialize(cmd, newDialer, fileHandler, &tc.helmLoader, &stubLicenseClient{}) err := initialize(cmd, newDialer, fileHandler, &tc.helmLoader, &stubLicenseClient{}, nopSpinner{})
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -423,7 +423,7 @@ func TestAttestation(t *testing.T) {
defer cancel() defer cancel()
cmd.SetContext(ctx) cmd.SetContext(ctx)
err := initialize(cmd, newDialer, fileHandler, &stubHelmLoader{}, &stubLicenseClient{}) err := initialize(cmd, newDialer, fileHandler, &stubHelmLoader{}, &stubLicenseClient{}, nopSpinner{})
assert.Error(err) assert.Error(err)
// make sure the error is actually a TLS handshake error // make sure the error is actually a TLS handshake error
assert.Contains(err.Error(), "transport: authentication handshake failed") assert.Contains(err.Error(), "transport: authentication handshake failed")

View File

@ -51,6 +51,13 @@ func newMiniUpCmd() *cobra.Command {
} }
func runUp(cmd *cobra.Command, args []string) error { func runUp(cmd *cobra.Command, args []string) error {
spinner, _ := newSpinner(cmd, cmd.OutOrStdout())
defer spinner.Stop()
return up(cmd, spinner)
}
func up(cmd *cobra.Command, spinner spinnerInterf) error {
if err := checkSystemRequirements(cmd.OutOrStdout()); err != nil { if err := checkSystemRequirements(cmd.OutOrStdout()); err != nil {
return fmt.Errorf("system requirements not met: %w", err) return fmt.Errorf("system requirements not met: %w", err)
} }
@ -64,8 +71,7 @@ func runUp(cmd *cobra.Command, args []string) error {
} }
// create cluster // create cluster
spinner := newSpinner(cmd, "Creating cluster in QEMU ", false) spinner.Start("Creating cluster in QEMU ", false)
spinner.Start()
err = createMiniCluster(cmd.Context(), fileHandler, cloudcmd.NewCreator(cmd.OutOrStdout()), config) err = createMiniCluster(cmd.Context(), fileHandler, cloudcmd.NewCreator(cmd.OutOrStdout()), config)
spinner.Stop() spinner.Stop()
if err != nil { if err != nil {
@ -224,7 +230,7 @@ func initializeMiniCluster(cmd *cobra.Command, fileHandler file.Handler) (retErr
cmd.Flags().String("endpoint", "", "") cmd.Flags().String("endpoint", "", "")
cmd.Flags().Bool("conformance", false, "") cmd.Flags().Bool("conformance", false, "")
if err := initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient()); err != nil { if err := initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient(), nopSpinner{}); err != nil {
return err return err
} }
return nil return nil

View File

@ -8,6 +8,7 @@ package cmd
import ( import (
"fmt" "fmt"
"io"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -20,27 +21,38 @@ var (
dotsStates = []string{".", "..", "..."} dotsStates = []string{".", "..", "..."}
) )
type spinnerInterf interface {
Start(text string, showDots bool)
Stop()
}
type spinner struct { type spinner struct {
out *cobra.Command out *cobra.Command
text string delay time.Duration
showDots bool wg *sync.WaitGroup
delay time.Duration stop int32
wg *sync.WaitGroup
stop int32
} }
func newSpinner(c *cobra.Command, text string, showDots bool) *spinner { func newSpinner(c *cobra.Command, writer io.Writer) (*spinner, *interruptSpinWriter) {
return &spinner{ spinner := &spinner{
out: c, out: c,
text: text, wg: &sync.WaitGroup{},
showDots: showDots, delay: 100 * time.Millisecond,
wg: &sync.WaitGroup{}, stop: 0,
delay: 100 * time.Millisecond,
stop: 0,
} }
if writer != nil {
interruptWriter := &interruptSpinWriter{
writer: writer,
spinner: spinner,
}
return spinner, interruptWriter
}
return spinner, nil
} }
func (s *spinner) Start() { func (s *spinner) Start(text string, showDots bool) {
s.wg.Add(1) s.wg.Add(1)
go func() { go func() {
defer s.wg.Done() defer s.wg.Done()
@ -50,19 +62,19 @@ func (s *spinner) Start() {
break break
} }
dotsState := "" dotsState := ""
if s.showDots { if showDots {
dotsState = dotsStates[i%len(dotsStates)] dotsState = dotsStates[i%len(dotsStates)]
} }
state := fmt.Sprintf("\r%s %s%s", spinnerStates[i], s.text, dotsState) state := fmt.Sprintf("\r%s %s%s", spinnerStates[i], text, dotsState)
s.out.Print(state) s.out.Print(state)
time.Sleep(s.delay) time.Sleep(s.delay)
} }
dotsState := "" dotsState := ""
if s.showDots { if showDots {
dotsState = dotsStates[len(dotsStates)-1] dotsState = dotsStates[len(dotsStates)-1]
} }
finalState := fmt.Sprintf("\r%s%s ", s.text, dotsState) finalState := fmt.Sprintf("\r%s%s ", text, dotsState)
s.out.Println(finalState) s.out.Println(finalState)
}() }()
} }
@ -71,3 +83,18 @@ func (s *spinner) Stop() {
atomic.StoreInt32(&s.stop, 1) atomic.StoreInt32(&s.stop, 1)
s.wg.Wait() s.wg.Wait()
} }
type interruptSpinWriter struct {
spinner *spinner
writer io.Writer
}
func (w *interruptSpinWriter) Write(p []byte) (n int, err error) {
w.spinner.Stop()
return w.writer.Write(p)
}
type nopSpinner struct{}
func (s nopSpinner) Start(string, bool) {}
func (s nopSpinner) Stop() {}

View File

@ -10,10 +10,11 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/assert"
) )
const ( const (
@ -22,58 +23,82 @@ const (
) )
func TestSpinnerInitialState(t *testing.T) { func TestSpinnerInitialState(t *testing.T) {
assert := assert.New(t)
cmd := NewInitCmd() cmd := NewInitCmd()
var out bytes.Buffer var out bytes.Buffer
cmd.SetOut(&out) cmd.SetOut(&out)
var errOut bytes.Buffer var errOut bytes.Buffer
cmd.SetErr(&errOut) cmd.SetErr(&errOut)
s := newSpinner(cmd, baseText, true) s, _ := newSpinner(cmd, nil)
s.Start() s.Start(baseText, true)
time.Sleep(baseWait * time.Second) time.Sleep(baseWait * time.Second)
s.Stop() s.Stop()
require.True(t, out.Len() > 0) assert.True(out.Len() > 0)
require.True(t, errOut.Len() == 0) assert.True(errOut.Len() == 0)
outStr := out.String() outStr := out.String()
require.True(t, strings.HasPrefix(outStr, generateAllStatesAsString(baseText, true))) assert.True(strings.HasPrefix(outStr, generateAllStatesAsString(baseText, true)))
} }
func TestSpinnerFinalState(t *testing.T) { func TestSpinnerFinalState(t *testing.T) {
assert := assert.New(t)
cmd := NewInitCmd() cmd := NewInitCmd()
var out bytes.Buffer var out bytes.Buffer
cmd.SetOut(&out) cmd.SetOut(&out)
var errOut bytes.Buffer var errOut bytes.Buffer
cmd.SetErr(&errOut) cmd.SetErr(&errOut)
s := newSpinner(cmd, baseText, true) s, _ := newSpinner(cmd, nil)
s.Start() s.Start(baseText, true)
time.Sleep(baseWait * time.Second) time.Sleep(baseWait * time.Second)
s.Stop() s.Stop()
require.True(t, out.Len() > 0) assert.True(out.Len() > 0)
require.True(t, errOut.Len() == 0) assert.True(errOut.Len() == 0)
outStr := out.String() outStr := out.String()
require.True(t, strings.HasSuffix(outStr, baseText+"... \n")) assert.True(strings.HasSuffix(outStr, baseText+"... \n"))
} }
func TestSpinnerDisabledShowDotsFlag(t *testing.T) { func TestSpinnerDisabledShowDotsFlag(t *testing.T) {
assert := assert.New(t)
cmd := NewInitCmd() cmd := NewInitCmd()
var out bytes.Buffer var out bytes.Buffer
cmd.SetOut(&out) cmd.SetOut(&out)
var errOut bytes.Buffer var errOut bytes.Buffer
cmd.SetErr(&errOut) cmd.SetErr(&errOut)
s := newSpinner(cmd, baseText, false) s, _ := newSpinner(cmd, nil)
s.Start() s.Start(baseText, false)
time.Sleep(baseWait * time.Second) time.Sleep(baseWait * time.Second)
s.Stop() s.Stop()
require.True(t, out.Len() > 0) assert.True(out.Len() > 0)
require.True(t, errOut.Len() == 0) assert.True(errOut.Len() == 0)
outStr := out.String() outStr := out.String()
require.True(t, strings.HasPrefix(outStr, generateAllStatesAsString(baseText, false))) assert.True(strings.HasPrefix(outStr, generateAllStatesAsString(baseText, false)))
require.True(t, strings.HasSuffix(outStr, baseText+" \n")) assert.True(strings.HasSuffix(outStr, baseText+" \n"))
}
func TestSpinnerInterruptWriter(t *testing.T) {
assert := assert.New(t)
cmd := NewInitCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
s, interruptWriter := newSpinner(cmd, &out)
s.Start(baseText, false)
time.Sleep(200 * time.Millisecond)
_, err := interruptWriter.Write([]byte("test"))
assert.NoError(err)
assert.Equal(int32(1), atomic.LoadInt32(&s.stop))
assert.True(strings.HasSuffix(out.String(), "test"))
} }
func generateAllStatesAsString(text string, showDots bool) string { func generateAllStatesAsString(text string, showDots bool) string {

View File

@ -36,19 +36,21 @@ func NewTerminateCmd() *cobra.Command {
// runTerminate runs the terminate command. // runTerminate runs the terminate command.
func runTerminate(cmd *cobra.Command, args []string) error { func runTerminate(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
spinner, _ := newSpinner(cmd, cmd.OutOrStdout())
defer spinner.Stop()
terminator := cloudcmd.NewTerminator() terminator := cloudcmd.NewTerminator()
return terminate(cmd, terminator, fileHandler) return terminate(cmd, terminator, fileHandler, spinner)
} }
func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.Handler) error { func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.Handler, spinner spinnerInterf,
) error {
var stat state.ConstellationState var stat state.ConstellationState
if err := fileHandler.ReadJSON(constants.StateFilename, &stat); err != nil { if err := fileHandler.ReadJSON(constants.StateFilename, &stat); err != nil {
return fmt.Errorf("reading Constellation state: %w", err) return fmt.Errorf("reading Constellation state: %w", err)
} }
spinner := newSpinner(cmd, "Terminating", false) spinner.Start("Terminating", false)
spinner.Start()
err := terminator.Terminate(cmd.Context(), stat) err := terminator.Terminate(cmd.Context(), stat)
spinner.Stop() spinner.Stop()
if err != nil { if err != nil {

View File

@ -119,7 +119,7 @@ func TestTerminate(t *testing.T) {
require.NotNil(tc.setupFs) require.NotNil(tc.setupFs)
fileHandler := file.NewHandler(tc.setupFs(require, tc.state)) fileHandler := file.NewHandler(tc.setupFs(require, tc.state))
err := terminate(cmd, tc.terminator, fileHandler) err := terminate(cmd, tc.terminator, fileHandler, nopSpinner{})
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)