From cac3b4700065512a759f531a86b715159bb88f75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= Date: Tue, 12 Mar 2024 11:18:51 +0100 Subject: [PATCH] Move disk wiping into main package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniel Weiße --- bootstrapper/cmd/bootstrapper/run.go | 28 +++++++++++++++---- bootstrapper/internal/initserver/BUILD.bazel | 1 - .../internal/initserver/initserver.go | 22 ++++----------- .../internal/initserver/initserver_test.go | 5 +++- bootstrapper/internal/joinclient/BUILD.bazel | 1 - .../internal/joinclient/joinclient.go | 23 ++++----------- 6 files changed, 38 insertions(+), 42 deletions(-) diff --git a/bootstrapper/cmd/bootstrapper/run.go b/bootstrapper/cmd/bootstrapper/run.go index 0a5a8b273..95bd46b06 100644 --- a/bootstrapper/cmd/bootstrapper/run.go +++ b/bootstrapper/cmd/bootstrapper/run.go @@ -35,7 +35,8 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl ) { log.With(slog.String("version", constants.BinaryVersion().String())).Info("Starting bootstrapper") - uuid, err := getDiskUUID() + disk := diskencryption.New() + uuid, err := getDiskUUID(disk) if err != nil { log.With(slog.Any("error", err)).Error("Failed to get disk UUID") } else { @@ -57,14 +58,14 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl } nodeLock := nodelock.New(openDevice) - initServer, err := initserver.New(context.Background(), nodeLock, kube, issuer, fileHandler, metadata, log) + initServer, err := initserver.New(context.Background(), nodeLock, kube, issuer, disk, fileHandler, metadata, log) if err != nil { log.With(slog.Any("error", err)).Error("Failed to create init server") reboot(fmt.Errorf("creating init server: %w", err)) } dialer := dialer.New(issuer, nil, &net.Dialer{}) - joinClient := joinclient.New(nodeLock, dialer, kube, metadata, log) + joinClient := joinclient.New(nodeLock, dialer, kube, metadata, disk, log) cleaner := clean.New().With(initServer).With(joinClient) go cleaner.Start() @@ -77,6 +78,7 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl defer wg.Done() if err := joinClient.Start(cleaner); err != nil { log.With(slog.Any("error", err)).Error("Failed to join cluster") + markDiskForReset(disk) reboot(fmt.Errorf("joining cluster: %w", err)) } }() @@ -86,6 +88,7 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl defer wg.Done() if err := initServer.Serve(bindIP, bindPort, cleaner); err != nil { log.With(slog.Any("error", err)).Error("Failed to serve init server") + markDiskForReset(disk) reboot(fmt.Errorf("serving init server: %w", err)) } }() @@ -94,8 +97,7 @@ func run(issuer atls.Issuer, openDevice vtpm.TPMOpenFunc, fileHandler file.Handl log.Info("bootstrapper done") } -func getDiskUUID() (string, error) { - disk := diskencryption.New() +func getDiskUUID(disk *diskencryption.DiskEncryption) (string, error) { free, err := disk.Open() if err != nil { return "", err @@ -104,6 +106,22 @@ func getDiskUUID() (string, error) { return disk.UUID() } +// markDiskForReset sets a token in the cryptsetup header of the disk to indicate the disk should be reset on next boot. +// This is used to reset all state of a node in case the bootstrapper encountered a non recoverable error +// after the node successfully retrieved a join ticket from the JoinService. +// As setting this token is safe as long as we are certain we don't need the data on the disk anymore, we call this +// unconditionally when either the JoinClient or the InitServer encounter an error. +// We don't call it before that, as the node may be restarting after a previous, successful bootstrapping, +// and now encountered a transient error on rejoining the cluster. Wiping the disk now would delete existing data. +func markDiskForReset(disk *diskencryption.DiskEncryption) { + free, err := disk.Open() + if err != nil { + return + } + defer free() + _ = disk.MarkDiskForReset() +} + // reboot writes an error message to the system log and reboots the system. // We call this instead of os.Exit() since failures in the bootstrapper usually require a node reset. func reboot(e error) { diff --git a/bootstrapper/internal/initserver/BUILD.bazel b/bootstrapper/internal/initserver/BUILD.bazel index 009bb0594..b1d5e66ba 100644 --- a/bootstrapper/internal/initserver/BUILD.bazel +++ b/bootstrapper/internal/initserver/BUILD.bazel @@ -8,7 +8,6 @@ go_library( visibility = ["//bootstrapper:__subpackages__"], deps = [ "//bootstrapper/initproto", - "//bootstrapper/internal/diskencryption", "//bootstrapper/internal/journald", "//internal/atls", "//internal/attestation", diff --git a/bootstrapper/internal/initserver/initserver.go b/bootstrapper/internal/initserver/initserver.go index ae85f5823..a38bdbc8d 100644 --- a/bootstrapper/internal/initserver/initserver.go +++ b/bootstrapper/internal/initserver/initserver.go @@ -30,7 +30,6 @@ import ( "time" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto" - "github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/journald" "github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/attestation" @@ -77,7 +76,10 @@ type Server struct { } // New creates a new initialization server. -func New(ctx context.Context, lock locker, kube ClusterInitializer, issuer atls.Issuer, fh file.Handler, metadata MetadataAPI, log *slog.Logger) (*Server, error) { +func New( + ctx context.Context, lock locker, kube ClusterInitializer, issuer atls.Issuer, + disk encryptedDisk, fh file.Handler, metadata MetadataAPI, log *slog.Logger, +) (*Server, error) { log = log.WithGroup("initServer") initSecretHash, err := metadata.InitSecretHash(ctx) @@ -95,7 +97,7 @@ func New(ctx context.Context, lock locker, kube ClusterInitializer, issuer atls. server := &Server{ nodeLock: lock, - disk: diskencryption.New(), + disk: disk, initializer: kube, fileHandler: fh, issuer: issuer, @@ -130,8 +132,7 @@ func (s *Server) Serve(ip, port string, cleaner cleaner) error { // In this case we don't care about any potential errors from the grpc server if s.initFailure != nil { s.log.Error("Fatal error during Init request", "error", s.initFailure) - resetErr := s.markDiskForReset() - return errors.Join(s.initFailure, resetErr) + return err } return err @@ -332,15 +333,6 @@ func (s *Server) setupDisk(ctx context.Context, cloudKms kms.CloudKMS) error { return s.disk.UpdatePassphrase(string(diskKey)) } -func (s *Server) markDiskForReset() error { - free, err := s.disk.Open() - if err != nil { - return fmt.Errorf("opening disk: %w", err) - } - defer free() - return s.disk.MarkDiskForReset() -} - func deriveMeasurementValues(ctx context.Context, measurementSalt []byte, cloudKms kms.CloudKMS) (clusterID []byte, err error) { secret, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+crypto.MeasurementSecretKeyID, crypto.DerivedKeyLengthDefault) if err != nil { @@ -376,8 +368,6 @@ type encryptedDisk interface { UUID() (string, error) // UpdatePassphrase switches the initial random passphrase of the encrypted disk to a permanent passphrase. UpdatePassphrase(passphrase string) error - // MarkDiskForReset marks the state disk as not initialized so it may be wiped (reset) on reboot. - MarkDiskForReset() error } type serveStopper interface { diff --git a/bootstrapper/internal/initserver/initserver_test.go b/bootstrapper/internal/initserver/initserver_test.go index 7e22b5313..84d0316d7 100644 --- a/bootstrapper/internal/initserver/initserver_test.go +++ b/bootstrapper/internal/initserver/initserver_test.go @@ -67,7 +67,10 @@ func TestNew(t *testing.T) { t.Run(name, func(t *testing.T) { assert := assert.New(t) - server, err := New(context.TODO(), newFakeLock(), &stubClusterInitializer{}, atls.NewFakeIssuer(variant.Dummy{}), fh, &tc.metadata, logger.NewTest(t)) + server, err := New( + context.TODO(), newFakeLock(), &stubClusterInitializer{}, atls.NewFakeIssuer(variant.Dummy{}), + &stubDisk{}, fh, &tc.metadata, logger.NewTest(t), + ) if tc.wantErr { assert.Error(err) return diff --git a/bootstrapper/internal/joinclient/BUILD.bazel b/bootstrapper/internal/joinclient/BUILD.bazel index 3b8bf70b7..687ffd250 100644 --- a/bootstrapper/internal/joinclient/BUILD.bazel +++ b/bootstrapper/internal/joinclient/BUILD.bazel @@ -8,7 +8,6 @@ go_library( visibility = ["//bootstrapper:__subpackages__"], deps = [ "//bootstrapper/internal/certificate", - "//bootstrapper/internal/diskencryption", "//internal/attestation", "//internal/cloud/metadata", "//internal/constants", diff --git a/bootstrapper/internal/joinclient/joinclient.go b/bootstrapper/internal/joinclient/joinclient.go index 82fc81ef5..3e2944325 100644 --- a/bootstrapper/internal/joinclient/joinclient.go +++ b/bootstrapper/internal/joinclient/joinclient.go @@ -28,7 +28,6 @@ import ( "time" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/certificate" - "github.com/edgelesssys/constellation/v2/bootstrapper/internal/diskencryption" "github.com/edgelesssys/constellation/v2/internal/attestation" "github.com/edgelesssys/constellation/v2/internal/cloud/metadata" "github.com/edgelesssys/constellation/v2/internal/constants" @@ -77,10 +76,10 @@ type JoinClient struct { } // New creates a new JoinClient. -func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, log *slog.Logger) *JoinClient { +func New(lock locker, dial grpcDialer, joiner ClusterJoiner, meta MetadataAPI, disk encryptedDisk, log *slog.Logger) *JoinClient { return &JoinClient{ nodeLock: lock, - disk: diskencryption.New(), + disk: disk, fileHandler: file.NewHandler(afero.NewOsFs()), timeout: timeout, joinTimeout: joinTimeout, @@ -109,7 +108,7 @@ func (c *JoinClient) Start(cleaner cleaner) error { diskUUID, err := c.getDiskUUID() if err != nil { c.log.With(slog.Any("error", err)).Error("Failed to get disk UUID") - return err // unrecoverable error, but disk wasn't initialized yet + return err } c.diskUUID = diskUUID @@ -149,9 +148,8 @@ func (c *JoinClient) Start(cleaner cleaner) error { } if err := c.startNodeAndJoin(ticket, kubeletKey, cleaner); err != nil { - c.log.With(slog.Any("error", err)).Error("Failed to start node and join cluster") // unrecoverable error - resetErr := c.markDiskForReset() - return errors.Join(err, resetErr) + c.log.With(slog.Any("error", err)).Error("Failed to start node and join cluster") + return err } return nil @@ -353,15 +351,6 @@ func (c *JoinClient) getDiskUUID() (string, error) { return c.disk.UUID() } -func (c *JoinClient) markDiskForReset() error { - free, err := c.disk.Open() - if err != nil { - return fmt.Errorf("opening disk: %w", err) - } - defer free() - return c.disk.MarkDiskForReset() -} - func (c *JoinClient) getControlPlaneIPs(ctx context.Context) ([]string, error) { instances, err := c.metadataAPI.List(ctx) if err != nil { @@ -431,8 +420,6 @@ type encryptedDisk interface { UUID() (string, error) // UpdatePassphrase switches the initial random passphrase of the encrypted disk to a permanent passphrase. UpdatePassphrase(passphrase string) error - // MarkDiskForReset marks the state disk as not initialized so it may be wiped (reset) on reboot. - MarkDiskForReset() error } type cleaner interface {