bootstrapper: wipe disk and reboot on non-recoverable error (#2971)

* Let JoinClient return fatal errors
* Mark disk for wiping if JoinClient or InitServer return errors
* Reboot system if bootstrapper detects an error
* Refactor joinClient start/stop implementation
* Fix joining nodes retrying kubeadm 3 times in all cases
* Write non-recoverable failures to syslog before rebooting

---------

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2024-03-12 11:43:38 +01:00 committed by GitHub
parent 1b973bf23f
commit 1077b7a48e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 199 additions and 220 deletions

View File

@ -10,8 +10,11 @@ import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"log/syslog"
"net" "net"
"os" "sync"
"syscall"
"time"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/clean" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/clean"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption"
@ -32,7 +35,8 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl
) { ) {
log.With(slog.String("version", constants.BinaryVersion().String())).Info("Starting bootstrapper") log.With(slog.String("version", constants.BinaryVersion().String())).Info("Starting bootstrapper")
uuid, err := getDiskUUID() disk := diskencryption.New()
uuid, err := getDiskUUID(disk)
if err != nil { if err != nil {
log.With(slog.Any("error", err)).Error("Failed to get disk UUID") log.With(slog.Any("error", err)).Error("Failed to get disk UUID")
} else { } else {
@ -42,43 +46,58 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl
nodeBootstrapped, err := initialize.IsNodeBootstrapped(openDevice) nodeBootstrapped, err := initialize.IsNodeBootstrapped(openDevice)
if err != nil { if err != nil {
log.With(slog.Any("error", err)).Error("Failed to check if node was previously bootstrapped") log.With(slog.Any("error", err)).Error("Failed to check if node was previously bootstrapped")
os.Exit(1) reboot(fmt.Errorf("checking if node was previously bootstrapped: %w", err))
} }
if nodeBootstrapped { if nodeBootstrapped {
if err := kube.StartKubelet(); err != nil { if err := kube.StartKubelet(); err != nil {
log.With(slog.Any("error", err)).Error("Failed to restart kubelet") log.With(slog.Any("error", err)).Error("Failed to restart kubelet")
os.Exit(1) reboot(fmt.Errorf("restarting kubelet: %w", err))
} }
return return
} }
nodeLock := nodelock.New(openDevice) nodeLock := nodelock.New(openDevice)
initServer, err := initserver.New(context.Background(), nodeLock, kube, issuer, fileHandler, metadata, log) initServer, err := initserver.New(context.Background(), nodeLock, kube, issuer, disk, fileHandler, metadata, log)
if err != nil { if err != nil {
log.With(slog.Any("error", err)).Error("Failed to create init server") log.With(slog.Any("error", err)).Error("Failed to create init server")
os.Exit(1) reboot(fmt.Errorf("creating init server: %w", err))
} }
dialer := dialer.New(issuer, nil, &net.Dialer{}) dialer := dialer.New(issuer, nil, &net.Dialer{})
joinClient := joinclient.New(nodeLock, dialer, kube, metadata, log) joinClient := joinclient.New(nodeLock, dialer, kube, metadata, disk, log)
cleaner := clean.New().With(initServer).With(joinClient) cleaner := clean.New().With(initServer).With(joinClient)
go cleaner.Start() go cleaner.Start()
defer cleaner.Done() defer cleaner.Done()
joinClient.Start(cleaner) var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := joinClient.Start(cleaner); err != nil {
log.With(slog.Any("error", err)).Error("Failed to join cluster")
markDiskForReset(disk)
reboot(fmt.Errorf("joining cluster: %w", err))
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := initServer.Serve(bindIP, bindPort, cleaner); err != nil { if err := initServer.Serve(bindIP, bindPort, cleaner); err != nil {
log.With(slog.Any("error", err)).Error("Failed to serve init server") log.With(slog.Any("error", err)).Error("Failed to serve init server")
os.Exit(1) markDiskForReset(disk)
reboot(fmt.Errorf("serving init server: %w", err))
} }
}()
wg.Wait()
log.Info("bootstrapper done") log.Info("bootstrapper done")
} }
func getDiskUUID() (string, error) { func getDiskUUID(disk *diskencryption.DiskEncryption) (string, error) {
disk := diskencryption.New()
free, err := disk.Open() free, err := disk.Open()
if err != nil { if err != nil {
return "", err return "", err
@ -87,6 +106,36 @@ func getDiskUUID() (string, error) {
return disk.UUID() return disk.UUID()
} }
// markDiskForReset sets a token in the cryptsetup header of the disk to indicate the disk should be reset on next boot.
// This is used to reset all state of a node in case the bootstrapper encountered a non recoverable error
// after the node successfully retrieved a join ticket from the JoinService.
// As setting this token is safe as long as we are certain we don't need the data on the disk anymore, we call this
// unconditionally when either the JoinClient or the InitServer encounter an error.
// We don't call it before that, as the node may be restarting after a previous, successful bootstrapping,
// and now encountered a transient error on rejoining the cluster. Wiping the disk now would delete existing data.
func markDiskForReset(disk *diskencryption.DiskEncryption) {
free, err := disk.Open()
if err != nil {
return
}
defer free()
_ = disk.MarkDiskForReset()
}
// reboot writes an error message to the system log and reboots the system.
// We call this instead of os.Exit() since failures in the bootstrapper usually require a node reset.
func reboot(e error) {
syslogWriter, err := syslog.New(syslog.LOG_EMERG|syslog.LOG_KERN, "bootstrapper")
if err != nil {
_ = syscall.Reboot(syscall.LINUX_REBOOT_CMD_RESTART)
}
_ = syslogWriter.Err(e.Error())
_ = syslogWriter.Emerg("bootstrapper has encountered a non recoverable error. Rebooting...")
time.Sleep(time.Minute) // sleep to allow the message to be written to syslog and seen by the user
_ = syscall.Reboot(syscall.LINUX_REBOOT_CMD_RESTART)
}
type clusterInitJoiner interface { type clusterInitJoiner interface {
joinclient.ClusterJoiner joinclient.ClusterJoiner
initserver.ClusterInitializer initserver.ClusterInitializer

View File

@ -60,6 +60,11 @@ func (c *DiskEncryption) UpdatePassphrase(passphrase string) error {
return c.device.SetConstellationStateDiskToken(cryptsetup.SetDiskInitialized) return c.device.SetConstellationStateDiskToken(cryptsetup.SetDiskInitialized)
} }
// MarkDiskForReset marks the state disk as not initialized so it may be wiped (reset) on reboot.
func (c *DiskEncryption) MarkDiskForReset() error {
return c.device.SetConstellationStateDiskToken(cryptsetup.SetDiskNotInitialized)
}
// getInitialPassphrase retrieves the initial passphrase used on first boot. // getInitialPassphrase retrieves the initial passphrase used on first boot.
func (c *DiskEncryption) getInitialPassphrase() (string, error) { func (c *DiskEncryption) getInitialPassphrase() (string, error) {
passphrase, err := afero.ReadFile(c.fs, initialKeyPath) passphrase, err := afero.ReadFile(c.fs, initialKeyPath)

View File

@ -8,7 +8,6 @@ go_library(
visibility = ["//bootstrapper:__subpackages__"], visibility = ["//bootstrapper:__subpackages__"],
deps = [ deps = [
"//bootstrapper/initproto", "//bootstrapper/initproto",
"//bootstrapper/internal/diskencryption",
"//bootstrapper/internal/journald", "//bootstrapper/internal/journald",
"//internal/atls", "//internal/atls",
"//internal/attestation", "//internal/attestation",

View File

@ -30,7 +30,6 @@ import (
"time" "time"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/journald" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/journald"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation" "github.com/edgelesssys/constellation/v2/internal/attestation"
@ -65,6 +64,7 @@ type Server struct {
shutdownLock sync.RWMutex shutdownLock sync.RWMutex
initSecretHash []byte initSecretHash []byte
initFailure error
kmsURI string kmsURI string
@ -76,7 +76,10 @@ type Server struct {
} }
// New creates a new initialization server. // New creates a new initialization server.
func New(ctx context.Context, lock locker, kube ClusterInitializer, issuer atls.Issuer, fh file.Handler, metadata MetadataAPI, log *slog.Logger) (*Server, error) { func New(
ctx context.Context, lock locker, kube ClusterInitializer, issuer atls.Issuer,
disk encryptedDisk, fh file.Handler, metadata MetadataAPI, log *slog.Logger,
) (*Server, error) {
log = log.WithGroup("initServer") log = log.WithGroup("initServer")
initSecretHash, err := metadata.InitSecretHash(ctx) initSecretHash, err := metadata.InitSecretHash(ctx)
@ -94,7 +97,7 @@ func New(ctx context.Context, lock locker, kube ClusterInitializer, issuer atls.
server := &Server{ server := &Server{
nodeLock: lock, nodeLock: lock,
disk: diskencryption.New(), disk: disk,
initializer: kube, initializer: kube,
fileHandler: fh, fileHandler: fh,
issuer: issuer, issuer: issuer,
@ -123,11 +126,20 @@ func (s *Server) Serve(ip, port string, cleaner cleaner) error {
} }
s.log.Info("Starting") s.log.Info("Starting")
return s.grpcServer.Serve(lis) err = s.grpcServer.Serve(lis)
// If Init failed, we mark the disk for reset, so the node can restart the process
// In this case we don't care about any potential errors from the grpc server
if s.initFailure != nil {
s.log.Error("Fatal error during Init request", "error", s.initFailure)
return err
}
return err
} }
// Init initializes the cluster. // Init initializes the cluster.
func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServer) (err error) { func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServer) (retErr error) {
// Acquire lock to prevent shutdown while Init is still running // Acquire lock to prevent shutdown while Init is still running
s.shutdownLock.RLock() s.shutdownLock.RLock()
defer s.shutdownLock.RUnlock() defer s.shutdownLock.RUnlock()
@ -188,6 +200,9 @@ func (s *Server) Init(req *initproto.InitRequest, stream initproto.API_InitServe
// since we are bootstrapping a new one. // since we are bootstrapping a new one.
// Any errors following this call will result in a failed node that may not join any cluster. // Any errors following this call will result in a failed node that may not join any cluster.
s.cleaner.Clean() s.cleaner.Clean()
defer func() {
s.initFailure = retErr
}()
if err := s.setupDisk(stream.Context(), cloudKms); err != nil { if err := s.setupDisk(stream.Context(), cloudKms); err != nil {
if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "setting up disk: %s", err)); e != nil { if e := s.sendLogsWithMessage(stream, status.Errorf(codes.Internal, "setting up disk: %s", err)); e != nil {

View File

@ -67,7 +67,10 @@ func TestNew(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
server, err := New(context.TODO(), newFakeLock(), &stubClusterInitializer{}, atls.NewFakeIssuer(variant.Dummy{}), fh, &tc.metadata, logger.NewTest(t)) server, err := New(
context.TODO(), newFakeLock(), &stubClusterInitializer{}, atls.NewFakeIssuer(variant.Dummy{}),
&stubDisk{}, fh, &tc.metadata, logger.NewTest(t),
)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
return return
@ -381,6 +384,10 @@ func (d *fakeDisk) UpdatePassphrase(passphrase string) error {
return nil return nil
} }
func (d *fakeDisk) MarkDiskForReset() error {
return nil
}
type stubDisk struct { type stubDisk struct {
openErr error openErr error
uuid string uuid string
@ -402,6 +409,10 @@ func (d *stubDisk) UpdatePassphrase(string) error {
return d.updatePassphraseErr return d.updatePassphraseErr
} }
func (d *stubDisk) MarkDiskForReset() error {
return nil
}
type stubClusterInitializer struct { type stubClusterInitializer struct {
initClusterKubeconfig []byte initClusterKubeconfig []byte
initClusterErr error initClusterErr error

View File

@ -8,7 +8,6 @@ go_library(
visibility = ["//bootstrapper:__subpackages__"], visibility = ["//bootstrapper:__subpackages__"],
deps = [ deps = [
"//bootstrapper/internal/certificate", "//bootstrapper/internal/certificate",
"//bootstrapper/internal/diskencryption",
"//internal/attestation", "//internal/attestation",
"//internal/cloud/metadata", "//internal/cloud/metadata",
"//internal/constants", "//internal/constants",

View File

@ -25,11 +25,9 @@ import (
"net" "net"
"path/filepath" "path/filepath"
"strconv" "strconv"
"sync"
"time" "time"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/certificate" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/certificate"
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption"
"github.com/edgelesssys/constellation/v2/internal/attestation" "github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata" "github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
@ -69,21 +67,19 @@ type JoinClient struct {
dialer grpcDialer dialer grpcDialer
joiner ClusterJoiner joiner ClusterJoiner
cleaner cleaner
metadataAPI MetadataAPI metadataAPI MetadataAPI
log *slog.Logger log *slog.Logger
mux sync.Mutex
stopC chan struct{} stopC chan struct{}
stopDone chan struct{} stopDone chan struct{}
} }
// New creates a new JoinClient. // New creates a new JoinClient.
func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, log *slog.Logger) *JoinClient { func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, disk encryptedDisk, log *slog.Logger) *JoinClient {
return &JoinClient{ return &JoinClient{
nodeLock: lock, nodeLock: lock,
disk: diskencryption.New(), disk: disk,
fileHandler: file.NewHandler(afero.NewOsFs()), fileHandler: file.NewHandler(afero.NewOsFs()),
timeout: timeout, timeout: timeout,
joinTimeout: joinTimeout, joinTimeout: joinTimeout,
@ -93,29 +89,18 @@ func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, l
joiner: joiner, joiner: joiner,
metadataAPI: meta, metadataAPI: meta,
log: log.WithGroup("join-client"), log: log.WithGroup("join-client"),
stopC: make(chan struct{}, 1),
stopDone: make(chan struct{}, 1),
} }
} }
// Start starts the client routine. The client will make the needed API calls to join // Start starts the client routine. The client will make the needed API calls to join
// the cluster with the role it receives from the metadata API. // the cluster with the role it receives from the metadata API.
// After receiving the needed information, the node will join the cluster. // After receiving the needed information, the node will join the cluster.
// Multiple calls of start on the same client won't start a second routine if there is func (c *JoinClient) Start(cleaner cleaner) error {
// already a routine running.
func (c *JoinClient) Start(cleaner cleaner) {
c.mux.Lock()
defer c.mux.Unlock()
if c.stopC != nil { // daemon already running
return
}
c.log.Info("Starting") c.log.Info("Starting")
c.stopC = make(chan struct{}, 1)
c.stopDone = make(chan struct{}, 1)
c.cleaner = cleaner
ticker := c.clock.NewTicker(c.interval) ticker := c.clock.NewTicker(c.interval)
go func() {
defer ticker.Stop() defer ticker.Stop()
defer func() { c.stopDone <- struct{}{} }() defer func() { c.stopDone <- struct{}{} }()
defer c.log.Info("Client stopped") defer c.log.Info("Client stopped")
@ -123,7 +108,7 @@ func (c *JoinClient) Start(cleaner cleaner) {
diskUUID, err := c.getDiskUUID() diskUUID, err := c.getDiskUUID()
if err != nil { if err != nil {
c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID") c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID")
return return err
} }
c.diskUUID = diskUUID c.diskUUID = diskUUID
@ -138,54 +123,49 @@ func (c *JoinClient) Start(cleaner cleaner) {
c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping") c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping")
select { select {
case <-c.stopC: case <-c.stopC:
return return nil
case <-ticker.C(): case <-ticker.C():
} }
} }
var ticket *joinproto.IssueJoinTicketResponse
var kubeletKey []byte
for { for {
err := c.tryJoinWithAvailableServices() ticket, kubeletKey, err = c.tryJoinWithAvailableServices()
if err == nil { if err == nil {
c.log.Info("Joined successfully. Client is shutting down") c.log.Info("Successfully retrieved join ticket, starting Kubernetes node")
return break
} else if isUnrecoverable(err) {
c.log.With(slog.Any("error", err)).Error("Unrecoverable error occurred")
// TODO(burgerdev): this should eventually lead to a full node reset
return
} }
c.log.With(slog.Any("error", err)).Warn("Join failed for all available endpoints") c.log.With(slog.Any("error", err)).Warn("Join failed for all available endpoints")
c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping") c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping")
select { select {
case <-c.stopC: case <-c.stopC:
return return nil
case <-ticker.C(): case <-ticker.C():
} }
} }
}()
if err := c.startNodeAndJoin(ticket, kubeletKey, cleaner); err != nil {
c.log.With(slog.Any("error", err)).Error("Failed to start node and join cluster")
return err
}
return nil
} }
// Stop stops the client and blocks until the client's routine is stopped. // Stop stops the client and blocks until the client's routine is stopped.
func (c *JoinClient) Stop() { func (c *JoinClient) Stop() {
c.mux.Lock()
defer c.mux.Unlock()
if c.stopC == nil { // daemon not running
return
}
c.log.Info("Stopping") c.log.Info("Stopping")
c.stopC <- struct{}{} c.stopC <- struct{}{}
<-c.stopDone <-c.stopDone
c.stopC = nil
c.stopDone = nil
c.log.Info("Stopped") c.log.Info("Stopped")
} }
func (c *JoinClient) tryJoinWithAvailableServices() error { func (c *JoinClient) tryJoinWithAvailableServices() (ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, err error) {
ctx, cancel := c.timeoutCtx() ctx, cancel := c.timeoutCtx()
defer cancel() defer cancel()
@ -193,46 +173,46 @@ func (c *JoinClient) tryJoinWithAvailableServices() error {
endpoint, _, err := c.metadataAPI.GetLoadBalancerEndpoint(ctx) endpoint, _, err := c.metadataAPI.GetLoadBalancerEndpoint(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to get load balancer endpoint: %w", err) return nil, nil, fmt.Errorf("failed to get load balancer endpoint: %w", err)
} }
endpoints = append(endpoints, endpoint) endpoints = append(endpoints, endpoint)
ips, err := c.getControlPlaneIPs(ctx) ips, err := c.getControlPlaneIPs(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to get control plane IPs: %w", err) return nil, nil, fmt.Errorf("failed to get control plane IPs: %w", err)
} }
endpoints = append(endpoints, ips...) endpoints = append(endpoints, ips...)
if len(endpoints) == 0 { if len(endpoints) == 0 {
return errors.New("no control plane IPs found") return nil, nil, errors.New("no control plane IPs found")
} }
var joinErrs error
for _, endpoint := range endpoints { for _, endpoint := range endpoints {
err = c.join(net.JoinHostPort(endpoint, strconv.Itoa(constants.JoinServiceNodePort))) ticket, kubeletKey, err := c.requestJoinTicket(net.JoinHostPort(endpoint, strconv.Itoa(constants.JoinServiceNodePort)))
if err == nil { if err == nil {
return nil return ticket, kubeletKey, nil
}
if isUnrecoverable(err) {
return err
}
} }
return err joinErrs = errors.Join(joinErrs, err)
} }
func (c *JoinClient) join(serviceEndpoint string) error { return nil, nil, fmt.Errorf("trying to join on all endpoints %v: %w", endpoints, joinErrs)
}
func (c *JoinClient) requestJoinTicket(serviceEndpoint string) (ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, err error) {
ctx, cancel := c.timeoutCtx() ctx, cancel := c.timeoutCtx()
defer cancel() defer cancel()
certificateRequest, kubeletKey, err := certificate.GetKubeletCertificateRequest(c.nodeName, c.validIPs) certificateRequest, kubeletKey, err := certificate.GetKubeletCertificateRequest(c.nodeName, c.validIPs)
if err != nil { if err != nil {
return err return nil, nil, err
} }
conn, err := c.dialer.Dial(ctx, serviceEndpoint) conn, err := c.dialer.Dial(ctx, serviceEndpoint)
if err != nil { if err != nil {
c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Join service unreachable") c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Join service unreachable")
return fmt.Errorf("dialing join service endpoint: %w", err) return nil, nil, fmt.Errorf("dialing join service endpoint: %w", err)
} }
defer conn.Close() defer conn.Close()
@ -242,26 +222,19 @@ func (c *JoinClient) join(serviceEndpoint string) error {
CertificateRequest: certificateRequest, CertificateRequest: certificateRequest,
IsControlPlane: c.role == role.ControlPlane, IsControlPlane: c.role == role.ControlPlane,
} }
ticket, err := protoClient.IssueJoinTicket(ctx, req) ticket, err = protoClient.IssueJoinTicket(ctx, req)
if err != nil { if err != nil {
c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Issuing join ticket failed") c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Issuing join ticket failed")
return fmt.Errorf("issuing join ticket: %w", err) return nil, nil, fmt.Errorf("issuing join ticket: %w", err)
} }
return c.startNodeAndJoin(ticket, kubeletKey) return ticket, kubeletKey, err
} }
func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte) (retErr error) { func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, cleaner cleaner) error {
ctx, cancel := context.WithTimeout(context.Background(), c.joinTimeout) ctx, cancel := context.WithTimeout(context.Background(), c.joinTimeout)
defer cancel() defer cancel()
// If an error occurs in this func, the client cannot continue.
defer func() {
if retErr != nil {
retErr = unrecoverableError{retErr}
}
}()
clusterID, err := attestation.DeriveClusterID(ticket.MeasurementSecret, ticket.MeasurementSalt) clusterID, err := attestation.DeriveClusterID(ticket.MeasurementSecret, ticket.MeasurementSalt)
if err != nil { if err != nil {
return err return err
@ -276,10 +249,11 @@ func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse,
// There is already a cluster initialization in progress on // There is already a cluster initialization in progress on
// this node, so there is no need to also join the cluster, // this node, so there is no need to also join the cluster,
// as the initializing node is automatically part of the cluster. // as the initializing node is automatically part of the cluster.
return errors.New("node is already being initialized") c.log.Info("Node is already being initialized. Aborting join process.")
return nil
} }
c.cleaner.Clean() cleaner.Clean()
if err := c.updateDiskPassphrase(string(ticket.StateDiskKey)); err != nil { if err := c.updateDiskPassphrase(string(ticket.StateDiskKey)); err != nil {
return fmt.Errorf("updating disk passphrase: %w", err) return fmt.Errorf("updating disk passphrase: %w", err)
@ -313,11 +287,12 @@ func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse,
// We currently cannot recover from any failure in this function. Joining the k8s cluster // We currently cannot recover from any failure in this function. Joining the k8s cluster
// sometimes fails transiently, and we don't want to brick the node because of that. // sometimes fails transiently, and we don't want to brick the node because of that.
for i := 0; i < 3; i++ { for i := range 3 {
err = c.joiner.JoinCluster(ctx, btd, c.role, ticket.KubernetesComponents, c.log) err = c.joiner.JoinCluster(ctx, btd, c.role, ticket.KubernetesComponents, c.log)
if err != nil { if err == nil {
c.log.Error("failed to join k8s cluster", "role", c.role, "attempt", i, "error", err) break
} }
c.log.Error("failed to join k8s cluster", "role", c.role, "attempt", i, "error", err)
} }
if err != nil { if err != nil {
return fmt.Errorf("joining Kubernetes cluster: %w", err) return fmt.Errorf("joining Kubernetes cluster: %w", err)
@ -412,13 +387,6 @@ func (c *JoinClient) timeoutCtx() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), c.timeout) return context.WithTimeout(context.Background(), c.timeout)
} }
type unrecoverableError struct{ error }
func isUnrecoverable(err error) bool {
_, ok := err.(unrecoverableError)
return ok
}
type grpcDialer interface { type grpcDialer interface {
Dial(ctx context.Context, target string) (*grpc.ClientConn, error) Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
} }

View File

@ -8,7 +8,6 @@ package joinclient
import ( import (
"context" "context"
"errors"
"log/slog" "log/slog"
"net" "net"
"strconv" "strconv"
@ -40,7 +39,6 @@ func TestMain(m *testing.M) {
} }
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
someErr := errors.New("failed")
lockedLock := newFakeLock() lockedLock := newFakeLock()
aqcuiredLock, lockErr := lockedLock.TryLockOnce(nil) aqcuiredLock, lockErr := lockedLock.TryLockOnce(nil)
require.True(t, aqcuiredLock) require.True(t, aqcuiredLock)
@ -67,9 +65,9 @@ func TestClient(t *testing.T) {
"on worker: metadata self: errors occur": { "on worker: metadata self: errors occur": {
role: role.Worker, role: role.Worker,
apiAnswers: []any{ apiAnswers: []any{
selfAnswer{err: someErr}, selfAnswer{err: assert.AnError},
selfAnswer{err: someErr}, selfAnswer{err: assert.AnError},
selfAnswer{err: someErr}, selfAnswer{err: assert.AnError},
selfAnswer{instance: workerSelf}, selfAnswer{instance: workerSelf},
listAnswer{instances: peers}, listAnswer{instances: peers},
issueJoinTicketAnswer{}, issueJoinTicketAnswer{},
@ -100,9 +98,9 @@ func TestClient(t *testing.T) {
role: role.Worker, role: role.Worker,
apiAnswers: []any{ apiAnswers: []any{
selfAnswer{instance: workerSelf}, selfAnswer{instance: workerSelf},
listAnswer{err: someErr}, listAnswer{err: assert.AnError},
listAnswer{err: someErr}, listAnswer{err: assert.AnError},
listAnswer{err: someErr}, listAnswer{err: assert.AnError},
listAnswer{instances: peers}, listAnswer{instances: peers},
issueJoinTicketAnswer{}, issueJoinTicketAnswer{},
}, },
@ -133,9 +131,9 @@ func TestClient(t *testing.T) {
apiAnswers: []any{ apiAnswers: []any{
selfAnswer{instance: workerSelf}, selfAnswer{instance: workerSelf},
listAnswer{instances: peers}, listAnswer{instances: peers},
issueJoinTicketAnswer{err: someErr}, issueJoinTicketAnswer{err: assert.AnError},
listAnswer{instances: peers}, listAnswer{instances: peers},
issueJoinTicketAnswer{err: someErr}, issueJoinTicketAnswer{err: assert.AnError},
listAnswer{instances: peers}, listAnswer{instances: peers},
issueJoinTicketAnswer{}, issueJoinTicketAnswer{},
}, },
@ -150,9 +148,9 @@ func TestClient(t *testing.T) {
apiAnswers: []any{ apiAnswers: []any{
selfAnswer{instance: controlSelf}, selfAnswer{instance: controlSelf},
listAnswer{instances: peers}, listAnswer{instances: peers},
issueJoinTicketAnswer{err: someErr}, issueJoinTicketAnswer{err: assert.AnError},
listAnswer{instances: peers}, listAnswer{instances: peers},
issueJoinTicketAnswer{err: someErr}, issueJoinTicketAnswer{err: assert.AnError},
listAnswer{instances: peers}, listAnswer{instances: peers},
issueJoinTicketAnswer{}, issueJoinTicketAnswer{},
}, },
@ -169,7 +167,7 @@ func TestClient(t *testing.T) {
listAnswer{instances: peers}, listAnswer{instances: peers},
issueJoinTicketAnswer{}, issueJoinTicketAnswer{},
}, },
clusterJoiner: &stubClusterJoiner{numBadCalls: -1, joinClusterErr: someErr}, clusterJoiner: &stubClusterJoiner{numBadCalls: -1, joinClusterErr: assert.AnError},
nodeLock: newFakeLock(), nodeLock: newFakeLock(),
disk: &stubDisk{}, disk: &stubDisk{},
wantJoin: true, wantJoin: true,
@ -182,7 +180,7 @@ func TestClient(t *testing.T) {
listAnswer{instances: peers}, listAnswer{instances: peers},
issueJoinTicketAnswer{}, issueJoinTicketAnswer{},
}, },
clusterJoiner: &stubClusterJoiner{numBadCalls: 1, joinClusterErr: someErr}, clusterJoiner: &stubClusterJoiner{numBadCalls: 1, joinClusterErr: assert.AnError},
nodeLock: newFakeLock(), nodeLock: newFakeLock(),
disk: &stubDisk{}, disk: &stubDisk{},
wantJoin: true, wantJoin: true,
@ -205,13 +203,13 @@ func TestClient(t *testing.T) {
role: role.ControlPlane, role: role.ControlPlane,
clusterJoiner: &stubClusterJoiner{}, clusterJoiner: &stubClusterJoiner{},
nodeLock: newFakeLock(), nodeLock: newFakeLock(),
disk: &stubDisk{openErr: someErr}, disk: &stubDisk{openErr: assert.AnError},
}, },
"on control plane: disk uuid fails": { "on control plane: disk uuid fails": {
role: role.ControlPlane, role: role.ControlPlane,
clusterJoiner: &stubClusterJoiner{}, clusterJoiner: &stubClusterJoiner{},
nodeLock: newFakeLock(), nodeLock: newFakeLock(),
disk: &stubDisk{uuidErr: someErr}, disk: &stubDisk{uuidErr: assert.AnError},
}, },
} }
@ -237,6 +235,9 @@ func TestClient(t *testing.T) {
metadataAPI: metadataAPI, metadataAPI: metadataAPI,
clock: clock, clock: clock,
log: logger.NewTest(t), log: logger.NewTest(t),
stopC: make(chan struct{}, 1),
stopDone: make(chan struct{}, 1),
} }
serverCreds := atlscredentials.New(nil, nil) serverCreds := atlscredentials.New(nil, nil)
@ -248,7 +249,7 @@ func TestClient(t *testing.T) {
go joinServer.Serve(listener) go joinServer.Serve(listener)
defer joinServer.GracefulStop() defer joinServer.GracefulStop()
client.Start(stubCleaner{}) go func() { _ = client.Start(stubCleaner{}) }()
for _, a := range tc.apiAnswers { for _, a := range tc.apiAnswers {
switch a := a.(type) { switch a := a.(type) {
@ -281,78 +282,6 @@ func TestClient(t *testing.T) {
} }
} }
func TestClientConcurrentStartStop(t *testing.T) {
netDialer := testdialer.NewBufconnDialer()
dialer := dialer.New(nil, nil, netDialer)
client := &JoinClient{
nodeLock: newFakeLock(),
timeout: 30 * time.Second,
interval: 30 * time.Second,
dialer: dialer,
disk: &stubDisk{},
joiner: &stubClusterJoiner{},
fileHandler: file.NewHandler(afero.NewMemMapFs()),
metadataAPI: &stubRepeaterMetadataAPI{},
clock: testclock.NewFakeClock(time.Now()),
log: logger.NewTest(t),
}
wg := sync.WaitGroup{}
start := func() {
defer wg.Done()
client.Start(stubCleaner{})
}
stop := func() {
defer wg.Done()
client.Stop()
}
wg.Add(10)
go stop()
go start()
go start()
go stop()
go stop()
go start()
go start()
go stop()
go stop()
go start()
wg.Wait()
client.Stop()
}
func TestIsUnrecoverable(t *testing.T) {
assert := assert.New(t)
some := errors.New("failed")
unrec := unrecoverableError{some}
assert.True(isUnrecoverable(unrec))
assert.False(isUnrecoverable(some))
}
type stubRepeaterMetadataAPI struct {
selfInstance metadata.InstanceMetadata
selfErr error
listInstances []metadata.InstanceMetadata
listErr error
}
func (s *stubRepeaterMetadataAPI) Self(_ context.Context) (metadata.InstanceMetadata, error) {
return s.selfInstance, s.selfErr
}
func (s *stubRepeaterMetadataAPI) List(_ context.Context) ([]metadata.InstanceMetadata, error) {
return s.listInstances, s.listErr
}
func (s *stubRepeaterMetadataAPI) GetLoadBalancerEndpoint(_ context.Context) (string, string, error) {
return "", "", nil
}
type stubMetadataAPI struct { type stubMetadataAPI struct {
selfAnswerC chan selfAnswer selfAnswerC chan selfAnswer
listAnswerC chan listAnswer listAnswerC chan listAnswer
@ -451,6 +380,10 @@ func (d *stubDisk) UpdatePassphrase(string) error {
return d.updatePassphraseErr return d.updatePassphraseErr
} }
func (d *stubDisk) MarkDiskForReset() error {
return nil
}
type stubCleaner struct{} type stubCleaner struct{}
func (c stubCleaner) Clean() {} func (c stubCleaner) Clean() {}

View File

@ -31,7 +31,7 @@ const (
deviceName string = "testDeviceName" deviceName string = "testDeviceName"
) )
var toolsEnvs []string = []string{"CP", "DD", "RM", "FSCK_EXT4", "MKFS_EXT4", "BLKID", "FSCK", "MOUNT", "UMOUNT"} var toolsEnvs = []string{"CP", "DD", "RM", "FSCK_EXT4", "MKFS_EXT4", "BLKID", "FSCK", "MOUNT", "UMOUNT"}
// addToolsToPATH is used to update the PATH to contain necessary tool binaries for // addToolsToPATH is used to update the PATH to contain necessary tool binaries for
// coreutils, util-linux and ext4. // coreutils, util-linux and ext4.

View File

@ -37,7 +37,7 @@ const (
var diskPath = flag.String("disk", "", "Path to the disk to use for the benchmark") var diskPath = flag.String("disk", "", "Path to the disk to use for the benchmark")
var toolsEnvs []string = []string{"DD", "RM"} var toolsEnvs = []string{"DD", "RM"}
// addToolsToPATH is used to update the PATH to contain necessary tool binaries for // addToolsToPATH is used to update the PATH to contain necessary tool binaries for
// coreutils. // coreutils.