AB#2262 Automatic recovery (#158)

* Update `constellation recover` to be fully automated

* Update recovery docs

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-09-26 09:57:40 +02:00 committed by GitHub
parent 4f596cde3c
commit 30f0554168
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 408 additions and 240 deletions

View file

@ -12,6 +12,7 @@ import (
"fmt"
"io"
"net"
"sync"
"time"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
@ -24,10 +25,8 @@ import (
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"go.uber.org/multierr"
)
// NewRecoverCmd returns a new cobra.Command for the recover command.
@ -40,8 +39,7 @@ func NewRecoverCmd() *cobra.Command {
Args: cobra.ExactArgs(0),
RunE: runRecover,
}
cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT] (required)")
must(cmd.MarkFlagRequired("endpoint"))
cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT]")
cmd.Flags().String("master-secret", constants.MasterSecretFilename, "path to master secret file")
return cmd
}
@ -51,11 +49,14 @@ func runRecover(cmd *cobra.Command, _ []string) error {
newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer {
return dialer.New(nil, validator.V(cmd), &net.Dialer{})
}
return recover(cmd, fileHandler, newDialer)
return recover(cmd, fileHandler, 5*time.Second, &recoverDoer{}, newDialer)
}
func recover(cmd *cobra.Command, fileHandler file.Handler, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer) error {
flags, err := parseRecoverFlags(cmd)
func recover(
cmd *cobra.Command, fileHandler file.Handler, interval time.Duration,
doer recoverDoerInterface, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer,
) error {
flags, err := parseRecoverFlags(cmd, fileHandler)
if err != nil {
return err
}
@ -65,48 +66,81 @@ func recover(cmd *cobra.Command, fileHandler file.Handler, newDialer func(valida
return err
}
var stat state.ConstellationState
if err := fileHandler.ReadJSON(constants.StateFilename, &stat); err != nil {
return err
}
provider := cloudprovider.FromString(stat.CloudProvider)
config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath)
if err != nil {
return fmt.Errorf("reading and validating config: %w", err)
}
provider := config.GetProvider()
if provider == cloudprovider.Azure {
interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances
}
validator, err := cloudcmd.NewValidator(provider, config)
if err != nil {
return err
}
doer.setDialer(newDialer(validator), flags.endpoint)
if err := recoverCall(cmd.Context(), newDialer(validator), flags.endpoint, masterSecret.Key, masterSecret.Salt); err != nil {
return fmt.Errorf("recovering cluster: %w", err)
}
cmd.Println("Pushed recovery key.")
return nil
}
func recoverCall(ctx context.Context, dialer grpcDialer, endpoint string, key, salt []byte) error {
measurementSecret, err := attestation.DeriveMeasurementSecret(key, salt)
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt)
if err != nil {
return err
}
doer := &recoverDoer{
dialer: dialer,
endpoint: endpoint,
getDiskKey: getStateDiskKeyFunc(key, salt),
measurementSecret: measurementSecret,
}
retrier := retry.NewIntervalRetrier(doer, 30*time.Second, grpcRetry.ServiceIsUnavailable)
if err := retrier.Do(ctx); err != nil {
return err
doer.setSecrets(getStateDiskKeyFunc(masterSecret.Key, masterSecret.Salt), measurementSecret)
if err := recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil {
if grpcRetry.ServiceIsUnavailable(err) {
return nil
}
return fmt.Errorf("recovering cluster: %w", err)
}
return nil
}
func recoverCall(ctx context.Context, out io.Writer, interval time.Duration, doer recoverDoerInterface) error {
var err error
ctr := 0
for {
once := sync.Once{}
retryOnceOnFailure := func(err error) bool {
// retry transient GCP LB errors
if grpcRetry.LoadbalancerIsNotReady(err) {
return true
}
retry := false
// retry connection errors once
// this is necessary because Azure's LB takes a while to remove unhealthy instances
once.Do(func() {
retry = grpcRetry.ServiceIsUnavailable(err)
})
return retry
}
retrier := retry.NewIntervalRetrier(doer, interval, retryOnceOnFailure)
err = retrier.Do(ctx)
if err != nil {
break
}
fmt.Fprintln(out, "Pushed recovery key.")
ctr++
}
if ctr > 0 {
fmt.Fprintf(out, "Recovered %d control-plane nodes.\n", ctr)
} else if grpcRetry.ServiceIsUnavailable(err) {
fmt.Fprintln(out, "No control-plane nodes in need of recovery found. Exiting.")
return nil
}
return err
}
type recoverDoerInterface interface {
Do(ctx context.Context) error
setDialer(dialer grpcDialer, endpoint string)
setSecrets(getDiskKey func(uuid string) ([]byte, error), measurementSecret []byte)
}
type recoverDoer struct {
dialer grpcDialer
endpoint string
@ -114,6 +148,7 @@ type recoverDoer struct {
getDiskKey func(uuid string) (key []byte, err error)
}
// Do performs the recover streaming rpc.
func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
conn, err := d.dialer.Dial(ctx, d.endpoint)
if err != nil {
@ -125,12 +160,10 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
protoClient := recoverproto.NewAPIClient(conn)
recoverclient, err := protoClient.Recover(ctx)
if err != nil {
return err
return fmt.Errorf("creating client: %w", err)
}
defer func() {
if err := recoverclient.CloseSend(); err != nil {
multierr.AppendInto(&retErr, err)
}
_ = recoverclient.CloseSend()
}()
// send measurement secret as first message
@ -139,17 +172,17 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
MeasurementSecret: d.measurementSecret,
},
}); err != nil {
return err
return fmt.Errorf("sending measurement secret: %w", err)
}
// receive disk uuid
res, err := recoverclient.Recv()
if err != nil {
return err
return fmt.Errorf("receiving disk uuid: %w", err)
}
stateDiskKey, err := d.getDiskKey(res.DiskUuid)
if err != nil {
return err
return fmt.Errorf("getting state disk key: %w", err)
}
// send disk key
@ -158,20 +191,42 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
StateDiskKey: stateDiskKey,
},
}); err != nil {
return err
return fmt.Errorf("sending state disk key: %w", err)
}
if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
return err
return fmt.Errorf("receiving confirmation: %w", err)
}
return nil
}
func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
func (d *recoverDoer) setDialer(dialer grpcDialer, endpoint string) {
d.dialer = dialer
d.endpoint = endpoint
}
func (d *recoverDoer) setSecrets(getDiskKey func(string) ([]byte, error), measurementSecret []byte) {
d.getDiskKey = getDiskKey
d.measurementSecret = measurementSecret
}
type recoverFlags struct {
endpoint string
secretPath string
configPath string
}
func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
endpoint, err := cmd.Flags().GetString("endpoint")
if err != nil {
return recoverFlags{}, fmt.Errorf("parsing endpoint argument: %w", err)
}
if endpoint == "" {
endpoint, err = readIPFromIDFile(fileHandler)
if err != nil {
return recoverFlags{}, fmt.Errorf("getting recovery endpoint: %w", err)
}
}
endpoint, err = addPortIfMissing(endpoint, constants.RecoveryPort)
if err != nil {
return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
@ -194,12 +249,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
}, nil
}
type recoverFlags struct {
endpoint string
secretPath string
configPath string
}
func getStateDiskKeyFunc(masterKey, salt []byte) func(uuid string) ([]byte, error) {
return func(uuid string) ([]byte, error) {
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.StateDiskKeyLength)