constellation/csi/cryptmapper/cryptmapper_test.go

464 lines
11 KiB
Go
Raw Normal View History

package cryptmapper
import (
"context"
"errors"
"testing"
cryptsetup "github.com/martinjungblut/go-cryptsetup"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
type stubCryptDevice struct {
deviceName string
uuid string
initErr error
initByNameErr error
activateErr error
activatePassErr error
deactivateErr error
formatErr error
loadErr error
keySlotAddCalled bool
keySlotAddErr error
wipeErr error
resizeErr error
}
func (c *stubCryptDevice) Init(string) error {
return c.initErr
}
func (c *stubCryptDevice) InitByName(string) error {
return c.initByNameErr
}
func (c *stubCryptDevice) ActivateByVolumeKey(string, string, int, int) error {
return c.activateErr
}
func (c *stubCryptDevice) ActivateByPassphrase(string, int, string, int) error {
return c.activatePassErr
}
func (c *stubCryptDevice) Deactivate(string) error {
return c.deactivateErr
}
func (c *stubCryptDevice) Format(cryptsetup.DeviceType, cryptsetup.GenericParams) error {
return c.formatErr
}
func (c *stubCryptDevice) Free() bool {
return true
}
func (c *stubCryptDevice) GetDeviceName() string {
return c.deviceName
}
func (c *stubCryptDevice) GetUUID() string {
return c.uuid
}
func (c *stubCryptDevice) Load(cryptsetup.DeviceType) error {
return c.loadErr
}
func (c *stubCryptDevice) KeyslotAddByVolumeKey(int, string, string) error {
c.keySlotAddCalled = true
return c.keySlotAddErr
}
func (c *stubCryptDevice) Wipe(string, int, uint64, uint64, int, int, func(size, offset uint64) int) error {
return c.wipeErr
}
func (c *stubCryptDevice) Resize(string, uint64) error {
return c.resizeErr
}
func TestCloseCryptDevice(t *testing.T) {
testCases := map[string]struct {
mapper *stubCryptDevice
wantErr bool
}{
"success": {
mapper: &stubCryptDevice{},
wantErr: false,
},
"error on InitByName": {
mapper: &stubCryptDevice{initByNameErr: errors.New("error")},
wantErr: true,
},
"error on Deactivate": {
mapper: &stubCryptDevice{deactivateErr: errors.New("error")},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := closeCryptDevice(tc.mapper, "/dev/some-device", "volume0", "test")
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
mapper := New(&fakeKMS{}, &stubCryptDevice{})
err := mapper.CloseCryptDevice("volume01-unit-test")
assert.NoError(t, err)
}
func TestOpenCryptDevice(t *testing.T) {
someErr := errors.New("error")
getKeyFunc := func(context.Context, string, int) ([]byte, error) {
return []byte{
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
}, nil
}
testCases := map[string]struct {
source string
volumeID string
integrity bool
mapper *stubCryptDevice
getKey func(context.Context, string, int) ([]byte, error)
diskInfo func(disk string) (string, error)
wantErr bool
}{
"success with Load": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: false,
},
"success with error on Load": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: someErr},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: false,
},
"success with integrity": {
source: "/dev/some-device",
volumeID: "volume0",
integrity: true,
mapper: &stubCryptDevice{loadErr: someErr},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: false,
},
"error on Init": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{initErr: someErr},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"error on Format": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: someErr, formatErr: someErr},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"error on Activate": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{activatePassErr: someErr},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"error on diskInfo": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: someErr},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", someErr },
wantErr: true,
},
"disk is already formatted": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: someErr},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "ext4", nil },
wantErr: true,
},
"error with integrity on wipe": {
source: "/dev/some-device",
volumeID: "volume0",
integrity: true,
mapper: &stubCryptDevice{loadErr: someErr, wipeErr: someErr},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"error with integrity on activate": {
source: "/dev/some-device",
volumeID: "volume0",
integrity: true,
mapper: &stubCryptDevice{loadErr: someErr, activateErr: someErr},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"error with integrity on deactivate": {
source: "/dev/some-device",
volumeID: "volume0",
integrity: true,
mapper: &stubCryptDevice{loadErr: someErr, deactivateErr: someErr},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"error on adding keyslot": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{
loadErr: someErr,
keySlotAddErr: someErr,
},
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"incorrect key length": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{},
getKey: func(ctx context.Context, s string, i int) ([]byte, error) { return []byte{0x1, 0x2, 0x3}, nil },
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"incorrect key length with error on Load": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: someErr},
getKey: func(ctx context.Context, s string, i int) ([]byte, error) { return []byte{0x1, 0x2, 0x3}, nil },
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"getKey fails": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{},
getKey: func(ctx context.Context, s string, i int) ([]byte, error) { return nil, someErr },
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"getKey fails with error on Load": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: someErr},
getKey: func(ctx context.Context, s string, i int) ([]byte, error) { return nil, someErr },
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
out, err := openCryptDevice(
context.Background(),
tc.mapper,
tc.source,
tc.volumeID,
tc.integrity,
tc.getKey,
tc.diskInfo,
)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(cryptPrefix+tc.volumeID, out)
if tc.mapper.loadErr == nil {
assert.False(tc.mapper.keySlotAddCalled)
} else {
assert.True(tc.mapper.keySlotAddCalled)
}
}
})
}
mapper := New(&fakeKMS{}, &stubCryptDevice{})
_, err := mapper.OpenCryptDevice(context.Background(), "/dev/some-device", "volume01", false)
assert.NoError(t, err)
}
func TestResizeCryptDevice(t *testing.T) {
volumeID := "pvc-123"
someErr := errors.New("error")
testCases := map[string]struct {
volumeID string
device *stubCryptDevice
wantErr bool
}{
"success": {
volumeID: volumeID,
device: &stubCryptDevice{},
},
"InitByName fails": {
volumeID: volumeID,
device: &stubCryptDevice{initByNameErr: someErr},
wantErr: true,
},
"Load fails": {
volumeID: volumeID,
device: &stubCryptDevice{loadErr: someErr},
wantErr: true,
},
"Resize fails": {
volumeID: volumeID,
device: &stubCryptDevice{resizeErr: someErr},
wantErr: true,
},
"ActivateByPassphrase fails": {
volumeID: volumeID,
device: &stubCryptDevice{activatePassErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
mapper := &CryptMapper{
kms: &fakeKMS{},
mapper: tc.device,
}
res, err := mapper.ResizeCryptDevice(context.Background(), tc.volumeID)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(cryptPrefix+tc.volumeID, res)
}
})
}
}
func TestGetDevicePath(t *testing.T) {
volumeID := "pvc-123"
someErr := errors.New("error")
testCases := map[string]struct {
volumeID string
device *stubCryptDevice
wantErr bool
}{
"success": {
volumeID: volumeID,
device: &stubCryptDevice{deviceName: volumeID},
},
"InitByName fails": {
volumeID: volumeID,
device: &stubCryptDevice{initByNameErr: someErr},
wantErr: true,
},
"GetDeviceName returns nothing": {
volumeID: volumeID,
device: &stubCryptDevice{},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
mapper := &CryptMapper{
mapper: tc.device,
}
res, err := mapper.GetDevicePath(tc.volumeID)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.device.deviceName, res)
}
})
}
}
func TestIsIntegrityFS(t *testing.T) {
testCases := map[string]struct {
wantIntegrity bool
fstype string
}{
"plain ext4": {
wantIntegrity: false,
fstype: "ext4",
},
"integrity ext4": {
wantIntegrity: true,
fstype: "ext4",
},
"integrity fs": {
wantIntegrity: false,
fstype: "integrity",
},
"double integrity": {
wantIntegrity: true,
fstype: "ext4-integrity",
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
request := tc.fstype
if tc.wantIntegrity {
request = tc.fstype + integrityFSSuffix
}
fstype, isIntegrity := IsIntegrityFS(request)
if tc.wantIntegrity {
assert.True(isIntegrity)
assert.Equal(tc.fstype, fstype)
} else {
assert.False(isIntegrity)
assert.Equal(tc.fstype, fstype)
}
})
}
}
type fakeKMS struct{}
func (k *fakeKMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
key := make([]byte, dekSize)
for i := range key {
key[i] = 0x41
}
return key, nil
}