constellation/disk-mapper/internal/rejoinclient/rejoinclient.go

197 lines
5.9 KiB
Go
Raw Normal View History

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
/*
Package rejoinclient handles the automatic rejoining of a restarting node.
It does so by continuously sending rejoin requests to the JoinService of available control-plane endpoints.
*/
package rejoinclient
import (
"context"
"errors"
"fmt"
2024-02-08 09:20:01 -05:00
"log/slog"
"net"
"strconv"
"time"
2022-09-21 07:47:57 -04:00
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/constants"
2022-09-21 07:47:57 -04:00
"github.com/edgelesssys/constellation/v2/internal/role"
"github.com/edgelesssys/constellation/v2/joinservice/joinproto"
"google.golang.org/grpc"
"k8s.io/utils/clock"
)
const (
interval = 30 * time.Second
timeout = 30 * time.Second
)
// RejoinClient is a client for requesting the needed information
// for rejoining a cluster as a restarting worker or control-plane node.
type RejoinClient struct {
diskUUID string
nodeInfo metadata.InstanceMetadata
timeout time.Duration
interval time.Duration
clock clock.WithTicker
dialer grpcDialer
metadataAPI metadataAPI
2024-02-08 09:20:01 -05:00
log *slog.Logger
}
// New returns a new RejoinClient.
func New(dial grpcDialer, nodeInfo metadata.InstanceMetadata,
2024-02-08 09:20:01 -05:00
meta metadataAPI, log *slog.Logger,
) *RejoinClient {
return &RejoinClient{
nodeInfo: nodeInfo,
timeout: timeout,
interval: interval,
clock: clock.RealClock{},
dialer: dial,
metadataAPI: meta,
log: log,
}
}
// Start starts the rejoin client.
// The client will continuously request available control-plane endpoints
// from the metadata API and send rejoin requests to them.
// The function returns after a successful rejoin request has been performed.
func (c *RejoinClient) Start(ctx context.Context, diskUUID string) (diskKey, measurementSecret []byte) {
2024-02-08 09:20:01 -05:00
c.log.Info("Starting RejoinClient")
c.diskUUID = diskUUID
ticker := c.clock.NewTicker(c.interval)
defer ticker.Stop()
2024-02-08 09:20:01 -05:00
defer c.log.Info("RejoinClient stopped")
for {
endpoints, err := c.getJoinEndpoints()
if err != nil {
2024-02-08 09:20:01 -05:00
c.log.With(slog.Any("error", err)).Error("Failed to get control-plane endpoints")
} else {
2024-02-08 09:20:01 -05:00
c.log.With(slog.Any("endpoints", endpoints)).Info("Received list with JoinService endpoints")
diskKey, measurementSecret, err = c.tryRejoinWithAvailableServices(ctx, endpoints)
if err == nil {
2024-02-08 09:20:01 -05:00
c.log.Info("Successfully retrieved rejoin ticket")
return diskKey, measurementSecret
}
}
select {
case <-ctx.Done():
return nil, nil
case <-ticker.C():
}
}
}
// tryRejoinWithAvailableServices tries sending rejoin requests to the available endpoints.
func (c *RejoinClient) tryRejoinWithAvailableServices(ctx context.Context, endpoints []string) (diskKey, measurementSecret []byte, err error) {
for _, endpoint := range endpoints {
2024-02-08 09:20:01 -05:00
c.log.With(slog.String("endpoint", endpoint)).Info("Requesting rejoin ticket")
rejoinTicket, err := c.requestRejoinTicket(endpoint)
if err == nil {
return rejoinTicket.StateDiskKey, rejoinTicket.MeasurementSecret, nil
}
2024-02-08 09:20:01 -05:00
c.log.With(slog.Any("error", err), slog.String("endpoint", endpoint)).Warn("Failed to rejoin on endpoint")
// stop requesting additional endpoints if the context is done
select {
case <-ctx.Done():
return nil, nil, ctx.Err()
default:
}
}
2024-02-08 09:20:01 -05:00
c.log.Error("Failed to rejoin on all endpoints")
return nil, nil, errors.New("failed to join on all endpoints")
}
// requestRejoinTicket requests a rejoin ticket from the endpoint.
func (c *RejoinClient) requestRejoinTicket(endpoint string) (*joinproto.IssueRejoinTicketResponse, error) {
ctx, cancel := c.timeoutCtx()
defer cancel()
conn, err := c.dialer.Dial(endpoint)
if err != nil {
return nil, err
}
defer conn.Close()
return joinproto.NewAPIClient(conn).IssueRejoinTicket(ctx, &joinproto.IssueRejoinTicketRequest{DiskUuid: c.diskUUID})
}
// getJoinEndpoints requests the available control-plane endpoints from the metadata API.
// The list is filtered to remove *this* node if it is a restarting control-plane node.
// Furthermore, the load balancer's endpoint is added.
func (c *RejoinClient) getJoinEndpoints() ([]string, error) {
ctx, cancel := c.timeoutCtx()
defer cancel()
joinEndpoints := []string{}
lbEndpoint, _, err := c.metadataAPI.GetLoadBalancerEndpoint(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving load balancer endpoint from cloud provider: %w", err)
}
joinEndpoints = append(joinEndpoints, net.JoinHostPort(lbEndpoint, strconv.Itoa(constants.JoinServiceNodePort)))
instances, err := c.metadataAPI.List(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving instances list from cloud provider: %w", err)
}
for _, instance := range instances {
if instance.Role == role.ControlPlane {
if instance.VPCIP != "" {
joinEndpoints = append(joinEndpoints, net.JoinHostPort(instance.VPCIP, strconv.Itoa(constants.JoinServiceNodePort)))
}
}
}
if c.nodeInfo.Role == role.ControlPlane {
return removeSelfFromEndpoints(c.nodeInfo.VPCIP, joinEndpoints), nil
}
return joinEndpoints, nil
}
// removeSelfFromEndpoints removes *this* node from the list of endpoints.
// If an error occurs, the entry is removed from the list of endpoints.
func removeSelfFromEndpoints(self string, endpoints []string) []string {
var result []string
for _, endpoint := range endpoints {
host, _, err := net.SplitHostPort(endpoint)
if err == nil && host != self {
result = append(result, endpoint)
}
}
return result
}
func (c *RejoinClient) timeoutCtx() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), c.timeout)
}
type grpcDialer interface {
Dial(target string) (*grpc.ClientConn, error)
}
type metadataAPI interface {
// List retrieves all instances belonging to the current constellation.
List(ctx context.Context) ([]metadata.InstanceMetadata, error)
// GetLoadBalancerEndpoint retrieves the load balancer endpoint.
GetLoadBalancerEndpoint(ctx context.Context) (host, port string, err error)
}