65 lines
1.1 KiB
Go
Raw Normal View History

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package nodelock
import (
"io"
"sync"
"sync/atomic"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/stretchr/testify/assert"
)
func TestTryLockOnce(t *testing.T) {
assert := assert.New(t)
tpm := spyDevice{}
lock := Lock{
tpm: tpm.Opener(),
marker: stubMarker,
}
locked, err := lock.TryLockOnce(nil)
assert.NoError(err)
assert.True(locked)
wg := sync.WaitGroup{}
tryLock := func() {
defer wg.Done()
locked, err := lock.TryLockOnce(nil)
assert.NoError(err)
assert.False(locked)
}
for i := 0; i < 10; i++ {
wg.Add(1)
go tryLock()
}
wg.Wait()
assert.EqualValues(1, tpm.counter.Load())
}
type spyDevice struct {
counter atomic.Uint64
}
func (s *spyDevice) Opener() vtpm.TPMOpenFunc {
return func() (io.ReadWriteCloser, error) {
s.counter.Add(1)
return nil, nil
}
}
func stubMarker(openDevice func() (io.ReadWriteCloser, error), _ []byte) error {
// this only needs to invoke the openDevice function
// so that the spyTPM counter is incremented
_, _ = openDevice()
return nil
}