mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-07 21:58:01 -05:00
19871ee422
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
320 lines
7.8 KiB
Go
320 lines
7.8 KiB
Go
package setup
|
|
|
|
import (
|
|
"errors"
|
|
"io"
|
|
"io/fs"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/edgelesssys/constellation/bootstrapper/nodestate"
|
|
"github.com/edgelesssys/constellation/internal/attestation/vtpm"
|
|
"github.com/edgelesssys/constellation/internal/crypto"
|
|
"github.com/edgelesssys/constellation/internal/file"
|
|
"github.com/edgelesssys/constellation/internal/logger"
|
|
"github.com/spf13/afero"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m)
|
|
}
|
|
|
|
func TestPrepareExistingDisk(t *testing.T) {
|
|
someErr := errors.New("error")
|
|
|
|
testCases := map[string]struct {
|
|
fs afero.Afero
|
|
keyWaiter *stubKeyWaiter
|
|
mapper *stubMapper
|
|
mounter *stubMounter
|
|
openTPM vtpm.TPMOpenFunc
|
|
missingState bool
|
|
wantErr bool
|
|
}{
|
|
"success": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
},
|
|
"WaitForDecryptionKey fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{waitErr: someErr},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
wantErr: true,
|
|
},
|
|
"MapDisk fails causes a repeat": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{
|
|
uuid: "test",
|
|
mapDiskErr: someErr,
|
|
mapDiskRepeatedCalls: 2,
|
|
},
|
|
mounter: &stubMounter{},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
wantErr: false,
|
|
},
|
|
"MkdirAll fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{mkdirAllErr: someErr},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
wantErr: true,
|
|
},
|
|
"Mount fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{mountErr: someErr},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
wantErr: true,
|
|
},
|
|
"Unmount fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{unmountErr: someErr},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
wantErr: true,
|
|
},
|
|
"MarkNodeAsInitialized fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{unmountErr: someErr},
|
|
openTPM: failOpener,
|
|
wantErr: true,
|
|
},
|
|
"no state file": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
missingState: true,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
|
|
salt := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
|
if !tc.missingState {
|
|
handler := file.NewHandler(tc.fs)
|
|
require.NoError(t, handler.WriteJSON(stateInfoPath, nodestate.NodeState{MeasurementSalt: salt}, file.OptMkdirAll))
|
|
}
|
|
|
|
setupManager := New(
|
|
logger.NewTest(t),
|
|
"test",
|
|
tc.fs,
|
|
tc.keyWaiter,
|
|
tc.mapper,
|
|
tc.mounter,
|
|
tc.openTPM,
|
|
)
|
|
|
|
err := setupManager.PrepareExistingDisk()
|
|
if tc.wantErr {
|
|
assert.Error(err)
|
|
} else {
|
|
assert.NoError(err)
|
|
assert.Equal(tc.mapper.uuid, tc.keyWaiter.receivedUUID)
|
|
assert.True(tc.mapper.mapDiskCalled)
|
|
assert.True(tc.mounter.mountCalled)
|
|
assert.True(tc.mounter.unmountCalled)
|
|
assert.False(tc.mapper.formatDiskCalled)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func failOpener() (io.ReadWriteCloser, error) {
|
|
return nil, errors.New("error")
|
|
}
|
|
|
|
func TestPrepareNewDisk(t *testing.T) {
|
|
someErr := errors.New("error")
|
|
testCases := map[string]struct {
|
|
fs afero.Afero
|
|
mapper *stubMapper
|
|
wantErr bool
|
|
}{
|
|
"success": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
},
|
|
"creating directory fails": {
|
|
fs: afero.Afero{Fs: afero.NewReadOnlyFs(afero.NewMemMapFs())},
|
|
mapper: &stubMapper{},
|
|
wantErr: true,
|
|
},
|
|
"FormatDisk fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
mapper: &stubMapper{
|
|
uuid: "test",
|
|
formatDiskErr: someErr,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
"MapDisk fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
mapper: &stubMapper{
|
|
uuid: "test",
|
|
mapDiskErr: someErr,
|
|
mapDiskRepeatedCalls: 1,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
|
|
setupManager := New(logger.NewTest(t), "test", tc.fs, nil, tc.mapper, nil, nil)
|
|
|
|
err := setupManager.PrepareNewDisk()
|
|
if tc.wantErr {
|
|
assert.Error(err)
|
|
} else {
|
|
assert.NoError(err)
|
|
assert.True(tc.mapper.formatDiskCalled)
|
|
assert.True(tc.mapper.mapDiskCalled)
|
|
|
|
data, err := tc.fs.ReadFile(filepath.Join(keyPath, keyFile))
|
|
require.NoError(t, err)
|
|
assert.Len(data, crypto.RNGLengthDefault)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestReadMeasurementSalt(t *testing.T) {
|
|
salt := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
|
testCases := map[string]struct {
|
|
fs afero.Afero
|
|
salt []byte
|
|
writeFile bool
|
|
wantErr bool
|
|
}{
|
|
"success": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
salt: salt,
|
|
writeFile: true,
|
|
},
|
|
"no state file": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
wantErr: true,
|
|
},
|
|
"missing salt": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
writeFile: true,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
if tc.writeFile {
|
|
handler := file.NewHandler(tc.fs)
|
|
state := nodestate.NodeState{MeasurementSalt: tc.salt}
|
|
require.NoError(handler.WriteJSON("test-state.json", state, file.OptMkdirAll))
|
|
}
|
|
|
|
setupManager := New(logger.NewTest(t), "test", tc.fs, nil, nil, nil, nil)
|
|
|
|
measurementSalt, err := setupManager.readMeasurementSalt("test-state.json")
|
|
if tc.wantErr {
|
|
assert.Error(err)
|
|
} else {
|
|
assert.NoError(err)
|
|
assert.Equal(tc.salt, measurementSalt)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type stubMapper struct {
|
|
formatDiskCalled bool
|
|
formatDiskErr error
|
|
mapDiskRepeatedCalls int
|
|
mapDiskCalled bool
|
|
mapDiskErr error
|
|
uuid string
|
|
}
|
|
|
|
func (s *stubMapper) DiskUUID() string {
|
|
return s.uuid
|
|
}
|
|
|
|
func (s *stubMapper) FormatDisk(string) error {
|
|
s.formatDiskCalled = true
|
|
return s.formatDiskErr
|
|
}
|
|
|
|
func (s *stubMapper) MapDisk(string, string) error {
|
|
if s.mapDiskRepeatedCalls == 0 {
|
|
s.mapDiskErr = nil
|
|
}
|
|
s.mapDiskRepeatedCalls--
|
|
s.mapDiskCalled = true
|
|
return s.mapDiskErr
|
|
}
|
|
|
|
type stubMounter struct {
|
|
mountCalled bool
|
|
mountErr error
|
|
unmountCalled bool
|
|
unmountErr error
|
|
mkdirAllErr error
|
|
}
|
|
|
|
func (s *stubMounter) Mount(source string, target string, fstype string, flags uintptr, data string) error {
|
|
s.mountCalled = true
|
|
return s.mountErr
|
|
}
|
|
|
|
func (s *stubMounter) Unmount(target string, flags int) error {
|
|
s.unmountCalled = true
|
|
return s.unmountErr
|
|
}
|
|
|
|
func (s *stubMounter) MkdirAll(path string, perm fs.FileMode) error {
|
|
return s.mkdirAllErr
|
|
}
|
|
|
|
type stubKeyWaiter struct {
|
|
receivedUUID string
|
|
decryptionKey []byte
|
|
measurementSecret []byte
|
|
waitErr error
|
|
waitCalled bool
|
|
}
|
|
|
|
func (s *stubKeyWaiter) WaitForDecryptionKey(uuid, addr string) ([]byte, []byte, error) {
|
|
if s.waitCalled {
|
|
return nil, nil, errors.New("wait called before key was reset")
|
|
}
|
|
s.waitCalled = true
|
|
s.receivedUUID = uuid
|
|
return s.decryptionKey, s.measurementSecret, s.waitErr
|
|
}
|
|
|
|
func (s *stubKeyWaiter) ResetKey() {
|
|
s.waitCalled = false
|
|
}
|