bootstrapper: wipe disk and reboot on non-recoverable error (#2971)

* Let JoinClient return fatal errors
* Mark disk for wiping if JoinClient or InitServer return errors
* Reboot system if bootstrapper detects an error
* Refactor joinClient start/stop implementation
* Fix joining nodes retrying kubeadm 3 times in all cases
* Write non-recoverable failures to syslog before rebooting

---------

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2024-03-12 11:43:38 +01:00 committed by GitHub
parent 1b973bf23f
commit 1077b7a48e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 199 additions and 220 deletions

View file

@ -25,11 +25,9 @@ import (
"net"
"path/filepath"
"strconv"
"sync"
"time"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/certificate"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption"
"github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/constants"
@ -69,21 +67,19 @@ type JoinClient struct {
dialer grpcDialer
joiner ClusterJoiner
cleaner cleaner
metadataAPI MetadataAPI
log *slog.Logger
mux sync.Mutex
stopC chan struct{}
stopDone chan struct{}
}
// New creates a new JoinClient.
func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, log *slog.Logger) *JoinClient {
func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, disk encryptedDisk, log *slog.Logger) *JoinClient {
return &JoinClient{
nodeLock: lock,
disk: diskencryption.New(),
disk: disk,
fileHandler: file.NewHandler(afero.NewOsFs()),
timeout: timeout,
joinTimeout: joinTimeout,
@ -93,99 +89,83 @@ func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, l
joiner: joiner,
metadataAPI: meta,
log: log.WithGroup("join-client"),
stopC: make(chan struct{}, 1),
stopDone: make(chan struct{}, 1),
}
}
// Start starts the client routine. The client will make the needed API calls to join
// the cluster with the role it receives from the metadata API.
// After receiving the needed information, the node will join the cluster.
// Multiple calls of start on the same client won't start a second routine if there is
// already a routine running.
func (c *JoinClient) Start(cleaner cleaner) {
c.mux.Lock()
defer c.mux.Unlock()
func (c *JoinClient) Start(cleaner cleaner) error {
c.log.Info("Starting")
ticker := c.clock.NewTicker(c.interval)
defer ticker.Stop()
defer func() { c.stopDone <- struct{}{} }()
defer c.log.Info("Client stopped")
if c.stopC != nil { // daemon already running
return
diskUUID, err := c.getDiskUUID()
if err != nil {
c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID")
return err
}
c.diskUUID = diskUUID
for {
err := c.getNodeMetadata()
if err == nil {
c.log.With(slog.String("role", c.role.String()), slog.String("name", c.nodeName)).Info("Received own instance metadata")
break
}
c.log.With(slog.Any("error", err)).Error("Failed to retrieve instance metadata")
c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping")
select {
case <-c.stopC:
return nil
case <-ticker.C():
}
}
c.log.Info("Starting")
c.stopC = make(chan struct{}, 1)
c.stopDone = make(chan struct{}, 1)
c.cleaner = cleaner
var ticket *joinproto.IssueJoinTicketResponse
var kubeletKey []byte
ticker := c.clock.NewTicker(c.interval)
go func() {
defer ticker.Stop()
defer func() { c.stopDone <- struct{}{} }()
defer c.log.Info("Client stopped")
diskUUID, err := c.getDiskUUID()
if err != nil {
c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID")
return
for {
ticket, kubeletKey, err = c.tryJoinWithAvailableServices()
if err == nil {
c.log.Info("Successfully retrieved join ticket, starting Kubernetes node")
break
}
c.diskUUID = diskUUID
c.log.With(slog.Any("error", err)).Warn("Join failed for all available endpoints")
for {
err := c.getNodeMetadata()
if err == nil {
c.log.With(slog.String("role", c.role.String()), slog.String("name", c.nodeName)).Info("Received own instance metadata")
break
}
c.log.With(slog.Any("error", err)).Error("Failed to retrieve instance metadata")
c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping")
select {
case <-c.stopC:
return
case <-ticker.C():
}
c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping")
select {
case <-c.stopC:
return nil
case <-ticker.C():
}
}
for {
err := c.tryJoinWithAvailableServices()
if err == nil {
c.log.Info("Joined successfully. Client is shutting down")
return
} else if isUnrecoverable(err) {
c.log.With(slog.Any("error", err)).Error("Unrecoverable error occurred")
// TODO(burgerdev): this should eventually lead to a full node reset
return
}
c.log.With(slog.Any("error", err)).Warn("Join failed for all available endpoints")
if err := c.startNodeAndJoin(ticket, kubeletKey, cleaner); err != nil {
c.log.With(slog.Any("error", err)).Error("Failed to start node and join cluster")
return err
}
c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping")
select {
case <-c.stopC:
return
case <-ticker.C():
}
}
}()
return nil
}
// Stop stops the client and blocks until the client's routine is stopped.
func (c *JoinClient) Stop() {
c.mux.Lock()
defer c.mux.Unlock()
if c.stopC == nil { // daemon not running
return
}
c.log.Info("Stopping")
c.stopC <- struct{}{}
<-c.stopDone
c.stopC = nil
c.stopDone = nil
c.log.Info("Stopped")
}
func (c *JoinClient) tryJoinWithAvailableServices() error {
func (c *JoinClient) tryJoinWithAvailableServices() (ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, err error) {
ctx, cancel := c.timeoutCtx()
defer cancel()
@ -193,46 +173,46 @@ func (c *JoinClient) tryJoinWithAvailableServices() error {
endpoint, _, err := c.metadataAPI.GetLoadBalancerEndpoint(ctx)
if err != nil {
return fmt.Errorf("failed to get load balancer endpoint: %w", err)
return nil, nil, fmt.Errorf("failed to get load balancer endpoint: %w", err)
}
endpoints = append(endpoints, endpoint)
ips, err := c.getControlPlaneIPs(ctx)
if err != nil {
return fmt.Errorf("failed to get control plane IPs: %w", err)
return nil, nil, fmt.Errorf("failed to get control plane IPs: %w", err)
}
endpoints = append(endpoints, ips...)
if len(endpoints) == 0 {
return errors.New("no control plane IPs found")
return nil, nil, errors.New("no control plane IPs found")
}
var joinErrs error
for _, endpoint := range endpoints {
err = c.join(net.JoinHostPort(endpoint, strconv.Itoa(constants.JoinServiceNodePort)))
ticket, kubeletKey, err := c.requestJoinTicket(net.JoinHostPort(endpoint, strconv.Itoa(constants.JoinServiceNodePort)))
if err == nil {
return nil
}
if isUnrecoverable(err) {
return err
return ticket, kubeletKey, nil
}
joinErrs = errors.Join(joinErrs, err)
}
return err
return nil, nil, fmt.Errorf("trying to join on all endpoints %v: %w", endpoints, joinErrs)
}
func (c *JoinClient) join(serviceEndpoint string) error {
func (c *JoinClient) requestJoinTicket(serviceEndpoint string) (ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, err error) {
ctx, cancel := c.timeoutCtx()
defer cancel()
certificateRequest, kubeletKey, err := certificate.GetKubeletCertificateRequest(c.nodeName, c.validIPs)
if err != nil {
return err
return nil, nil, err
}
conn, err := c.dialer.Dial(ctx, serviceEndpoint)
if err != nil {
c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Join service unreachable")
return fmt.Errorf("dialing join service endpoint: %w", err)
return nil, nil, fmt.Errorf("dialing join service endpoint: %w", err)
}
defer conn.Close()
@ -242,26 +222,19 @@ func (c *JoinClient) join(serviceEndpoint string) error {
CertificateRequest: certificateRequest,
IsControlPlane: c.role == role.ControlPlane,
}
ticket, err := protoClient.IssueJoinTicket(ctx, req)
ticket, err = protoClient.IssueJoinTicket(ctx, req)
if err != nil {
c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Issuing join ticket failed")
return fmt.Errorf("issuing join ticket: %w", err)
return nil, nil, fmt.Errorf("issuing join ticket: %w", err)
}
return c.startNodeAndJoin(ticket, kubeletKey)
return ticket, kubeletKey, err
}
func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte) (retErr error) {
func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, cleaner cleaner) error {
ctx, cancel := context.WithTimeout(context.Background(), c.joinTimeout)
defer cancel()
// If an error occurs in this func, the client cannot continue.
defer func() {
if retErr != nil {
retErr = unrecoverableError{retErr}
}
}()
clusterID, err := attestation.DeriveClusterID(ticket.MeasurementSecret, ticket.MeasurementSalt)
if err != nil {
return err
@ -276,10 +249,11 @@ func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse,
// There is already a cluster initialization in progress on
// this node, so there is no need to also join the cluster,
// as the initializing node is automatically part of the cluster.
return errors.New("node is already being initialized")
c.log.Info("Node is already being initialized. Aborting join process.")
return nil
}
c.cleaner.Clean()
cleaner.Clean()
if err := c.updateDiskPassphrase(string(ticket.StateDiskKey)); err != nil {
return fmt.Errorf("updating disk passphrase: %w", err)
@ -313,11 +287,12 @@ func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse,
// We currently cannot recover from any failure in this function. Joining the k8s cluster
// sometimes fails transiently, and we don't want to brick the node because of that.
for i := 0; i < 3; i++ {
for i := range 3 {
err = c.joiner.JoinCluster(ctx, btd, c.role, ticket.KubernetesComponents, c.log)
if err != nil {
c.log.Error("failed to join k8s cluster", "role", c.role, "attempt", i, "error", err)
if err == nil {
break
}
c.log.Error("failed to join k8s cluster", "role", c.role, "attempt", i, "error", err)
}
if err != nil {
return fmt.Errorf("joining Kubernetes cluster: %w", err)
@ -412,13 +387,6 @@ func (c *JoinClient) timeoutCtx() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), c.timeout)
}
type unrecoverableError struct{ error }
func isUnrecoverable(err error) bool {
_, ok := err.(unrecoverableError)
return ok
}
type grpcDialer interface {
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
}