mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-09 22:49:39 -05:00
51068abc27
consistent naming for test values using 'want' instead of 'expect/ed'
318 lines
7.8 KiB
Go
318 lines
7.8 KiB
Go
package setup
|
|
|
|
import (
|
|
"errors"
|
|
"io"
|
|
"io/fs"
|
|
"path/filepath"
|
|
"testing"
|
|
|
|
"github.com/edgelesssys/constellation/cli/file"
|
|
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
|
"github.com/edgelesssys/constellation/coordinator/config"
|
|
"github.com/edgelesssys/constellation/coordinator/nodestate"
|
|
"github.com/spf13/afero"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestPrepareExistingDisk(t *testing.T) {
|
|
someErr := errors.New("error")
|
|
|
|
testCases := map[string]struct {
|
|
fs afero.Afero
|
|
keyWaiter *stubKeyWaiter
|
|
mapper *stubMapper
|
|
mounter *stubMounter
|
|
openTPM vtpm.TPMOpenFunc
|
|
missingState bool
|
|
wantErr bool
|
|
}{
|
|
"success": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
},
|
|
"WaitForDecryptionKey fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{waitErr: someErr},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
wantErr: true,
|
|
},
|
|
"MapDisk fails causes a repeat": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{
|
|
uuid: "test",
|
|
mapDiskErr: someErr,
|
|
mapDiskRepeatedCalls: 2,
|
|
},
|
|
mounter: &stubMounter{},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
wantErr: false,
|
|
},
|
|
"MkdirAll fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{mkdirAllErr: someErr},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
wantErr: true,
|
|
},
|
|
"Mount fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{mountErr: someErr},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
wantErr: true,
|
|
},
|
|
"Unmount fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{unmountErr: someErr},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
wantErr: true,
|
|
},
|
|
"MarkNodeAsInitialized fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{unmountErr: someErr},
|
|
openTPM: failOpener,
|
|
wantErr: true,
|
|
},
|
|
"no state file": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
keyWaiter: &stubKeyWaiter{},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
mounter: &stubMounter{},
|
|
openTPM: vtpm.OpenNOPTPM,
|
|
missingState: true,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
|
|
if !tc.missingState {
|
|
handler := file.NewHandler(tc.fs)
|
|
require.NoError(t, handler.WriteJSON(stateInfoPath, nodestate.NodeState{OwnerID: []byte("ownerID"), ClusterID: []byte("clusterID")}, file.OptMkdirAll))
|
|
}
|
|
|
|
setupManager := New("test", tc.fs, tc.keyWaiter, tc.mapper, tc.mounter, tc.openTPM)
|
|
|
|
err := setupManager.PrepareExistingDisk()
|
|
if tc.wantErr {
|
|
assert.Error(err)
|
|
} else {
|
|
assert.NoError(err)
|
|
assert.Equal(tc.mapper.uuid, tc.keyWaiter.receivedUUID)
|
|
assert.True(tc.mapper.mapDiskCalled)
|
|
assert.True(tc.mounter.mountCalled)
|
|
assert.True(tc.mounter.unmountCalled)
|
|
assert.False(tc.mapper.formatDiskCalled)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func failOpener() (io.ReadWriteCloser, error) {
|
|
return nil, errors.New("error")
|
|
}
|
|
|
|
func TestPrepareNewDisk(t *testing.T) {
|
|
someErr := errors.New("error")
|
|
testCases := map[string]struct {
|
|
fs afero.Afero
|
|
mapper *stubMapper
|
|
wantErr bool
|
|
}{
|
|
"success": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
mapper: &stubMapper{uuid: "test"},
|
|
},
|
|
"creating directory fails": {
|
|
fs: afero.Afero{Fs: afero.NewReadOnlyFs(afero.NewMemMapFs())},
|
|
mapper: &stubMapper{},
|
|
wantErr: true,
|
|
},
|
|
"FormatDisk fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
mapper: &stubMapper{
|
|
uuid: "test",
|
|
formatDiskErr: someErr,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
"MapDisk fails": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
mapper: &stubMapper{
|
|
uuid: "test",
|
|
mapDiskErr: someErr,
|
|
mapDiskRepeatedCalls: 1,
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
|
|
setupManager := New("test", tc.fs, nil, tc.mapper, nil, nil)
|
|
|
|
err := setupManager.PrepareNewDisk()
|
|
if tc.wantErr {
|
|
assert.Error(err)
|
|
} else {
|
|
assert.NoError(err)
|
|
assert.True(tc.mapper.formatDiskCalled)
|
|
assert.True(tc.mapper.mapDiskCalled)
|
|
|
|
data, err := tc.fs.ReadFile(filepath.Join(keyPath, keyFile))
|
|
require.NoError(t, err)
|
|
assert.Len(data, config.RNGLengthDefault)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestReadInitSecrets(t *testing.T) {
|
|
testCases := map[string]struct {
|
|
fs afero.Afero
|
|
ownerID string
|
|
clusterID string
|
|
writeFile bool
|
|
wantErr bool
|
|
}{
|
|
"success": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
ownerID: "ownerID",
|
|
clusterID: "clusterID",
|
|
writeFile: true,
|
|
},
|
|
"no state file": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
wantErr: true,
|
|
},
|
|
"missing ownerID": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
clusterID: "clusterID",
|
|
writeFile: true,
|
|
wantErr: true,
|
|
},
|
|
"missing clusterID": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
ownerID: "ownerID",
|
|
writeFile: true,
|
|
wantErr: true,
|
|
},
|
|
"no IDs": {
|
|
fs: afero.Afero{Fs: afero.NewMemMapFs()},
|
|
writeFile: true,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
if tc.writeFile {
|
|
handler := file.NewHandler(tc.fs)
|
|
state := nodestate.NodeState{ClusterID: []byte(tc.clusterID), OwnerID: []byte(tc.ownerID)}
|
|
require.NoError(handler.WriteJSON("/tmp/test-state.json", state, file.OptMkdirAll))
|
|
}
|
|
|
|
setupManager := New("test", tc.fs, nil, nil, nil, nil)
|
|
|
|
ownerID, clusterID, err := setupManager.readInitSecrets("/tmp/test-state.json")
|
|
if tc.wantErr {
|
|
assert.Error(err)
|
|
} else {
|
|
assert.NoError(err)
|
|
assert.Equal([]byte(tc.ownerID), ownerID)
|
|
assert.Equal([]byte(tc.clusterID), clusterID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type stubMapper struct {
|
|
formatDiskCalled bool
|
|
formatDiskErr error
|
|
mapDiskRepeatedCalls int
|
|
mapDiskCalled bool
|
|
mapDiskErr error
|
|
uuid string
|
|
}
|
|
|
|
func (s *stubMapper) DiskUUID() string {
|
|
return s.uuid
|
|
}
|
|
|
|
func (s *stubMapper) FormatDisk(passphrase string) error {
|
|
s.formatDiskCalled = true
|
|
return s.formatDiskErr
|
|
}
|
|
|
|
func (s *stubMapper) MapDisk(target string, passphrase string) error {
|
|
if s.mapDiskRepeatedCalls == 0 {
|
|
s.mapDiskErr = nil
|
|
}
|
|
s.mapDiskRepeatedCalls--
|
|
s.mapDiskCalled = true
|
|
return s.mapDiskErr
|
|
}
|
|
|
|
type stubMounter struct {
|
|
mountCalled bool
|
|
mountErr error
|
|
unmountCalled bool
|
|
unmountErr error
|
|
mkdirAllErr error
|
|
}
|
|
|
|
func (s *stubMounter) Mount(source string, target string, fstype string, flags uintptr, data string) error {
|
|
s.mountCalled = true
|
|
return s.mountErr
|
|
}
|
|
|
|
func (s *stubMounter) Unmount(target string, flags int) error {
|
|
s.unmountCalled = true
|
|
return s.unmountErr
|
|
}
|
|
|
|
func (s *stubMounter) MkdirAll(path string, perm fs.FileMode) error {
|
|
return s.mkdirAllErr
|
|
}
|
|
|
|
type stubKeyWaiter struct {
|
|
receivedUUID string
|
|
decryptionKey []byte
|
|
waitErr error
|
|
waitCalled bool
|
|
}
|
|
|
|
func (s *stubKeyWaiter) WaitForDecryptionKey(uuid, addr string) ([]byte, error) {
|
|
if s.waitCalled {
|
|
return nil, errors.New("wait called before key was reset")
|
|
}
|
|
s.waitCalled = true
|
|
s.receivedUUID = uuid
|
|
return s.decryptionKey, s.waitErr
|
|
}
|
|
|
|
func (s *stubKeyWaiter) ResetKey() {
|
|
s.waitCalled = false
|
|
}
|