bootstrapper: use atomics in nodelock (#2001)

This commit is contained in:
Malte Poll 2023-07-04 16:26:37 +02:00 committed by GitHub
parent f8117b7223
commit 8ba0179137
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 7 deletions

View File

@ -1,4 +1,5 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//bazel/go:go_test.bzl", "go_test")
go_library(
name = "nodelock",
@ -10,3 +11,13 @@ go_library(
"//internal/attestation/vtpm",
],
)
go_test(
name = "nodelock_test",
srcs = ["nodelock_test.go"],
embed = [":nodelock"],
deps = [
"//internal/attestation/vtpm",
"@com_github_stretchr_testify//assert",
],
)

View File

@ -8,7 +8,8 @@ SPDX-License-Identifier: AGPL-3.0-only
package nodelock
import (
"sync"
"io"
"sync/atomic"
"github.com/edgelesssys/constellation/v2/internal/attestation/initialize"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
@ -22,24 +23,36 @@ import (
// There is no way to unlock, so the state changes only once from unlock to
// locked.
type Lock struct {
tpm vtpm.TPMOpenFunc
mux *sync.Mutex
tpm vtpm.TPMOpenFunc
marker markAsBootstrapped
inner atomic.Bool
}
// New creates a new NodeLock, which is unlocked.
func New(tpm vtpm.TPMOpenFunc) *Lock {
return &Lock{
tpm: tpm,
mux: &sync.Mutex{},
tpm: tpm,
marker: initialize.MarkNodeAsBootstrapped,
}
}
// TryLockOnce tries to lock the node. If the node is already locked, it
// returns false. If the node is unlocked, it locks it and returns true.
func (l *Lock) TryLockOnce(clusterID []byte) (bool, error) {
if !l.mux.TryLock() {
// CompareAndSwap first checks if the node is currently unlocked.
// If it was already locked, it returns early.
// If it is unlocked, it swaps the value to locked atomically and continues.
if !l.inner.CompareAndSwap(unlocked, locked) {
return false, nil
}
return true, initialize.MarkNodeAsBootstrapped(l.tpm, clusterID)
return true, l.marker(l.tpm, clusterID)
}
// markAsBootstrapped is a function that marks the node as bootstrapped in the TPM.
type markAsBootstrapped func(openDevice func() (io.ReadWriteCloser, error), clusterID []byte) error
const (
unlocked = false
locked = true
)

View File

@ -0,0 +1,64 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package nodelock
import (
"io"
"sync"
"sync/atomic"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/stretchr/testify/assert"
)
func TestTryLockOnce(t *testing.T) {
assert := assert.New(t)
tpm := spyDevice{}
lock := Lock{
tpm: tpm.Opener(),
marker: stubMarker,
}
locked, err := lock.TryLockOnce(nil)
assert.NoError(err)
assert.True(locked)
wg := sync.WaitGroup{}
tryLock := func() {
defer wg.Done()
locked, err := lock.TryLockOnce(nil)
assert.NoError(err)
assert.False(locked)
}
for i := 0; i < 10; i++ {
wg.Add(1)
go tryLock()
}
wg.Wait()
assert.EqualValues(1, tpm.counter.Load())
}
type spyDevice struct {
counter atomic.Uint64
}
func (s *spyDevice) Opener() vtpm.TPMOpenFunc {
return func() (io.ReadWriteCloser, error) {
s.counter.Add(1)
return nil, nil
}
}
func stubMarker(openDevice func() (io.ReadWriteCloser, error), _ []byte) error {
// this only needs to invoke the openDevice function
// so that the spyTPM counter is incremented
_, _ = openDevice()
return nil
}