diff --git a/bootstrapper/internal/nodelock/BUILD.bazel b/bootstrapper/internal/nodelock/BUILD.bazel index 3efb9c78f..22ee8ee98 100644 --- a/bootstrapper/internal/nodelock/BUILD.bazel +++ b/bootstrapper/internal/nodelock/BUILD.bazel @@ -1,4 +1,5 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//bazel/go:go_test.bzl", "go_test") go_library( name = "nodelock", @@ -10,3 +11,13 @@ go_library( "//internal/attestation/vtpm", ], ) + +go_test( + name = "nodelock_test", + srcs = ["nodelock_test.go"], + embed = [":nodelock"], + deps = [ + "//internal/attestation/vtpm", + "@com_github_stretchr_testify//assert", + ], +) diff --git a/bootstrapper/internal/nodelock/nodelock.go b/bootstrapper/internal/nodelock/nodelock.go index 0e4053eda..2a3865c8d 100644 --- a/bootstrapper/internal/nodelock/nodelock.go +++ b/bootstrapper/internal/nodelock/nodelock.go @@ -8,7 +8,8 @@ SPDX-License-Identifier: AGPL-3.0-only package nodelock import ( - "sync" + "io" + "sync/atomic" "github.com/edgelesssys/constellation/v2/internal/attestation/initialize" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" @@ -22,24 +23,36 @@ import ( // There is no way to unlock, so the state changes only once from unlock to // locked. type Lock struct { - tpm vtpm.TPMOpenFunc - mux *sync.Mutex + tpm vtpm.TPMOpenFunc + marker markAsBootstrapped + inner atomic.Bool } // New creates a new NodeLock, which is unlocked. func New(tpm vtpm.TPMOpenFunc) *Lock { return &Lock{ - tpm: tpm, - mux: &sync.Mutex{}, + tpm: tpm, + marker: initialize.MarkNodeAsBootstrapped, } } // TryLockOnce tries to lock the node. If the node is already locked, it // returns false. If the node is unlocked, it locks it and returns true. func (l *Lock) TryLockOnce(clusterID []byte) (bool, error) { - if !l.mux.TryLock() { + // CompareAndSwap first checks if the node is currently unlocked. + // If it was already locked, it returns early. + // If it is unlocked, it swaps the value to locked atomically and continues. + if !l.inner.CompareAndSwap(unlocked, locked) { return false, nil } - return true, initialize.MarkNodeAsBootstrapped(l.tpm, clusterID) + return true, l.marker(l.tpm, clusterID) } + +// markAsBootstrapped is a function that marks the node as bootstrapped in the TPM. +type markAsBootstrapped func(openDevice func() (io.ReadWriteCloser, error), clusterID []byte) error + +const ( + unlocked = false + locked = true +) diff --git a/bootstrapper/internal/nodelock/nodelock_test.go b/bootstrapper/internal/nodelock/nodelock_test.go new file mode 100644 index 000000000..c5738fec1 --- /dev/null +++ b/bootstrapper/internal/nodelock/nodelock_test.go @@ -0,0 +1,64 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package nodelock + +import ( + "io" + "sync" + "sync/atomic" + "testing" + + "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" + "github.com/stretchr/testify/assert" +) + +func TestTryLockOnce(t *testing.T) { + assert := assert.New(t) + tpm := spyDevice{} + lock := Lock{ + tpm: tpm.Opener(), + marker: stubMarker, + } + locked, err := lock.TryLockOnce(nil) + assert.NoError(err) + assert.True(locked) + + wg := sync.WaitGroup{} + tryLock := func() { + defer wg.Done() + locked, err := lock.TryLockOnce(nil) + assert.NoError(err) + assert.False(locked) + } + + for i := 0; i < 10; i++ { + wg.Add(1) + go tryLock() + } + + wg.Wait() + + assert.EqualValues(1, tpm.counter.Load()) +} + +type spyDevice struct { + counter atomic.Uint64 +} + +func (s *spyDevice) Opener() vtpm.TPMOpenFunc { + return func() (io.ReadWriteCloser, error) { + s.counter.Add(1) + return nil, nil + } +} + +func stubMarker(openDevice func() (io.ReadWriteCloser, error), _ []byte) error { + // this only needs to invoke the openDevice function + // so that the spyTPM counter is incremented + _, _ = openDevice() + return nil +}