mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
bootstrapper: use atomics in nodelock (#2001)
This commit is contained in:
parent
f8117b7223
commit
8ba0179137
@ -1,4 +1,5 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
load("//bazel/go:go_test.bzl", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "nodelock",
|
||||
@ -10,3 +11,13 @@ go_library(
|
||||
"//internal/attestation/vtpm",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "nodelock_test",
|
||||
srcs = ["nodelock_test.go"],
|
||||
embed = [":nodelock"],
|
||||
deps = [
|
||||
"//internal/attestation/vtpm",
|
||||
"@com_github_stretchr_testify//assert",
|
||||
],
|
||||
)
|
||||
|
@ -8,7 +8,8 @@ SPDX-License-Identifier: AGPL-3.0-only
|
||||
package nodelock
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/initialize"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
@ -22,24 +23,36 @@ import (
|
||||
// There is no way to unlock, so the state changes only once from unlock to
|
||||
// locked.
|
||||
type Lock struct {
|
||||
tpm vtpm.TPMOpenFunc
|
||||
mux *sync.Mutex
|
||||
tpm vtpm.TPMOpenFunc
|
||||
marker markAsBootstrapped
|
||||
inner atomic.Bool
|
||||
}
|
||||
|
||||
// New creates a new NodeLock, which is unlocked.
|
||||
func New(tpm vtpm.TPMOpenFunc) *Lock {
|
||||
return &Lock{
|
||||
tpm: tpm,
|
||||
mux: &sync.Mutex{},
|
||||
tpm: tpm,
|
||||
marker: initialize.MarkNodeAsBootstrapped,
|
||||
}
|
||||
}
|
||||
|
||||
// TryLockOnce tries to lock the node. If the node is already locked, it
|
||||
// returns false. If the node is unlocked, it locks it and returns true.
|
||||
func (l *Lock) TryLockOnce(clusterID []byte) (bool, error) {
|
||||
if !l.mux.TryLock() {
|
||||
// CompareAndSwap first checks if the node is currently unlocked.
|
||||
// If it was already locked, it returns early.
|
||||
// If it is unlocked, it swaps the value to locked atomically and continues.
|
||||
if !l.inner.CompareAndSwap(unlocked, locked) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, initialize.MarkNodeAsBootstrapped(l.tpm, clusterID)
|
||||
return true, l.marker(l.tpm, clusterID)
|
||||
}
|
||||
|
||||
// markAsBootstrapped is a function that marks the node as bootstrapped in the TPM.
|
||||
type markAsBootstrapped func(openDevice func() (io.ReadWriteCloser, error), clusterID []byte) error
|
||||
|
||||
const (
|
||||
unlocked = false
|
||||
locked = true
|
||||
)
|
||||
|
64
bootstrapper/internal/nodelock/nodelock_test.go
Normal file
64
bootstrapper/internal/nodelock/nodelock_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package nodelock
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTryLockOnce(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
tpm := spyDevice{}
|
||||
lock := Lock{
|
||||
tpm: tpm.Opener(),
|
||||
marker: stubMarker,
|
||||
}
|
||||
locked, err := lock.TryLockOnce(nil)
|
||||
assert.NoError(err)
|
||||
assert.True(locked)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
tryLock := func() {
|
||||
defer wg.Done()
|
||||
locked, err := lock.TryLockOnce(nil)
|
||||
assert.NoError(err)
|
||||
assert.False(locked)
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go tryLock()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.EqualValues(1, tpm.counter.Load())
|
||||
}
|
||||
|
||||
type spyDevice struct {
|
||||
counter atomic.Uint64
|
||||
}
|
||||
|
||||
func (s *spyDevice) Opener() vtpm.TPMOpenFunc {
|
||||
return func() (io.ReadWriteCloser, error) {
|
||||
s.counter.Add(1)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func stubMarker(openDevice func() (io.ReadWriteCloser, error), _ []byte) error {
|
||||
// this only needs to invoke the openDevice function
|
||||
// so that the spyTPM counter is incremented
|
||||
_, _ = openDevice()
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user