mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-07-11 17:39:26 -04:00
Let JoinClient return fatal errors
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
108784c580
commit
1e44c20561
3 changed files with 89 additions and 160 deletions
|
@ -12,6 +12,7 @@ import (
|
|||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/clean"
|
||||
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption"
|
||||
|
@ -67,12 +68,26 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl
|
|||
go cleaner.Start()
|
||||
defer cleaner.Done()
|
||||
|
||||
joinClient.Start(cleaner)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
if err := initServer.Serve(bindIP, bindPort, cleaner); err != nil {
|
||||
log.With(slog.Any("error", err)).Error("Failed to serve init server")
|
||||
os.Exit(1)
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := joinClient.Start(cleaner); err != nil {
|
||||
log.With(slog.Any("error", err)).Error("Failed to join cluster")
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := initServer.Serve(bindIP, bindPort, cleaner); err != nil {
|
||||
log.With(slog.Any("error", err)).Error("Failed to serve init server")
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
log.Info("bootstrapper done")
|
||||
}
|
||||
|
|
|
@ -99,70 +99,70 @@ func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, l
|
|||
// Start starts the client routine. The client will make the needed API calls to join
|
||||
// the cluster with the role it receives from the metadata API.
|
||||
// After receiving the needed information, the node will join the cluster.
|
||||
// Multiple calls of start on the same client won't start a second routine if there is
|
||||
// already a routine running.
|
||||
func (c *JoinClient) Start(cleaner cleaner) {
|
||||
func (c *JoinClient) Start(cleaner cleaner) error {
|
||||
// Locked set up section
|
||||
// We need to make sure this is not executed synchronously with Stop
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
if c.stopC != nil { // daemon already running
|
||||
return
|
||||
}
|
||||
|
||||
c.log.Info("Starting")
|
||||
c.stopC = make(chan struct{}, 1)
|
||||
c.stopDone = make(chan struct{}, 1)
|
||||
c.cleaner = cleaner
|
||||
c.mux.Unlock()
|
||||
// End of locked set up section
|
||||
|
||||
c.log.Info("Starting")
|
||||
ticker := c.clock.NewTicker(c.interval)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
defer func() { c.stopDone <- struct{}{} }()
|
||||
defer c.log.Info("Client stopped")
|
||||
defer ticker.Stop()
|
||||
defer func() { c.stopDone <- struct{}{} }()
|
||||
defer c.log.Info("Client stopped")
|
||||
|
||||
diskUUID, err := c.getDiskUUID()
|
||||
if err != nil {
|
||||
c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID")
|
||||
return
|
||||
diskUUID, err := c.getDiskUUID()
|
||||
if err != nil {
|
||||
c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID")
|
||||
return err
|
||||
}
|
||||
c.diskUUID = diskUUID
|
||||
|
||||
for {
|
||||
err := c.getNodeMetadata()
|
||||
if err == nil {
|
||||
c.log.With(slog.String("role", c.role.String()), slog.String("name", c.nodeName)).Info("Received own instance metadata")
|
||||
break
|
||||
}
|
||||
c.diskUUID = diskUUID
|
||||
c.log.With(slog.Any("error", err)).Error("Failed to retrieve instance metadata")
|
||||
|
||||
for {
|
||||
err := c.getNodeMetadata()
|
||||
if err == nil {
|
||||
c.log.With(slog.String("role", c.role.String()), slog.String("name", c.nodeName)).Info("Received own instance metadata")
|
||||
break
|
||||
}
|
||||
c.log.With(slog.Any("error", err)).Error("Failed to retrieve instance metadata")
|
||||
|
||||
c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping")
|
||||
select {
|
||||
case <-c.stopC:
|
||||
return
|
||||
case <-ticker.C():
|
||||
}
|
||||
c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping")
|
||||
select {
|
||||
case <-c.stopC:
|
||||
return nil
|
||||
case <-ticker.C():
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
err := c.tryJoinWithAvailableServices()
|
||||
if err == nil {
|
||||
c.log.Info("Joined successfully. Client is shutting down")
|
||||
return
|
||||
} else if isUnrecoverable(err) {
|
||||
c.log.With(slog.Any("error", err)).Error("Unrecoverable error occurred")
|
||||
// TODO(burgerdev): this should eventually lead to a full node reset
|
||||
return
|
||||
}
|
||||
c.log.With(slog.Any("error", err)).Warn("Join failed for all available endpoints")
|
||||
var ticket *joinproto.IssueJoinTicketResponse
|
||||
var kubeletKey []byte
|
||||
|
||||
c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping")
|
||||
select {
|
||||
case <-c.stopC:
|
||||
return
|
||||
case <-ticker.C():
|
||||
}
|
||||
for {
|
||||
ticket, kubeletKey, err = c.tryJoinWithAvailableServices()
|
||||
if err == nil {
|
||||
c.log.Info("Successfully retrieved join ticket, starting Kubernetes node")
|
||||
break
|
||||
}
|
||||
}()
|
||||
c.log.With(slog.Any("error", err)).Warn("Join failed for all available endpoints")
|
||||
|
||||
c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping")
|
||||
select {
|
||||
case <-c.stopC:
|
||||
return nil
|
||||
case <-ticker.C():
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.startNodeAndJoin(ticket, kubeletKey); err != nil {
|
||||
c.log.With(slog.Any("error", err)).Error("Failed to start node and join cluster") // unrecoverable error
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the client and blocks until the client's routine is stopped.
|
||||
|
@ -185,7 +185,7 @@ func (c *JoinClient) Stop() {
|
|||
c.log.Info("Stopped")
|
||||
}
|
||||
|
||||
func (c *JoinClient) tryJoinWithAvailableServices() error {
|
||||
func (c *JoinClient) tryJoinWithAvailableServices() (ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, err error) {
|
||||
ctx, cancel := c.timeoutCtx()
|
||||
defer cancel()
|
||||
|
||||
|
@ -193,46 +193,46 @@ func (c *JoinClient) tryJoinWithAvailableServices() error {
|
|||
|
||||
endpoint, _, err := c.metadataAPI.GetLoadBalancerEndpoint(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get load balancer endpoint: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to get load balancer endpoint: %w", err)
|
||||
}
|
||||
endpoints = append(endpoints, endpoint)
|
||||
|
||||
ips, err := c.getControlPlaneIPs(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get control plane IPs: %w", err)
|
||||
return nil, nil, fmt.Errorf("failed to get control plane IPs: %w", err)
|
||||
}
|
||||
endpoints = append(endpoints, ips...)
|
||||
|
||||
if len(endpoints) == 0 {
|
||||
return errors.New("no control plane IPs found")
|
||||
return nil, nil, errors.New("no control plane IPs found")
|
||||
}
|
||||
|
||||
var joinErrs error
|
||||
for _, endpoint := range endpoints {
|
||||
err = c.join(net.JoinHostPort(endpoint, strconv.Itoa(constants.JoinServiceNodePort)))
|
||||
ticket, kubeletKey, err := c.requestJoinTicket(net.JoinHostPort(endpoint, strconv.Itoa(constants.JoinServiceNodePort)))
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if isUnrecoverable(err) {
|
||||
return err
|
||||
return ticket, kubeletKey, nil
|
||||
}
|
||||
|
||||
joinErrs = errors.Join(joinErrs, err)
|
||||
}
|
||||
|
||||
return err
|
||||
return nil, nil, fmt.Errorf("trying to join on all endpoints %v: %w", endpoints, joinErrs)
|
||||
}
|
||||
|
||||
func (c *JoinClient) join(serviceEndpoint string) error {
|
||||
func (c *JoinClient) requestJoinTicket(serviceEndpoint string) (ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, err error) {
|
||||
ctx, cancel := c.timeoutCtx()
|
||||
defer cancel()
|
||||
|
||||
certificateRequest, kubeletKey, err := certificate.GetKubeletCertificateRequest(c.nodeName, c.validIPs)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
conn, err := c.dialer.Dial(ctx, serviceEndpoint)
|
||||
if err != nil {
|
||||
c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Join service unreachable")
|
||||
return fmt.Errorf("dialing join service endpoint: %w", err)
|
||||
return nil, nil, fmt.Errorf("dialing join service endpoint: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
|
@ -242,26 +242,19 @@ func (c *JoinClient) join(serviceEndpoint string) error {
|
|||
CertificateRequest: certificateRequest,
|
||||
IsControlPlane: c.role == role.ControlPlane,
|
||||
}
|
||||
ticket, err := protoClient.IssueJoinTicket(ctx, req)
|
||||
ticket, err = protoClient.IssueJoinTicket(ctx, req)
|
||||
if err != nil {
|
||||
c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Issuing join ticket failed")
|
||||
return fmt.Errorf("issuing join ticket: %w", err)
|
||||
return nil, nil, fmt.Errorf("issuing join ticket: %w", err)
|
||||
}
|
||||
|
||||
return c.startNodeAndJoin(ticket, kubeletKey)
|
||||
return ticket, kubeletKey, err
|
||||
}
|
||||
|
||||
func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte) (retErr error) {
|
||||
func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.joinTimeout)
|
||||
defer cancel()
|
||||
|
||||
// If an error occurs in this func, the client cannot continue.
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
retErr = unrecoverableError{retErr}
|
||||
}
|
||||
}()
|
||||
|
||||
clusterID, err := attestation.DeriveClusterID(ticket.MeasurementSecret, ticket.MeasurementSalt)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -412,13 +405,6 @@ func (c *JoinClient) timeoutCtx() (context.Context, context.CancelFunc) {
|
|||
return context.WithTimeout(context.Background(), c.timeout)
|
||||
}
|
||||
|
||||
type unrecoverableError struct{ error }
|
||||
|
||||
func isUnrecoverable(err error) bool {
|
||||
_, ok := err.(unrecoverableError)
|
||||
return ok
|
||||
}
|
||||
|
||||
type grpcDialer interface {
|
||||
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||
}
|
||||
|
|
|
@ -248,7 +248,7 @@ func TestClient(t *testing.T) {
|
|||
go joinServer.Serve(listener)
|
||||
defer joinServer.GracefulStop()
|
||||
|
||||
client.Start(stubCleaner{})
|
||||
go client.Start(stubCleaner{})
|
||||
|
||||
for _, a := range tc.apiAnswers {
|
||||
switch a := a.(type) {
|
||||
|
@ -281,78 +281,6 @@ func TestClient(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestClientConcurrentStartStop(t *testing.T) {
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := dialer.New(nil, nil, netDialer)
|
||||
client := &JoinClient{
|
||||
nodeLock: newFakeLock(),
|
||||
timeout: 30 * time.Second,
|
||||
interval: 30 * time.Second,
|
||||
dialer: dialer,
|
||||
disk: &stubDisk{},
|
||||
joiner: &stubClusterJoiner{},
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
metadataAPI: &stubRepeaterMetadataAPI{},
|
||||
clock: testclock.NewFakeClock(time.Now()),
|
||||
log: logger.NewTest(t),
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
start := func() {
|
||||
defer wg.Done()
|
||||
client.Start(stubCleaner{})
|
||||
}
|
||||
|
||||
stop := func() {
|
||||
defer wg.Done()
|
||||
client.Stop()
|
||||
}
|
||||
|
||||
wg.Add(10)
|
||||
go stop()
|
||||
go start()
|
||||
go start()
|
||||
go stop()
|
||||
go stop()
|
||||
go start()
|
||||
go start()
|
||||
go stop()
|
||||
go stop()
|
||||
go start()
|
||||
wg.Wait()
|
||||
|
||||
client.Stop()
|
||||
}
|
||||
|
||||
func TestIsUnrecoverable(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
some := errors.New("failed")
|
||||
unrec := unrecoverableError{some}
|
||||
assert.True(isUnrecoverable(unrec))
|
||||
assert.False(isUnrecoverable(some))
|
||||
}
|
||||
|
||||
type stubRepeaterMetadataAPI struct {
|
||||
selfInstance metadata.InstanceMetadata
|
||||
selfErr error
|
||||
listInstances []metadata.InstanceMetadata
|
||||
listErr error
|
||||
}
|
||||
|
||||
func (s *stubRepeaterMetadataAPI) Self(_ context.Context) (metadata.InstanceMetadata, error) {
|
||||
return s.selfInstance, s.selfErr
|
||||
}
|
||||
|
||||
func (s *stubRepeaterMetadataAPI) List(_ context.Context) ([]metadata.InstanceMetadata, error) {
|
||||
return s.listInstances, s.listErr
|
||||
}
|
||||
|
||||
func (s *stubRepeaterMetadataAPI) GetLoadBalancerEndpoint(_ context.Context) (string, string, error) {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
type stubMetadataAPI struct {
|
||||
selfAnswerC chan selfAnswer
|
||||
listAnswerC chan listAnswer
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue