mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-08-05 21:44:15 -04:00
AB#2262 Automatic recovery (#158)
* Update `constellation recover` to be fully automated * Update recovery docs Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
4f596cde3c
commit
30f0554168
6 changed files with 408 additions and 240 deletions
|
@ -12,6 +12,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
|
||||
|
@ -24,10 +25,8 @@ import (
|
|||
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||||
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
|
||||
"github.com/edgelesssys/constellation/v2/internal/retry"
|
||||
"github.com/edgelesssys/constellation/v2/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
// NewRecoverCmd returns a new cobra.Command for the recover command.
|
||||
|
@ -40,8 +39,7 @@ func NewRecoverCmd() *cobra.Command {
|
|||
Args: cobra.ExactArgs(0),
|
||||
RunE: runRecover,
|
||||
}
|
||||
cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT] (required)")
|
||||
must(cmd.MarkFlagRequired("endpoint"))
|
||||
cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT]")
|
||||
cmd.Flags().String("master-secret", constants.MasterSecretFilename, "path to master secret file")
|
||||
return cmd
|
||||
}
|
||||
|
@ -51,11 +49,14 @@ func runRecover(cmd *cobra.Command, _ []string) error {
|
|||
newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer {
|
||||
return dialer.New(nil, validator.V(cmd), &net.Dialer{})
|
||||
}
|
||||
return recover(cmd, fileHandler, newDialer)
|
||||
return recover(cmd, fileHandler, 5*time.Second, &recoverDoer{}, newDialer)
|
||||
}
|
||||
|
||||
func recover(cmd *cobra.Command, fileHandler file.Handler, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer) error {
|
||||
flags, err := parseRecoverFlags(cmd)
|
||||
func recover(
|
||||
cmd *cobra.Command, fileHandler file.Handler, interval time.Duration,
|
||||
doer recoverDoerInterface, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer,
|
||||
) error {
|
||||
flags, err := parseRecoverFlags(cmd, fileHandler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -65,48 +66,81 @@ func recover(cmd *cobra.Command, fileHandler file.Handler, newDialer func(valida
|
|||
return err
|
||||
}
|
||||
|
||||
var stat state.ConstellationState
|
||||
if err := fileHandler.ReadJSON(constants.StateFilename, &stat); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
provider := cloudprovider.FromString(stat.CloudProvider)
|
||||
config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading and validating config: %w", err)
|
||||
}
|
||||
provider := config.GetProvider()
|
||||
if provider == cloudprovider.Azure {
|
||||
interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances
|
||||
}
|
||||
|
||||
validator, err := cloudcmd.NewValidator(provider, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
doer.setDialer(newDialer(validator), flags.endpoint)
|
||||
|
||||
if err := recoverCall(cmd.Context(), newDialer(validator), flags.endpoint, masterSecret.Key, masterSecret.Salt); err != nil {
|
||||
return fmt.Errorf("recovering cluster: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("Pushed recovery key.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func recoverCall(ctx context.Context, dialer grpcDialer, endpoint string, key, salt []byte) error {
|
||||
measurementSecret, err := attestation.DeriveMeasurementSecret(key, salt)
|
||||
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
doer := &recoverDoer{
|
||||
dialer: dialer,
|
||||
endpoint: endpoint,
|
||||
getDiskKey: getStateDiskKeyFunc(key, salt),
|
||||
measurementSecret: measurementSecret,
|
||||
}
|
||||
retrier := retry.NewIntervalRetrier(doer, 30*time.Second, grpcRetry.ServiceIsUnavailable)
|
||||
if err := retrier.Do(ctx); err != nil {
|
||||
return err
|
||||
doer.setSecrets(getStateDiskKeyFunc(masterSecret.Key, masterSecret.Salt), measurementSecret)
|
||||
|
||||
if err := recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil {
|
||||
if grpcRetry.ServiceIsUnavailable(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("recovering cluster: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func recoverCall(ctx context.Context, out io.Writer, interval time.Duration, doer recoverDoerInterface) error {
|
||||
var err error
|
||||
ctr := 0
|
||||
for {
|
||||
once := sync.Once{}
|
||||
retryOnceOnFailure := func(err error) bool {
|
||||
// retry transient GCP LB errors
|
||||
if grpcRetry.LoadbalancerIsNotReady(err) {
|
||||
return true
|
||||
}
|
||||
retry := false
|
||||
|
||||
// retry connection errors once
|
||||
// this is necessary because Azure's LB takes a while to remove unhealthy instances
|
||||
once.Do(func() {
|
||||
retry = grpcRetry.ServiceIsUnavailable(err)
|
||||
})
|
||||
return retry
|
||||
}
|
||||
|
||||
retrier := retry.NewIntervalRetrier(doer, interval, retryOnceOnFailure)
|
||||
err = retrier.Do(ctx)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
fmt.Fprintln(out, "Pushed recovery key.")
|
||||
ctr++
|
||||
}
|
||||
|
||||
if ctr > 0 {
|
||||
fmt.Fprintf(out, "Recovered %d control-plane nodes.\n", ctr)
|
||||
} else if grpcRetry.ServiceIsUnavailable(err) {
|
||||
fmt.Fprintln(out, "No control-plane nodes in need of recovery found. Exiting.")
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type recoverDoerInterface interface {
|
||||
Do(ctx context.Context) error
|
||||
setDialer(dialer grpcDialer, endpoint string)
|
||||
setSecrets(getDiskKey func(uuid string) ([]byte, error), measurementSecret []byte)
|
||||
}
|
||||
|
||||
type recoverDoer struct {
|
||||
dialer grpcDialer
|
||||
endpoint string
|
||||
|
@ -114,6 +148,7 @@ type recoverDoer struct {
|
|||
getDiskKey func(uuid string) (key []byte, err error)
|
||||
}
|
||||
|
||||
// Do performs the recover streaming rpc.
|
||||
func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
|
||||
conn, err := d.dialer.Dial(ctx, d.endpoint)
|
||||
if err != nil {
|
||||
|
@ -125,12 +160,10 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
|
|||
protoClient := recoverproto.NewAPIClient(conn)
|
||||
recoverclient, err := protoClient.Recover(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := recoverclient.CloseSend(); err != nil {
|
||||
multierr.AppendInto(&retErr, err)
|
||||
}
|
||||
_ = recoverclient.CloseSend()
|
||||
}()
|
||||
|
||||
// send measurement secret as first message
|
||||
|
@ -139,17 +172,17 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
|
|||
MeasurementSecret: d.measurementSecret,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("sending measurement secret: %w", err)
|
||||
}
|
||||
|
||||
// receive disk uuid
|
||||
res, err := recoverclient.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("receiving disk uuid: %w", err)
|
||||
}
|
||||
stateDiskKey, err := d.getDiskKey(res.DiskUuid)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("getting state disk key: %w", err)
|
||||
}
|
||||
|
||||
// send disk key
|
||||
|
@ -158,20 +191,42 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
|
|||
StateDiskKey: stateDiskKey,
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("sending state disk key: %w", err)
|
||||
}
|
||||
|
||||
if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
return fmt.Errorf("receiving confirmation: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
|
||||
func (d *recoverDoer) setDialer(dialer grpcDialer, endpoint string) {
|
||||
d.dialer = dialer
|
||||
d.endpoint = endpoint
|
||||
}
|
||||
|
||||
func (d *recoverDoer) setSecrets(getDiskKey func(string) ([]byte, error), measurementSecret []byte) {
|
||||
d.getDiskKey = getDiskKey
|
||||
d.measurementSecret = measurementSecret
|
||||
}
|
||||
|
||||
type recoverFlags struct {
|
||||
endpoint string
|
||||
secretPath string
|
||||
configPath string
|
||||
}
|
||||
|
||||
func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
|
||||
endpoint, err := cmd.Flags().GetString("endpoint")
|
||||
if err != nil {
|
||||
return recoverFlags{}, fmt.Errorf("parsing endpoint argument: %w", err)
|
||||
}
|
||||
if endpoint == "" {
|
||||
endpoint, err = readIPFromIDFile(fileHandler)
|
||||
if err != nil {
|
||||
return recoverFlags{}, fmt.Errorf("getting recovery endpoint: %w", err)
|
||||
}
|
||||
}
|
||||
endpoint, err = addPortIfMissing(endpoint, constants.RecoveryPort)
|
||||
if err != nil {
|
||||
return recoverFlags{}, fmt.Errorf("validating endpoint argument: %w", err)
|
||||
|
@ -194,12 +249,6 @@ func parseRecoverFlags(cmd *cobra.Command) (recoverFlags, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
type recoverFlags struct {
|
||||
endpoint string
|
||||
secretPath string
|
||||
configPath string
|
||||
}
|
||||
|
||||
func getStateDiskKeyFunc(masterKey, salt []byte) func(uuid string) ([]byte, error) {
|
||||
return func(uuid string) ([]byte, error) {
|
||||
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.StateDiskKeyLength)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue