diff --git a/cli/internal/azure/loadbalancer.go b/cli/internal/azure/loadbalancer.go index d56e5b7a5..3bcbe9f13 100644 --- a/cli/internal/azure/loadbalancer.go +++ b/cli/internal/azure/loadbalancer.go @@ -98,8 +98,9 @@ func (l LoadBalancer) Azure() armnetwork.LoadBalancer { { Name: to.Ptr(recoveryHealthProbeName), Properties: &armnetwork.ProbePropertiesFormat{ - Protocol: to.Ptr(armnetwork.ProbeProtocolTCP), - Port: to.Ptr[int32](constants.RecoveryPort), + Protocol: to.Ptr(armnetwork.ProbeProtocolTCP), + Port: to.Ptr[int32](constants.RecoveryPort), + IntervalInSeconds: to.Ptr[int32](5), }, }, }, diff --git a/cli/internal/cmd/recover.go b/cli/internal/cmd/recover.go index abb4ecf35..ae1af2b95 100644 --- a/cli/internal/cmd/recover.go +++ b/cli/internal/cmd/recover.go @@ -12,6 +12,7 @@ import ( "fmt" "io" "net" + "sync" "time" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" @@ -24,10 +25,8 @@ import ( "github.com/edgelesssys/constellation/v2/internal/grpc/dialer" grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry" "github.com/edgelesssys/constellation/v2/internal/retry" - "github.com/edgelesssys/constellation/v2/internal/state" "github.com/spf13/afero" "github.com/spf13/cobra" - "go.uber.org/multierr" ) // NewRecoverCmd returns a new cobra.Command for the recover command. @@ -40,8 +39,7 @@ func NewRecoverCmd() *cobra.Command { Args: cobra.ExactArgs(0), RunE: runRecover, } - cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT] (required)") - must(cmd.MarkFlagRequired("endpoint")) + cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT]") cmd.Flags().String("master-secret", constants.MasterSecretFilename, "path to master secret file") return cmd } @@ -51,11 +49,14 @@ func runRecover(cmd *cobra.Command, _ []string) error { newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer { return dialer.New(nil, validator.V(cmd), &net.Dialer{}) } - return recover(cmd, fileHandler, newDialer) + return recover(cmd, fileHandler, 5*time.Second, &recoverDoer{}, newDialer) } -func recover(cmd *cobra.Command, fileHandler file.Handler, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer) error { - flags, err := parseRecoverFlags(cmd) +func recover( + cmd *cobra.Command, fileHandler file.Handler, interval time.Duration, + doer recoverDoerInterface, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer, +) error { + flags, err := parseRecoverFlags(cmd, fileHandler) if err != nil { return err } @@ -65,48 +66,81 @@ func recover(cmd *cobra.Command, fileHandler file.Handler, newDialer func(valida return err } - var stat state.ConstellationState - if err := fileHandler.ReadJSON(constants.StateFilename, &stat); err != nil { - return err - } - - provider := cloudprovider.FromString(stat.CloudProvider) config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath) if err != nil { return fmt.Errorf("reading and validating config: %w", err) } + provider := config.GetProvider() + if provider == cloudprovider.Azure { + interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances + } validator, err := cloudcmd.NewValidator(provider, config) if err != nil { return err } + doer.setDialer(newDialer(validator), flags.endpoint) - if err := recoverCall(cmd.Context(), newDialer(validator), flags.endpoint, masterSecret.Key, masterSecret.Salt); err != nil { - return fmt.Errorf("recovering cluster: %w", err) - } - - cmd.Println("Pushed recovery key.") - return nil -} - -func recoverCall(ctx context.Context, dialer grpcDialer, endpoint string, key, salt []byte) error { - measurementSecret, err := attestation.DeriveMeasurementSecret(key, salt) + measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt) if err != nil { return err } - doer := &recoverDoer{ - dialer: dialer, - endpoint: endpoint, - getDiskKey: getStateDiskKeyFunc(key, salt), - measurementSecret: measurementSecret, - } - retrier := retry.NewIntervalRetrier(doer, 30*time.Second, grpcRetry.ServiceIsUnavailable) - if err := retrier.Do(ctx); err != nil { - return err + doer.setSecrets(getStateDiskKeyFunc(masterSecret.Key, masterSecret.Salt), measurementSecret) + + if err := recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil { + if grpcRetry.ServiceIsUnavailable(err) { + return nil + } + return fmt.Errorf("recovering cluster: %w", err) } return nil } +func recoverCall(ctx context.Context, out io.Writer, interval time.Duration, doer recoverDoerInterface) error { + var err error + ctr := 0 + for { + once := sync.Once{} + retryOnceOnFailure := func(err error) bool { + // retry transient GCP LB errors + if grpcRetry.LoadbalancerIsNotReady(err) { + return true + } + retry := false + + // retry connection errors once + // this is necessary because Azure's LB takes a while to remove unhealthy instances + once.Do(func() { + retry = grpcRetry.ServiceIsUnavailable(err) + }) + return retry + } + + retrier := retry.NewIntervalRetrier(doer, interval, retryOnceOnFailure) + err = retrier.Do(ctx) + if err != nil { + break + } + fmt.Fprintln(out, "Pushed recovery key.") + ctr++ + } + + if ctr > 0 { + fmt.Fprintf(out, "Recovered %d control-plane nodes.\n", ctr) + } else if grpcRetry.ServiceIsUnavailable(err) { + fmt.Fprintln(out, "No control-plane nodes in need of recovery found. Exiting.") + return nil + } + + return err +} + +type recoverDoerInterface interface { + Do(ctx context.Context) error + setDialer(dialer grpcDialer, endpoint string) + setSecrets(getDiskKey func(uuid string) ([]byte, error), measurementSecret []byte) +} + type recoverDoer struct { dialer grpcDialer endpoint string @@ -114,6 +148,7 @@ type recoverDoer struct { getDiskKey func(uuid string) (key []byte, err error) } +// Do performs the recover streaming rpc. func (d *recoverDoer) Do(ctx context.Context) (retErr error) { conn, err := d.dialer.Dial(ctx, d.endpoint) if err != nil { @@ -125,12 +160,10 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) { protoClient := recoverproto.NewAPIClient(conn) recoverclient, err := protoClient.Recover(ctx) if err != nil { - return err + return fmt.Errorf("creating client: %w", err) } defer func() { - if err := recoverclient.CloseSend(); err != nil { - multierr.AppendInto(&retErr, err) - } + _ = recoverclient.CloseSend() }() // send measurement secret as first message @@ -139,17 +172,17 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) { MeasurementSecret: d.measurementSecret, }, }); err != nil { - return err + return fmt.Errorf("sending measurement secret: %w", err) } // receive disk uuid res, err := recoverclient.Recv() if err != nil { - return err + return fmt.Errorf("receiving disk uuid: %w", err) } stateDiskKey, err := d.getDiskKey(res.DiskUuid) if err != nil { - return err + return fmt.Errorf("getting state disk key: %w", err) } // send disk key @@ -158,20 +191,42 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) { StateDiskKey: stateDiskKey, }, }); err != nil { - return err + return fmt.Errorf("sending state disk key: %w", err) } if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) { - return err + return fmt.Errorf("receiving confirmation: %w", err) } return nil } -func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) { +func (d *recoverDoer) setDialer(dialer grpcDialer, endpoint string) { + d.dialer = dialer + d.endpoint = endpoint +} + +func (d *recoverDoer) setSecrets(getDiskKey func(string) ([]byte, error), measurementSecret []byte) { + d.getDiskKey = getDiskKey + d.measurementSecret = measurementSecret +} + +type recoverFlags struct { + endpoint string + secretPath string + configPath string +} + +func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) { endpoint, err := cmd.Flags().GetString("endpoint") if err != nil { return recoverFlags{}, fmt.Errorf("parsing endpoint argument: %w", err) } + if endpoint == "" { + endpoint, err = readIPFromIDFile(fileHandler) + if err != nil { + return recoverFlags{}, fmt.Errorf("getting recovery endpoint: %w", err) + } + } endpoint, err = addPortIfMissing(endpoint, constants.RecoveryPort) if err != nil { return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err) @@ -194,12 +249,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) { }, nil } -type recoverFlags struct { - endpoint string - secretPath string - configPath string -} - func getStateDiskKeyFunc(masterKey, salt []byte) func(uuid string) ([]byte, error) { return func(uuid string) ([]byte, error) { return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.StateDiskKeyLength) diff --git a/cli/internal/cmd/recover_test.go b/cli/internal/cmd/recover_test.go index 73b8ffa11..e668c29fa 100644 --- a/cli/internal/cmd/recover_test.go +++ b/cli/internal/cmd/recover_test.go @@ -13,6 +13,7 @@ import ( "net" "strconv" "testing" + "time" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto" @@ -24,11 +25,12 @@ import ( "github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/grpc/testdialer" - "github.com/edgelesssys/constellation/v2/internal/state" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func TestRecoverCmdArgumentValidation(t *testing.T) { @@ -57,65 +59,73 @@ func TestRecoverCmdArgumentValidation(t *testing.T) { } func TestRecover(t *testing.T) { - validState := state.ConstellationState{CloudProvider: "GCP"} - invalidCSPState := state.ConstellationState{CloudProvider: "invalid"} - successActions := []func(stream recoverproto.API_RecoverServer) error{ - func(stream recoverproto.API_RecoverServer) error { - _, err := stream.Recv() - return err - }, - func(stream recoverproto.API_RecoverServer) error { - return stream.Send(&recoverproto.RecoverResponse{ - DiskUuid: "00000000-0000-0000-0000-000000000000", - }) - }, - func(stream recoverproto.API_RecoverServer) error { - _, err := stream.Recv() - return err - }, - } + someErr := errors.New("error") + unavailableErr := status.Error(codes.Unavailable, "unavailable") + lbErr := status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed: read tcp`) testCases := map[string]struct { - existingState state.ConstellationState - recoverServerAPI *stubRecoveryServer - masterSecret testvector.HKDF - endpointFlag string - masterSecretFlag string - configFlag string - stateless bool - wantErr bool + doer *stubDoer + masterSecret testvector.HKDF + endpoint string + configFlag string + successfulCalls int + wantErr bool }{ "works": { - existingState: validState, - recoverServerAPI: &stubRecoveryServer{actions: successActions}, - endpointFlag: "192.0.2.1", - masterSecret: testvector.HKDFZero, - }, - "missing flags": { - recoverServerAPI: &stubRecoveryServer{actions: successActions}, - wantErr: true, + doer: &stubDoer{returns: []error{nil}}, + endpoint: "192.0.2.90", + masterSecret: testvector.HKDFZero, + successfulCalls: 1, }, "missing config": { - recoverServerAPI: &stubRecoveryServer{actions: successActions}, - endpointFlag: "192.0.2.1", - masterSecret: testvector.HKDFZero, - configFlag: "nonexistent-config", - wantErr: true, + doer: &stubDoer{returns: []error{nil}}, + endpoint: "192.0.2.89", + masterSecret: testvector.HKDFZero, + configFlag: "nonexistent-config", + wantErr: true, }, - "missing state": { - existingState: validState, - recoverServerAPI: &stubRecoveryServer{actions: successActions}, - endpointFlag: "192.0.2.1", - masterSecret: testvector.HKDFZero, - stateless: true, - wantErr: true, + "success multiple nodes": { + doer: &stubDoer{returns: []error{nil, nil}}, + endpoint: "192.0.2.90", + masterSecret: testvector.HKDFZero, + successfulCalls: 2, }, - "invalid cloud provider": { - existingState: invalidCSPState, - recoverServerAPI: &stubRecoveryServer{actions: successActions}, - endpointFlag: "192.0.2.1", - masterSecret: testvector.HKDFZero, - wantErr: true, + "no nodes to recover does not error": { + doer: &stubDoer{returns: []error{unavailableErr}}, + endpoint: "192.0.2.90", + masterSecret: testvector.HKDFZero, + successfulCalls: 0, + }, + "error on first node": { + doer: &stubDoer{returns: []error{someErr, nil}}, + endpoint: "192.0.2.90", + masterSecret: testvector.HKDFZero, + successfulCalls: 0, + wantErr: true, + }, + "unavailable error is retried once": { + doer: &stubDoer{returns: []error{unavailableErr, nil}}, + endpoint: "192.0.2.90", + masterSecret: testvector.HKDFZero, + successfulCalls: 1, + }, + "unavailable error is not retried twice": { + doer: &stubDoer{returns: []error{unavailableErr, unavailableErr, nil}}, + endpoint: "192.0.2.90", + masterSecret: testvector.HKDFZero, + successfulCalls: 0, + }, + "unavailable error is not retried twice after success": { + doer: &stubDoer{returns: []error{nil, unavailableErr, unavailableErr, nil}}, + endpoint: "192.0.2.90", + masterSecret: testvector.HKDFZero, + successfulCalls: 1, + }, + "transient LB errors are retried": { + doer: &stubDoer{returns: []error{lbErr, lbErr, lbErr, nil}}, + endpoint: "192.0.2.90", + masterSecret: testvector.HKDFZero, + successfulCalls: 1, }, } @@ -129,13 +139,9 @@ func TestRecover(t *testing.T) { cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually out := &bytes.Buffer{} cmd.SetOut(out) - cmd.SetErr(&bytes.Buffer{}) - if tc.endpointFlag != "" { - require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag)) - } - if tc.masterSecretFlag != "" { - require.NoError(cmd.Flags().Set("master-secret", tc.masterSecretFlag)) - } + cmd.SetErr(out) + require.NoError(cmd.Flags().Set("endpoint", tc.endpoint)) + if tc.configFlag != "" { require.NoError(cmd.Flags().Set("config", tc.configFlag)) } @@ -143,7 +149,7 @@ func TestRecover(t *testing.T) { fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) - config := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.FromString(tc.existingState.CloudProvider)) + config := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, config)) require.NoError(fileHandler.WriteJSON( @@ -152,62 +158,57 @@ func TestRecover(t *testing.T) { file.OptNone, )) - if !tc.stateless { - require.NoError(fileHandler.WriteJSON( - constants.StateFilename, - tc.existingState, - file.OptNone, - )) - } - - netDialer := testdialer.NewBufconnDialer() - newDialer := func(*cloudcmd.Validator) *dialer.Dialer { - return dialer.New(nil, nil, netDialer) - } - serverCreds := atlscredentials.New(nil, nil) - recoverServer := grpc.NewServer(grpc.Creds(serverCreds)) - recoverproto.RegisterAPIServer(recoverServer, tc.recoverServerAPI) - listener := netDialer.GetListener(net.JoinHostPort("192.0.2.1", strconv.Itoa(constants.RecoveryPort))) - go recoverServer.Serve(listener) - defer recoverServer.GracefulStop() - - err := recover(cmd, fileHandler, newDialer) + newDialer := func(*cloudcmd.Validator) *dialer.Dialer { return nil } + err := recover(cmd, fileHandler, time.Millisecond, tc.doer, newDialer) if tc.wantErr { assert.Error(err) + if tc.successfulCalls > 0 { + assert.Contains(out.String(), strconv.Itoa(tc.successfulCalls)) + } return } assert.NoError(err) - assert.Contains(out.String(), "Pushed recovery key.") + if tc.successfulCalls > 0 { + assert.Contains(out.String(), "Pushed recovery key.") + assert.Contains(out.String(), strconv.Itoa(tc.successfulCalls)) + } else { + assert.Contains(out.String(), "No control-plane nodes in need of recovery found.") + } }) } } func TestParseRecoverFlags(t *testing.T) { testCases := map[string]struct { - args []string - wantFlags recoverFlags - wantErr bool + args []string + wantFlags recoverFlags + writeIDFile bool + wantErr bool }{ "no flags": { - wantErr: true, - }, - "invalid ip": { - args: []string{"-e", "192.0.2.1:2:2"}, - wantErr: true, - }, - "minimal args set": { - args: []string{"-e", "192.0.2.1:2"}, wantFlags: recoverFlags{ - endpoint: "192.0.2.1:2", + endpoint: "192.0.2.42:9999", secretPath: "constellation-mastersecret.json", }, + writeIDFile: true, + }, + "no flags, no ID file": { + wantFlags: recoverFlags{ + endpoint: "192.0.2.42:9999", + secretPath: "constellation-mastersecret.json", + }, + wantErr: true, + }, + "invalid endpoint": { + args: []string{"-e", "192.0.2.42:2:2"}, + wantErr: true, }, "all args set": { - args: []string{"-e", "192.0.2.1:2", "--config", "config-path", "--master-secret", "/path/super-secret.json"}, + args: []string{"-e", "192.0.2.42:2", "--config", "config-path", "--master-secret", "/path/super-secret.json"}, wantFlags: recoverFlags{ - endpoint: "192.0.2.1:2", + endpoint: "192.0.2.42:2", secretPath: "/path/super-secret.json", configPath: "config-path", }, @@ -222,7 +223,13 @@ func TestParseRecoverFlags(t *testing.T) { cmd := NewRecoverCmd() cmd.Flags().String("config", "", "") // register persistent flag manually require.NoError(cmd.ParseFlags(tc.args)) - flags, err := parseRecoverFlags(cmd) + + fileHandler := file.NewHandler(afero.NewMemMapFs()) + if tc.writeIDFile { + require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, &clusterIDsFile{IP: "192.0.2.42"})) + } + + flags, err := parseRecoverFlags(cmd, fileHandler) if tc.wantErr { assert.Error(err) @@ -241,78 +248,94 @@ func TestDoRecovery(t *testing.T) { wantErr bool }{ "success": { - recoveryServer: &stubRecoveryServer{actions: []func(stream recoverproto.API_RecoverServer) error{ - func(stream recoverproto.API_RecoverServer) error { - _, err := stream.Recv() - return err - }, - func(stream recoverproto.API_RecoverServer) error { - return stream.Send(&recoverproto.RecoverResponse{ - DiskUuid: "00000000-0000-0000-0000-000000000000", - }) - }, - func(stream recoverproto.API_RecoverServer) error { - _, err := stream.Recv() - return err - }, - }}, + recoveryServer: &stubRecoveryServer{ + actions: [][]func(stream recoverproto.API_RecoverServer) error{{ + func(stream recoverproto.API_RecoverServer) error { + _, err := stream.Recv() + return err + }, + func(stream recoverproto.API_RecoverServer) error { + return stream.Send(&recoverproto.RecoverResponse{ + DiskUuid: "00000000-0000-0000-0000-000000000000", + }) + }, + func(stream recoverproto.API_RecoverServer) error { + _, err := stream.Recv() + return err + }, + }}, + }, }, "error on first recv": { - recoveryServer: &stubRecoveryServer{actions: []func(stream recoverproto.API_RecoverServer) error{ - func(stream recoverproto.API_RecoverServer) error { - return someErr + recoveryServer: &stubRecoveryServer{ + actions: [][]func(stream recoverproto.API_RecoverServer) error{ + { + func(stream recoverproto.API_RecoverServer) error { + return someErr + }, + }, }, - }}, + }, wantErr: true, }, "error on send": { - recoveryServer: &stubRecoveryServer{actions: []func(stream recoverproto.API_RecoverServer) error{ - func(stream recoverproto.API_RecoverServer) error { - _, err := stream.Recv() - return err + recoveryServer: &stubRecoveryServer{ + actions: [][]func(stream recoverproto.API_RecoverServer) error{ + { + func(stream recoverproto.API_RecoverServer) error { + _, err := stream.Recv() + return err + }, + func(stream recoverproto.API_RecoverServer) error { + return someErr + }, + }, }, - func(stream recoverproto.API_RecoverServer) error { - return someErr - }, - }}, + }, wantErr: true, }, "error on second recv": { - recoveryServer: &stubRecoveryServer{actions: []func(stream recoverproto.API_RecoverServer) error{ - func(stream recoverproto.API_RecoverServer) error { - _, err := stream.Recv() - return err + recoveryServer: &stubRecoveryServer{ + actions: [][]func(stream recoverproto.API_RecoverServer) error{ + { + func(stream recoverproto.API_RecoverServer) error { + _, err := stream.Recv() + return err + }, + func(stream recoverproto.API_RecoverServer) error { + return stream.Send(&recoverproto.RecoverResponse{ + DiskUuid: "00000000-0000-0000-0000-000000000000", + }) + }, + func(stream recoverproto.API_RecoverServer) error { + return someErr + }, + }, }, - func(stream recoverproto.API_RecoverServer) error { - return stream.Send(&recoverproto.RecoverResponse{ - DiskUuid: "00000000-0000-0000-0000-000000000000", - }) - }, - func(stream recoverproto.API_RecoverServer) error { - return someErr - }, - }}, + }, wantErr: true, }, "final message is an error": { - recoveryServer: &stubRecoveryServer{actions: []func(stream recoverproto.API_RecoverServer) error{ - func(stream recoverproto.API_RecoverServer) error { - _, err := stream.Recv() - return err - }, - func(stream recoverproto.API_RecoverServer) error { - return stream.Send(&recoverproto.RecoverResponse{ - DiskUuid: "00000000-0000-0000-0000-000000000000", - }) - }, - func(stream recoverproto.API_RecoverServer) error { - _, err := stream.Recv() - return err - }, - func(stream recoverproto.API_RecoverServer) error { - return someErr - }, - }}, + recoveryServer: &stubRecoveryServer{ + actions: [][]func(stream recoverproto.API_RecoverServer) error{{ + func(stream recoverproto.API_RecoverServer) error { + _, err := stream.Recv() + return err + }, + func(stream recoverproto.API_RecoverServer) error { + return stream.Send(&recoverproto.RecoverResponse{ + DiskUuid: "00000000-0000-0000-0000-000000000000", + }) + }, + func(stream recoverproto.API_RecoverServer) error { + _, err := stream.Recv() + return err + }, + func(stream recoverproto.API_RecoverServer) error { + return someErr + }, + }}, + }, wantErr: true, }, } @@ -325,7 +348,7 @@ func TestDoRecovery(t *testing.T) { serverCreds := atlscredentials.New(nil, nil) recoverServer := grpc.NewServer(grpc.Creds(serverCreds)) recoverproto.RegisterAPIServer(recoverServer, tc.recoveryServer) - addr := net.JoinHostPort("192.0.2.1", strconv.Itoa(constants.RecoveryPort)) + addr := net.JoinHostPort("192.0.42.42", strconv.Itoa(constants.RecoveryPort)) listener := netDialer.GetListener(addr) go recoverServer.Serve(listener) defer recoverServer.GracefulStop() @@ -375,15 +398,39 @@ func TestDeriveStateDiskKey(t *testing.T) { } type stubRecoveryServer struct { - actions []func(recoverproto.API_RecoverServer) error + actions [][]func(recoverproto.API_RecoverServer) error + calls int recoverproto.UnimplementedAPIServer } func (s *stubRecoveryServer) Recover(stream recoverproto.API_RecoverServer) error { - for _, action := range s.actions { + if s.calls >= len(s.actions) { + return status.Error(codes.Unavailable, "server is unavailable") + } + s.calls++ + + for _, action := range s.actions[s.calls-1] { if err := action(stream); err != nil { return err } } return nil } + +type stubDoer struct { + returns []error +} + +func (d *stubDoer) Do(context.Context) error { + err := d.returns[0] + if len(d.returns) > 1 { + d.returns = d.returns[1:] + } else { + d.returns = []error{status.Error(codes.Unavailable, "unavailable")} + } + return err +} + +func (d *stubDoer) setDialer(grpcDialer, string) {} + +func (d *stubDoer) setSecrets(func(string) ([]byte, error), []byte) {} diff --git a/docs/docs/workflows/recovery.md b/docs/docs/workflows/recovery.md index 8698b9d87..24b57500c 100644 --- a/docs/docs/workflows/recovery.md +++ b/docs/docs/workflows/recovery.md @@ -51,7 +51,7 @@ If that fails, because the control plane is unhealthy, you will see log messages {"level":"ERROR","ts":"2022-09-08T09:57:23Z","logger":"rejoinClient","caller":"rejoinclient/client.go:110","msg":"Failed to rejoin on all endpoints"} ``` -This means that you have to recover the node manually. For this, you need its IP address, which can be obtained from the *Overview* page under *Private IP address*. +This means that you have to recover the node manually. @@ -88,33 +88,26 @@ If that fails, because the control plane is unhealthy, you will see log messages {"level":"ERROR","ts":"2022-09-08T10:22:13Z","logger":"rejoinClient","caller":"rejoinclient/client.go:110","msg":"Failed to rejoin on all endpoints"} ``` -This means that you have to recover the node manually. For this, you need its IP address, which can be obtained from the *"VM Instance" -> "network interfaces"* page under *"Primary internal IP address."* +This means that you have to recover the node manually. ## Recover your cluster -The following process needs to be repeated until a [member quorum for etcd](https://etcd.io/docs/v3.5/faq/#what-is-failure-tolerance) is established. -For example, assume you have 5 control-plane nodes in your cluster and 4 of them have been rebooted due to a maintenance downtime in the cloud environment. -You have to run through the following process for 2 of these nodes and recover them manually to recover the quorum. -From there, your cluster will auto heal the remaining 2 control-plane nodes and the rest of your cluster. +Recovering a cluster requires the following parameters: -Recovering a node requires the following parameters: - -* The node's IP address +* The `constellation-id.json` file in your working directory or the cluster's load balancer IP address * Access to the master secret of the cluster -See the [Identify unhealthy clusters](#identify-unhealthy-clusters) description of how to obtain the node's IP address. -Note that the recovery command needs to connect to the recovering nodes. -Nodes only have private IP addresses in the VPC of the cluster, hence, the command needs to be issued from within the VPC network of the cluster. -The easiest approach is to set up a jump host connected to the VPC network and perform the recovery from there. +A cluster can be recovered like this: -Given these prerequisites a node can be recovered like this: - -``` -$ constellation recover -e 34.107.89.208 --master-secret constellation-mastersecret.json +```bash +$ constellation recover --master-secret constellation-mastersecret.json Pushed recovery key. +Pushed recovery key. +Pushed recovery key. +Recovered 3 control-plane nodes. ``` In the serial console output of the node you'll see a similar output to the following: diff --git a/internal/grpc/retry/retry.go b/internal/grpc/retry/retry.go index c1922397b..48da4b1dd 100644 --- a/internal/grpc/retry/retry.go +++ b/internal/grpc/retry/retry.go @@ -14,15 +14,54 @@ import ( "google.golang.org/grpc/status" ) +const ( + authEOFErr = `connection error: desc = "transport: authentication handshake failed: EOF"` + authReadTCPErr = `connection error: desc = "transport: authentication handshake failed: read tcp` + authHandshakeErr = `connection error: desc = "transport: authentication handshake failed` +) + +// grpcErr is the error type that is returned by the grpc client. +// taken from google.golang.org/grpc/status.FromError. +type grpcErr interface { + GRPCStatus() *status.Status + Error() string +} + // ServiceIsUnavailable checks if the error is a grpc status with code Unavailable. // In the special case of an authentication handshake failure, false is returned to prevent further retries. +// Since the GCP proxy loadbalancer may error with an authentication handshake failure if no available backends are ready, +// the special handshake errors caused by the GCP LB (e.g. "read tcp", "EOF") are retried. func ServiceIsUnavailable(err error) bool { - // taken from google.golang.org/grpc/status.FromError - var targetErr interface { - GRPCStatus() *status.Status - Error() string + var targetErr grpcErr + if !errors.As(err, &targetErr) { + return false } + statusErr, ok := status.FromError(targetErr) + if !ok { + return false + } + + if statusErr.Code() != codes.Unavailable { + return false + } + + // retry if GCP proxy LB isn't available + if strings.HasPrefix(statusErr.Message(), authEOFErr) { + return true + } + + // retry if GCP proxy LB isn't fully available yet + if strings.HasPrefix(statusErr.Message(), authReadTCPErr) { + return true + } + + return !strings.HasPrefix(statusErr.Message(), authHandshakeErr) +} + +// LoadbalancerIsNotReady checks if the error was caused by a GCP LB not being ready yet. +func LoadbalancerIsNotReady(err error) bool { + var targetErr grpcErr if !errors.As(err, &targetErr) { return false } @@ -37,15 +76,5 @@ func ServiceIsUnavailable(err error) bool { } // retry if GCP proxy LB isn't fully available yet - if strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed: EOF"`) { - return true - } - - // retry if GCP proxy LB isn't fully available yet - if strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed: read tcp`) { - return true - } - - // ideally we would check the error type directly, but grpc only provides a string - return !strings.HasPrefix(statusErr.Message(), `connection error: desc = "transport: authentication handshake failed`) + return strings.HasPrefix(statusErr.Message(), authReadTCPErr) } diff --git a/internal/grpc/retry/retry_test.go b/internal/grpc/retry/retry_test.go index 7ed75dbf6..a1b44dce4 100644 --- a/internal/grpc/retry/retry_test.go +++ b/internal/grpc/retry/retry_test.go @@ -29,12 +29,20 @@ func TestServiceIsUnavailable(t *testing.T) { err: status.Error(codes.Internal, "error"), }, "unavailable error with authentication handshake failure": { - err: status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed`), + err: status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed: bad certificate"`), }, "normal unavailable error": { err: status.Error(codes.Unavailable, "error"), wantUnavailable: true, }, + "handshake EOF error": { + err: status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed: EOF"`), + wantUnavailable: true, + }, + "handshake read tcp error": { + err: status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed: read tcp error"`), + wantUnavailable: true, + }, "wrapped error": { err: fmt.Errorf("some wrapping: %w", status.Error(codes.Unavailable, "error")), wantUnavailable: true, @@ -51,3 +59,44 @@ func TestServiceIsUnavailable(t *testing.T) { }) } } + +func TestLoadbalancerIsNotReady(t *testing.T) { + testCases := map[string]struct { + err error + wantNotReady bool + }{ + "nil": {}, + "not status error": { + err: errors.New("error"), + }, + "not unavailable": { + err: status.Error(codes.Internal, "error"), + }, + "unavailable error with authentication handshake failure": { + err: status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed: bad certificate"`), + }, + "handshake EOF error": { + err: status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed: EOF"`), + }, + "handshake read tcp error": { + err: status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed: read tcp error"`), + wantNotReady: true, + }, + "normal unavailable error": { + err: status.Error(codes.Unavailable, "error"), + }, + "wrapped error": { + err: fmt.Errorf("some wrapping: %w", status.Error(codes.Unavailable, "error")), + }, + "code unknown": { + err: status.Error(codes.Unknown, "unknown"), + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + assert.Equal(tc.wantNotReady, LoadbalancerIsNotReady(tc.err)) + }) + } +}