mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
bootstrapper: use atomics in nodelock (#2001)
This commit is contained in:
parent
f8117b7223
commit
8ba0179137
@ -1,4 +1,5 @@
|
|||||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||||
|
load("//bazel/go:go_test.bzl", "go_test")
|
||||||
|
|
||||||
go_library(
|
go_library(
|
||||||
name = "nodelock",
|
name = "nodelock",
|
||||||
@ -10,3 +11,13 @@ go_library(
|
|||||||
"//internal/attestation/vtpm",
|
"//internal/attestation/vtpm",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
go_test(
|
||||||
|
name = "nodelock_test",
|
||||||
|
srcs = ["nodelock_test.go"],
|
||||||
|
embed = [":nodelock"],
|
||||||
|
deps = [
|
||||||
|
"//internal/attestation/vtpm",
|
||||||
|
"@com_github_stretchr_testify//assert",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -8,7 +8,8 @@ SPDX-License-Identifier: AGPL-3.0-only
|
|||||||
package nodelock
|
package nodelock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"io"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/v2/internal/attestation/initialize"
|
"github.com/edgelesssys/constellation/v2/internal/attestation/initialize"
|
||||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||||
@ -23,23 +24,35 @@ import (
|
|||||||
// locked.
|
// locked.
|
||||||
type Lock struct {
|
type Lock struct {
|
||||||
tpm vtpm.TPMOpenFunc
|
tpm vtpm.TPMOpenFunc
|
||||||
mux *sync.Mutex
|
marker markAsBootstrapped
|
||||||
|
inner atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new NodeLock, which is unlocked.
|
// New creates a new NodeLock, which is unlocked.
|
||||||
func New(tpm vtpm.TPMOpenFunc) *Lock {
|
func New(tpm vtpm.TPMOpenFunc) *Lock {
|
||||||
return &Lock{
|
return &Lock{
|
||||||
tpm: tpm,
|
tpm: tpm,
|
||||||
mux: &sync.Mutex{},
|
marker: initialize.MarkNodeAsBootstrapped,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TryLockOnce tries to lock the node. If the node is already locked, it
|
// TryLockOnce tries to lock the node. If the node is already locked, it
|
||||||
// returns false. If the node is unlocked, it locks it and returns true.
|
// returns false. If the node is unlocked, it locks it and returns true.
|
||||||
func (l *Lock) TryLockOnce(clusterID []byte) (bool, error) {
|
func (l *Lock) TryLockOnce(clusterID []byte) (bool, error) {
|
||||||
if !l.mux.TryLock() {
|
// CompareAndSwap first checks if the node is currently unlocked.
|
||||||
|
// If it was already locked, it returns early.
|
||||||
|
// If it is unlocked, it swaps the value to locked atomically and continues.
|
||||||
|
if !l.inner.CompareAndSwap(unlocked, locked) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return true, initialize.MarkNodeAsBootstrapped(l.tpm, clusterID)
|
return true, l.marker(l.tpm, clusterID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// markAsBootstrapped is a function that marks the node as bootstrapped in the TPM.
|
||||||
|
type markAsBootstrapped func(openDevice func() (io.ReadWriteCloser, error), clusterID []byte) error
|
||||||
|
|
||||||
|
const (
|
||||||
|
unlocked = false
|
||||||
|
locked = true
|
||||||
|
)
|
||||||
|
64
bootstrapper/internal/nodelock/nodelock_test.go
Normal file
64
bootstrapper/internal/nodelock/nodelock_test.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) Edgeless Systems GmbH
|
||||||
|
|
||||||
|
SPDX-License-Identifier: AGPL-3.0-only
|
||||||
|
*/
|
||||||
|
|
||||||
|
package nodelock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTryLockOnce(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
tpm := spyDevice{}
|
||||||
|
lock := Lock{
|
||||||
|
tpm: tpm.Opener(),
|
||||||
|
marker: stubMarker,
|
||||||
|
}
|
||||||
|
locked, err := lock.TryLockOnce(nil)
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.True(locked)
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
tryLock := func() {
|
||||||
|
defer wg.Done()
|
||||||
|
locked, err := lock.TryLockOnce(nil)
|
||||||
|
assert.NoError(err)
|
||||||
|
assert.False(locked)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go tryLock()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
assert.EqualValues(1, tpm.counter.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
type spyDevice struct {
|
||||||
|
counter atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *spyDevice) Opener() vtpm.TPMOpenFunc {
|
||||||
|
return func() (io.ReadWriteCloser, error) {
|
||||||
|
s.counter.Add(1)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stubMarker(openDevice func() (io.ReadWriteCloser, error), _ []byte) error {
|
||||||
|
// this only needs to invoke the openDevice function
|
||||||
|
// so that the spyTPM counter is incremented
|
||||||
|
_, _ = openDevice()
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user