constellation/internal/constellation/applyinit.go
2024-05-15 16:17:12 +02:00

307 lines
9.5 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package constellation
import (
"context"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"net/url"
"strconv"
"sync"
"time"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/constellation/state"
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/versions"
"google.golang.org/grpc"
"k8s.io/client-go/tools/clientcmd"
)
// InitPayload contains the configurable data for the init RPC.
type InitPayload struct {
MasterSecret uri.MasterSecret
MeasurementSalt []byte
K8sVersion versions.ValidK8sVersion
ConformanceMode bool
ServiceCIDR string
}
// GrpcDialer dials a gRPC server.
type GrpcDialer interface {
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
}
// Init performs the init RPC.
func (a *Applier) Init(
ctx context.Context,
validator atls.Validator,
state *state.State,
clusterLogWriter io.Writer,
payload InitPayload,
) (
InitOutput,
error,
) {
// Prepare the Request
req := &initproto.InitRequest{
KmsUri: payload.MasterSecret.EncodeToURI(),
StorageUri: uri.NoStoreURI,
MeasurementSalt: payload.MeasurementSalt,
KubernetesVersion: versions.VersionConfigs[payload.K8sVersion].ClusterVersion,
KubernetesComponents: versions.VersionConfigs[payload.K8sVersion].KubernetesComponents,
ConformanceMode: payload.ConformanceMode,
InitSecret: state.Infrastructure.InitSecret,
ClusterName: state.Infrastructure.Name,
ApiserverCertSans: state.Infrastructure.APIServerCertSANs,
ServiceCidr: payload.ServiceCIDR,
}
doer := &initDoer{
dialer: a.newDialer(validator),
endpoint: net.JoinHostPort(
state.Infrastructure.ClusterEndpoint,
strconv.Itoa(constants.BootstrapperPort),
),
req: req,
log: a.log,
clusterLogWriter: clusterLogWriter,
spinner: a.spinner,
}
// Create a wrapper function that allows logging any returned error from the retrier before checking if it's the expected retriable one.
serviceIsUnavailable := func(err error) bool {
isServiceUnavailable := grpcRetry.ServiceIsUnavailable(err)
a.log.Debug(fmt.Sprintf("Encountered error (retriable: %t): %q", isServiceUnavailable, err))
return isServiceUnavailable
}
// Perform the RPC
a.log.Debug("Initialization call", "endpoint", doer.endpoint)
a.spinner.Start("Connecting ", false)
retrier := retry.NewIntervalRetrier(doer, 30*time.Second, serviceIsUnavailable)
if err := retrier.Do(ctx); err != nil {
return InitOutput{}, fmt.Errorf("doing init call: %w", err)
}
a.spinner.Stop()
a.log.Debug("Initialization request finished")
a.log.Debug(fmt.Sprintf("Rewriting cluster server address in kubeconfig to %q", state.Infrastructure.ClusterEndpoint))
kubeconfig, err := clientcmd.Load(doer.resp.Kubeconfig)
if err != nil {
return InitOutput{}, fmt.Errorf("loading kubeconfig: %w", err)
}
if len(kubeconfig.Clusters) != 1 {
return InitOutput{}, fmt.Errorf("expected exactly one cluster in kubeconfig, got %d", len(kubeconfig.Clusters))
}
for _, cluster := range kubeconfig.Clusters {
kubeEndpoint, err := url.Parse(cluster.Server)
if err != nil {
return InitOutput{}, fmt.Errorf("parsing kubeconfig server URL: %w", err)
}
kubeEndpoint.Host = net.JoinHostPort(state.Infrastructure.ClusterEndpoint, kubeEndpoint.Port())
cluster.Server = kubeEndpoint.String()
}
kubeconfigBytes, err := clientcmd.Write(*kubeconfig)
if err != nil {
return InitOutput{}, fmt.Errorf("writing kubeconfig: %w", err)
}
return InitOutput{
ClusterID: hex.EncodeToString(doer.resp.ClusterId),
OwnerID: hex.EncodeToString(doer.resp.OwnerId),
Kubeconfig: kubeconfigBytes,
}, nil
}
// InitOutput contains the output of the init RPC.
type InitOutput struct {
// ClusterID is the hex encoded ID of the cluster.
ClusterID string
// OwnerID is the hex encoded ID of the owner of the cluster.
OwnerID string
// Kubeconfig is the kubeconfig for the cluster.
Kubeconfig []byte
}
// the initDoer performs the actual init RPC with retry logic.
type initDoer struct {
dialer GrpcDialer
endpoint string
req *initproto.InitRequest
log debugLog
connectedOnce bool
spinner spinnerInterf
// clusterLogWriter is the writer to which the cluster logs are written.
clusterLogWriter io.Writer
// Read-Only-fields:
// resp is the response returned upon successful initialization.
resp *initproto.InitSuccessResponse
}
type spinnerInterf interface {
Start(text string, showDots bool)
Stop()
io.Writer
}
// Do performs the init gRPC call.
func (d *initDoer) Do(ctx context.Context) error {
// connectedOnce is set in handleGRPCStateChanges when a connection was established in one retry attempt.
// This should cancel any other retry attempts when the connection is lost since the bootstrapper likely won't accept any new attempts anymore.
if d.connectedOnce {
return &NonRetriableInitError{
LogCollectionErr: errors.New("init already connected to the remote server in a previous attempt - resumption is not supported"),
Err: errors.New("init already connected to the remote server in a previous attempt - resumption is not supported"),
}
}
conn, err := d.dialer.Dial(ctx, d.endpoint)
if err != nil {
d.log.Debug(fmt.Sprintf("Dialing init server failed: %q. Retrying...", err))
return fmt.Errorf("dialing init server: %w", err)
}
defer conn.Close()
var wg sync.WaitGroup
defer wg.Wait()
grpcStateLogCtx, grpcStateLogCancel := context.WithCancel(ctx)
defer grpcStateLogCancel()
d.handleGRPCStateChanges(grpcStateLogCtx, &wg, conn)
protoClient := initproto.NewAPIClient(conn)
d.log.Debug("Created protoClient")
resp, err := protoClient.Init(ctx, d.req)
if err != nil {
return &NonRetriableInitError{
LogCollectionErr: errors.New("rpc failed before first response was received - no logs available"),
Err: fmt.Errorf("init call: %w", err),
}
}
res, err := resp.Recv() // get first response, either success or failure
if err != nil {
if e := d.getLogs(resp); e != nil {
d.log.Debug(fmt.Sprintf("Failed to collect logs: %q", e))
return &NonRetriableInitError{
LogCollectionErr: e,
Err: err,
}
}
return &NonRetriableInitError{Err: err}
}
switch res.Kind.(type) {
case *initproto.InitResponse_InitSuccess:
d.resp = res.GetInitSuccess()
case *initproto.InitResponse_InitFailure:
if e := d.getLogs(resp); e != nil {
d.log.Debug(fmt.Sprintf("Failed to get logs from cluster: %q", e))
return &NonRetriableInitError{
LogCollectionErr: e,
Err: errors.New(res.GetInitFailure().GetError()),
}
}
return &NonRetriableInitError{Err: errors.New(res.GetInitFailure().GetError())}
case nil:
d.log.Debug("Cluster returned nil response type")
err = errors.New("empty response from cluster")
if e := d.getLogs(resp); e != nil {
d.log.Debug(fmt.Sprintf("Failed to collect logs: %q", e))
return &NonRetriableInitError{
LogCollectionErr: e,
Err: err,
}
}
return &NonRetriableInitError{Err: err}
default:
d.log.Debug("Cluster returned unknown response type")
err = errors.New("unknown response from cluster")
if e := d.getLogs(resp); e != nil {
d.log.Debug(fmt.Sprintf("Failed to collect logs: %q", e))
return &NonRetriableInitError{
LogCollectionErr: e,
Err: err,
}
}
return &NonRetriableInitError{Err: err}
}
return nil
}
// getLogs retrieves the cluster logs from the bootstrapper and saves them in the initDoer.
func (d *initDoer) getLogs(resp initproto.API_InitClient) error {
d.log.Debug("Attempting to collect cluster logs")
for {
res, err := resp.Recv()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("receiving logs: %w", err)
}
switch res.Kind.(type) {
case *initproto.InitResponse_InitFailure:
return errors.New("trying to collect logs: received init failure response, expected log response")
case *initproto.InitResponse_InitSuccess:
return errors.New("trying to collect logs: received init success response, expected log response")
case nil:
return errors.New("trying to collect logs: received nil response, expected log response")
}
log := res.GetLog().GetLog()
if log == nil {
return errors.New("received empty logs")
}
if _, err := d.clusterLogWriter.Write(log); err != nil {
return fmt.Errorf("writing logs: %w", err)
}
}
d.log.Debug("Received cluster logs")
return nil
}
func (d *initDoer) handleGRPCStateChanges(ctx context.Context, wg *sync.WaitGroup, conn *grpc.ClientConn) {
grpclog.LogStateChangesUntilReady(ctx, conn, d.log, wg, func() {
d.connectedOnce = true
d.spinner.Stop()
d.spinner.Start("Initializing cluster ", false)
})
}
// NonRetriableInitError is returned when the init RPC fails and the error is not retriable.
type NonRetriableInitError struct {
LogCollectionErr error
Err error
}
// Error returns the error message.
func (e *NonRetriableInitError) Error() string {
return e.Err.Error()
}
// Unwrap returns the wrapped error.
func (e *NonRetriableInitError) Unwrap() error {
return e.Err
}