diff --git a/bootstrapper/cmd/bootstrapper/run.go b/bootstrapper/cmd/bootstrapper/run.go index 733444bee..f8a201349 100644 --- a/bootstrapper/cmd/bootstrapper/run.go +++ b/bootstrapper/cmd/bootstrapper/run.go @@ -12,6 +12,7 @@ import ( "log/slog" "net" "os" + "sync" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/clean" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption" @@ -67,12 +68,26 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl go cleaner.Start() defer cleaner.Done() - joinClient.Start(cleaner) + var wg sync.WaitGroup - if err := initServer.Serve(bindIP, bindPort, cleaner); err != nil { - log.With(slog.Any("error", err)).Error("Failed to serve init server") - os.Exit(1) - } + wg.Add(1) + go func() { + defer wg.Done() + if err := joinClient.Start(cleaner); err != nil { + log.With(slog.Any("error", err)).Error("Failed to join cluster") + os.Exit(1) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + if err := initServer.Serve(bindIP, bindPort, cleaner); err != nil { + log.With(slog.Any("error", err)).Error("Failed to serve init server") + os.Exit(1) + } + }() + wg.Wait() log.Info("bootstrapper done") } diff --git a/bootstrapper/internal/joinclient/joinclient.go b/bootstrapper/internal/joinclient/joinclient.go index 8f44fa115..a09ac5606 100644 --- a/bootstrapper/internal/joinclient/joinclient.go +++ b/bootstrapper/internal/joinclient/joinclient.go @@ -99,70 +99,70 @@ func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, l // Start starts the client routine. The client will make the needed API calls to join // the cluster with the role it receives from the metadata API. // After receiving the needed information, the node will join the cluster. -// Multiple calls of start on the same client won't start a second routine if there is -// already a routine running. -func (c *JoinClient) Start(cleaner cleaner) { +func (c *JoinClient) Start(cleaner cleaner) error { + // Locked set up section + // We need to make sure this is not executed synchronously with Stop c.mux.Lock() - defer c.mux.Unlock() - - if c.stopC != nil { // daemon already running - return - } - - c.log.Info("Starting") c.stopC = make(chan struct{}, 1) c.stopDone = make(chan struct{}, 1) c.cleaner = cleaner + c.mux.Unlock() + // End of locked set up section + c.log.Info("Starting") ticker := c.clock.NewTicker(c.interval) - go func() { - defer ticker.Stop() - defer func() { c.stopDone <- struct{}{} }() - defer c.log.Info("Client stopped") + defer ticker.Stop() + defer func() { c.stopDone <- struct{}{} }() + defer c.log.Info("Client stopped") - diskUUID, err := c.getDiskUUID() - if err != nil { - c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID") - return + diskUUID, err := c.getDiskUUID() + if err != nil { + c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID") + return err + } + c.diskUUID = diskUUID + + for { + err := c.getNodeMetadata() + if err == nil { + c.log.With(slog.String("role", c.role.String()), slog.String("name", c.nodeName)).Info("Received own instance metadata") + break } - c.diskUUID = diskUUID + c.log.With(slog.Any("error", err)).Error("Failed to retrieve instance metadata") - for { - err := c.getNodeMetadata() - if err == nil { - c.log.With(slog.String("role", c.role.String()), slog.String("name", c.nodeName)).Info("Received own instance metadata") - break - } - c.log.With(slog.Any("error", err)).Error("Failed to retrieve instance metadata") - - c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping") - select { - case <-c.stopC: - return - case <-ticker.C(): - } + c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping") + select { + case <-c.stopC: + return nil + case <-ticker.C(): } + } - for { - err := c.tryJoinWithAvailableServices() - if err == nil { - c.log.Info("Joined successfully. Client is shutting down") - return - } else if isUnrecoverable(err) { - c.log.With(slog.Any("error", err)).Error("Unrecoverable error occurred") - // TODO(burgerdev): this should eventually lead to a full node reset - return - } - c.log.With(slog.Any("error", err)).Warn("Join failed for all available endpoints") + var ticket *joinproto.IssueJoinTicketResponse + var kubeletKey []byte - c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping") - select { - case <-c.stopC: - return - case <-ticker.C(): - } + for { + ticket, kubeletKey, err = c.tryJoinWithAvailableServices() + if err == nil { + c.log.Info("Successfully retrieved join ticket, starting Kubernetes node") + break } - }() + c.log.With(slog.Any("error", err)).Warn("Join failed for all available endpoints") + + c.log.With(slog.Duration("interval", c.interval)).Info("Sleeping") + select { + case <-c.stopC: + return nil + case <-ticker.C(): + } + } + + if err := c.startNodeAndJoin(ticket, kubeletKey); err != nil { + c.log.With(slog.Any("error", err)).Error("Failed to start node and join cluster") // unrecoverable error + return err + } + + return nil } // Stop stops the client and blocks until the client's routine is stopped. @@ -185,7 +185,7 @@ func (c *JoinClient) Stop() { c.log.Info("Stopped") } -func (c *JoinClient) tryJoinWithAvailableServices() error { +func (c *JoinClient) tryJoinWithAvailableServices() (ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, err error) { ctx, cancel := c.timeoutCtx() defer cancel() @@ -193,46 +193,46 @@ func (c *JoinClient) tryJoinWithAvailableServices() error { endpoint, _, err := c.metadataAPI.GetLoadBalancerEndpoint(ctx) if err != nil { - return fmt.Errorf("failed to get load balancer endpoint: %w", err) + return nil, nil, fmt.Errorf("failed to get load balancer endpoint: %w", err) } endpoints = append(endpoints, endpoint) ips, err := c.getControlPlaneIPs(ctx) if err != nil { - return fmt.Errorf("failed to get control plane IPs: %w", err) + return nil, nil, fmt.Errorf("failed to get control plane IPs: %w", err) } endpoints = append(endpoints, ips...) if len(endpoints) == 0 { - return errors.New("no control plane IPs found") + return nil, nil, errors.New("no control plane IPs found") } + var joinErrs error for _, endpoint := range endpoints { - err = c.join(net.JoinHostPort(endpoint, strconv.Itoa(constants.JoinServiceNodePort))) + ticket, kubeletKey, err := c.requestJoinTicket(net.JoinHostPort(endpoint, strconv.Itoa(constants.JoinServiceNodePort))) if err == nil { - return nil - } - if isUnrecoverable(err) { - return err + return ticket, kubeletKey, nil } + + joinErrs = errors.Join(joinErrs, err) } - return err + return nil, nil, fmt.Errorf("trying to join on all endpoints %v: %w", endpoints, joinErrs) } -func (c *JoinClient) join(serviceEndpoint string) error { +func (c *JoinClient) requestJoinTicket(serviceEndpoint string) (ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte, err error) { ctx, cancel := c.timeoutCtx() defer cancel() certificateRequest, kubeletKey, err := certificate.GetKubeletCertificateRequest(c.nodeName, c.validIPs) if err != nil { - return err + return nil, nil, err } conn, err := c.dialer.Dial(ctx, serviceEndpoint) if err != nil { c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Join service unreachable") - return fmt.Errorf("dialing join service endpoint: %w", err) + return nil, nil, fmt.Errorf("dialing join service endpoint: %w", err) } defer conn.Close() @@ -242,26 +242,19 @@ func (c *JoinClient) join(serviceEndpoint string) error { CertificateRequest: certificateRequest, IsControlPlane: c.role == role.ControlPlane, } - ticket, err := protoClient.IssueJoinTicket(ctx, req) + ticket, err = protoClient.IssueJoinTicket(ctx, req) if err != nil { c.log.With(slog.String("endpoint", serviceEndpoint), slog.Any("error", err)).Error("Issuing join ticket failed") - return fmt.Errorf("issuing join ticket: %w", err) + return nil, nil, fmt.Errorf("issuing join ticket: %w", err) } - return c.startNodeAndJoin(ticket, kubeletKey) + return ticket, kubeletKey, err } -func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte) (retErr error) { +func (c *JoinClient) startNodeAndJoin(ticket *joinproto.IssueJoinTicketResponse, kubeletKey []byte) error { ctx, cancel := context.WithTimeout(context.Background(), c.joinTimeout) defer cancel() - // If an error occurs in this func, the client cannot continue. - defer func() { - if retErr != nil { - retErr = unrecoverableError{retErr} - } - }() - clusterID, err := attestation.DeriveClusterID(ticket.MeasurementSecret, ticket.MeasurementSalt) if err != nil { return err @@ -412,13 +405,6 @@ func (c *JoinClient) timeoutCtx() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), c.timeout) } -type unrecoverableError struct{ error } - -func isUnrecoverable(err error) bool { - _, ok := err.(unrecoverableError) - return ok -} - type grpcDialer interface { Dial(ctx context.Context, target string) (*grpc.ClientConn, error) } diff --git a/bootstrapper/internal/joinclient/joinclient_test.go b/bootstrapper/internal/joinclient/joinclient_test.go index d22ed4fb9..c9f4e048a 100644 --- a/bootstrapper/internal/joinclient/joinclient_test.go +++ b/bootstrapper/internal/joinclient/joinclient_test.go @@ -248,7 +248,7 @@ func TestClient(t *testing.T) { go joinServer.Serve(listener) defer joinServer.GracefulStop() - client.Start(stubCleaner{}) + go client.Start(stubCleaner{}) for _, a := range tc.apiAnswers { switch a := a.(type) { @@ -281,78 +281,6 @@ func TestClient(t *testing.T) { } } -func TestClientConcurrentStartStop(t *testing.T) { - netDialer := testdialer.NewBufconnDialer() - dialer := dialer.New(nil, nil, netDialer) - client := &JoinClient{ - nodeLock: newFakeLock(), - timeout: 30 * time.Second, - interval: 30 * time.Second, - dialer: dialer, - disk: &stubDisk{}, - joiner: &stubClusterJoiner{}, - fileHandler: file.NewHandler(afero.NewMemMapFs()), - metadataAPI: &stubRepeaterMetadataAPI{}, - clock: testclock.NewFakeClock(time.Now()), - log: logger.NewTest(t), - } - - wg := sync.WaitGroup{} - - start := func() { - defer wg.Done() - client.Start(stubCleaner{}) - } - - stop := func() { - defer wg.Done() - client.Stop() - } - - wg.Add(10) - go stop() - go start() - go start() - go stop() - go stop() - go start() - go start() - go stop() - go stop() - go start() - wg.Wait() - - client.Stop() -} - -func TestIsUnrecoverable(t *testing.T) { - assert := assert.New(t) - - some := errors.New("failed") - unrec := unrecoverableError{some} - assert.True(isUnrecoverable(unrec)) - assert.False(isUnrecoverable(some)) -} - -type stubRepeaterMetadataAPI struct { - selfInstance metadata.InstanceMetadata - selfErr error - listInstances []metadata.InstanceMetadata - listErr error -} - -func (s *stubRepeaterMetadataAPI) Self(_ context.Context) (metadata.InstanceMetadata, error) { - return s.selfInstance, s.selfErr -} - -func (s *stubRepeaterMetadataAPI) List(_ context.Context) ([]metadata.InstanceMetadata, error) { - return s.listInstances, s.listErr -} - -func (s *stubRepeaterMetadataAPI) GetLoadBalancerEndpoint(_ context.Context) (string, string, error) { - return "", "", nil -} - type stubMetadataAPI struct { selfAnswerC chan selfAnswer listAnswerC chan listAnswer