Add warning about non retriable error during init (#644)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-11-25 10:02:12 +01:00 committed by GitHub
parent e76a87fcfc
commit 1968dfe70c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 3 deletions

View File

@ -9,6 +9,7 @@ package cmd
import ( import (
"context" "context"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -141,6 +142,11 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
resp, err := initCall(cmd.Context(), newDialer(validator), idFile.IP, req) resp, err := initCall(cmd.Context(), newDialer(validator), idFile.IP, req)
spinner.Stop() spinner.Stop()
if err != nil { if err != nil {
var nonRetriable *nonRetriableError
if errors.As(err, &nonRetriable) {
cmd.PrintErrln("Cluster initialization failed. This error is not recoverable.")
cmd.PrintErrln("Terminate your cluster and try again.")
}
return err return err
} }
@ -181,7 +187,7 @@ func (d *initDoer) Do(ctx context.Context) error {
protoClient := initproto.NewAPIClient(conn) protoClient := initproto.NewAPIClient(conn)
resp, err := protoClient.Init(ctx, d.req) resp, err := protoClient.Init(ctx, d.req)
if err != nil { if err != nil {
return fmt.Errorf("init call: %w", err) return &nonRetriableError{fmt.Errorf("init call: %w", err)}
} }
d.resp = resp d.resp = resp
return nil return nil
@ -339,3 +345,17 @@ func getMarshaledServiceAccountURI(provider cloudprovider.Provider, config *conf
type grpcDialer interface { type grpcDialer interface {
Dial(ctx context.Context, target string) (*grpc.ClientConn, error) Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
} }
type nonRetriableError struct {
err error
}
// Error returns the error message.
func (e *nonRetriableError) Error() string {
return e.err.Error()
}
// Unwrap returns the wrapped error.
func (e *nonRetriableError) Unwrap() error {
return e.err
}

View File

@ -65,6 +65,7 @@ func TestInitialize(t *testing.T) {
configMutator func(*config.Config) configMutator func(*config.Config)
serviceAccKey *gcpshared.ServiceAccountKey serviceAccKey *gcpshared.ServiceAccountKey
initServerAPI *stubInitServer initServerAPI *stubInitServer
retriable bool
masterSecretShouldExist bool masterSecretShouldExist bool
wantErr bool wantErr bool
}{ }{
@ -85,20 +86,31 @@ func TestInitialize(t *testing.T) {
idFile: &clusterid.File{IP: "192.0.2.1"}, idFile: &clusterid.File{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{initResp: testInitResp}, initServerAPI: &stubInitServer{initResp: testInitResp},
}, },
"non retriable error": {
provider: cloudprovider.QEMU,
idFile: &clusterid.File{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{initErr: &nonRetriableError{someErr}},
retriable: false,
masterSecretShouldExist: true,
wantErr: true,
},
"empty id file": { "empty id file": {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
idFile: &clusterid.File{}, idFile: &clusterid.File{},
initServerAPI: &stubInitServer{}, initServerAPI: &stubInitServer{},
retriable: true,
wantErr: true, wantErr: true,
}, },
"no id file": { "no id file": {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
wantErr: true, retriable: true,
wantErr: true,
}, },
"init call fails": { "init call fails": {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
idFile: &clusterid.File{IP: "192.0.2.1"}, idFile: &clusterid.File{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{initErr: someErr}, initServerAPI: &stubInitServer{initErr: someErr},
retriable: true,
wantErr: true, wantErr: true,
}, },
} }
@ -156,6 +168,11 @@ func TestInitialize(t *testing.T) {
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
if !tc.retriable {
assert.Contains(errOut.String(), "This error is not recoverable")
} else {
assert.Empty(errOut.String())
}
if !tc.masterSecretShouldExist { if !tc.masterSecretShouldExist {
_, err = fileHandler.Stat(constants.MasterSecretFilename) _, err = fileHandler.Stat(constants.MasterSecretFilename)
assert.Error(err) assert.Error(err)