constellation/csi/cryptmapper/cryptmapper_test.go

460 lines
11 KiB
Go
Raw Normal View History

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cryptmapper
import (
"bytes"
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, goleak.IgnoreAnyFunction("github.com/bazelbuild/rules_go/go/tools/bzltestutil.RegisterTimeoutHandler.func1"))
}
func TestCloseCryptDevice(t *testing.T) {
testCases := map[string]struct {
mapper *stubCryptDevice
wantErr bool
}{
"success": {
mapper: &stubCryptDevice{},
wantErr: false,
},
"error on InitByName": {
mapper: &stubCryptDevice{initByNameErr: assert.AnError},
wantErr: true,
},
"error on Deactivate": {
mapper: &stubCryptDevice{deactivateErr: assert.AnError},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
mapper := &CryptMapper{
kms: &fakeKMS{},
mapper: testMapper(tc.mapper),
}
err := mapper.closeCryptDevice("/dev/mapper/volume01", "volume01-unit-test", "crypt")
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
mapper := &CryptMapper{
mapper: testMapper(&stubCryptDevice{}),
kms: &fakeKMS{},
getDiskFormat: getDiskFormat,
}
err := mapper.CloseCryptDevice("volume01-unit-test")
assert.NoError(t, err)
}
func TestOpenCryptDevice(t *testing.T) {
testCases := map[string]struct {
source string
volumeID string
integrity bool
mapper *stubCryptDevice
kms keyCreator
diskInfo func(disk string) (string, error)
wantErr bool
}{
"success with Load": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{},
kms: &fakeKMS{},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: false,
},
"success with error on Load": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: assert.AnError},
kms: &fakeKMS{},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: false,
},
"success with integrity": {
source: "/dev/some-device",
volumeID: "volume0",
integrity: true,
mapper: &stubCryptDevice{loadErr: assert.AnError},
kms: &fakeKMS{},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: false,
},
"error on Init": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{initErr: assert.AnError},
kms: &fakeKMS{},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: true,
},
"error on Format": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: assert.AnError, formatErr: assert.AnError},
kms: &fakeKMS{},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: true,
},
"error on Activate": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{activatePassErr: assert.AnError},
kms: &fakeKMS{},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: true,
},
"error on diskInfo": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: assert.AnError},
kms: &fakeKMS{},
diskInfo: func(_ string) (string, error) { return "", assert.AnError },
wantErr: true,
},
"disk is already formatted": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: assert.AnError},
kms: &fakeKMS{},
diskInfo: func(_ string) (string, error) { return "ext4", nil },
wantErr: true,
},
"error with integrity on wipe": {
source: "/dev/some-device",
volumeID: "volume0",
integrity: true,
mapper: &stubCryptDevice{loadErr: assert.AnError, wipeErr: assert.AnError},
kms: &fakeKMS{},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: true,
},
"error on adding keyslot": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: assert.AnError, keySlotAddErr: assert.AnError},
kms: &fakeKMS{},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: true,
},
"incorrect key length": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{},
kms: &fakeKMS{presetKey: []byte{0x1, 0x2, 0x3}},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: true,
},
"incorrect key length with error on Load": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: assert.AnError},
kms: &fakeKMS{presetKey: []byte{0x1, 0x2, 0x3}},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: true,
},
"getKey fails": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{},
kms: &fakeKMS{getDEKErr: assert.AnError},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: true,
},
"getKey fails with error on Load": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: assert.AnError},
kms: &fakeKMS{getDEKErr: assert.AnError},
diskInfo: func(_ string) (string, error) { return "", nil },
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
mapper := &CryptMapper{
mapper: testMapper(tc.mapper),
kms: tc.kms,
getDiskFormat: tc.diskInfo,
}
out, err := mapper.OpenCryptDevice(context.Background(), tc.source, tc.volumeID, tc.integrity)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(cryptPrefix+tc.volumeID, out)
if tc.mapper.loadErr == nil {
assert.False(tc.mapper.keySlotAddCalled)
} else {
assert.True(tc.mapper.keySlotAddCalled)
}
}
})
}
mapper := &CryptMapper{
mapper: testMapper(&stubCryptDevice{}),
kms: &fakeKMS{},
getDiskFormat: getDiskFormat,
}
_, err := mapper.OpenCryptDevice(context.Background(), "/dev/some-device", "volume01", false)
assert.NoError(t, err)
}
func TestResizeCryptDevice(t *testing.T) {
volumeID := "pvc-123"
someErr := errors.New("error")
testCases := map[string]struct {
volumeID string
device *stubCryptDevice
wantErr bool
}{
"success": {
volumeID: volumeID,
device: &stubCryptDevice{},
},
"InitByName fails": {
volumeID: volumeID,
device: &stubCryptDevice{initByNameErr: someErr},
wantErr: true,
},
"Load fails": {
volumeID: volumeID,
device: &stubCryptDevice{loadErr: someErr},
wantErr: true,
},
"Resize fails": {
volumeID: volumeID,
device: &stubCryptDevice{resizeErr: someErr},
wantErr: true,
},
"ActivateByPassphrase fails": {
volumeID: volumeID,
device: &stubCryptDevice{activatePassErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
mapper := &CryptMapper{
kms: &fakeKMS{},
mapper: testMapper(tc.device),
}
res, err := mapper.ResizeCryptDevice(context.Background(), tc.volumeID)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(cryptPrefix+tc.volumeID, res)
}
})
}
}
func TestGetDevicePath(t *testing.T) {
volumeID := "pvc-123"
someErr := errors.New("error")
testCases := map[string]struct {
volumeID string
device *stubCryptDevice
wantErr bool
}{
"success": {
volumeID: volumeID,
device: &stubCryptDevice{deviceName: volumeID},
},
"InitByName fails": {
volumeID: volumeID,
device: &stubCryptDevice{initByNameErr: someErr},
wantErr: true,
},
"GetDeviceName returns nothing": {
volumeID: volumeID,
device: &stubCryptDevice{},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
mapper := &CryptMapper{
mapper: testMapper(tc.device),
}
res, err := mapper.GetDevicePath(tc.volumeID)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.device.deviceName, res)
}
})
}
}
func TestIsIntegrityFS(t *testing.T) {
testCases := map[string]struct {
wantIntegrity bool
fstype string
}{
"plain ext4": {
wantIntegrity: false,
fstype: "ext4",
},
"integrity ext4": {
wantIntegrity: true,
fstype: "ext4",
},
"integrity fs": {
wantIntegrity: false,
fstype: "integrity",
},
"double integrity": {
wantIntegrity: true,
fstype: "ext4-integrity",
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
request := tc.fstype
if tc.wantIntegrity {
request = tc.fstype + integrityFSSuffix
}
fstype, isIntegrity := IsIntegrityFS(request)
if tc.wantIntegrity {
assert.True(isIntegrity)
assert.Equal(tc.fstype, fstype)
} else {
assert.False(isIntegrity)
assert.Equal(tc.fstype, fstype)
}
})
}
}
type fakeKMS struct {
presetKey []byte
getDEKErr error
}
func (k *fakeKMS) GetDEK(_ context.Context, _ string, dekSize int) ([]byte, error) {
if k.getDEKErr != nil {
return nil, k.getDEKErr
}
if k.presetKey != nil {
return k.presetKey, nil
}
return bytes.Repeat([]byte{0xAA}, dekSize), nil
}
type stubCryptDevice struct {
deviceName string
uuid string
uuidErr error
initErr error
initByNameErr error
activateErr error
activatePassErr error
deactivateErr error
formatErr error
loadErr error
keySlotAddCalled bool
keySlotAddErr error
wipeErr error
resizeErr error
}
func (c *stubCryptDevice) Init(_ string) (func(), error) {
return func() {}, c.initErr
}
func (c *stubCryptDevice) InitByName(_ string) (func(), error) {
return func() {}, c.initByNameErr
}
func (c *stubCryptDevice) ActivateByVolumeKey(_, _ string, _, _ int) error {
return c.activateErr
}
func (c *stubCryptDevice) ActivateByPassphrase(_ string, _ int, _ string, _ int) error {
return c.activatePassErr
}
func (c *stubCryptDevice) Deactivate(_ string) error {
return c.deactivateErr
}
func (c *stubCryptDevice) Format(_ bool) error {
return c.formatErr
}
func (c *stubCryptDevice) Free() {}
func (c *stubCryptDevice) GetDeviceName() string {
return c.deviceName
}
func (c *stubCryptDevice) GetUUID() (string, error) {
return c.uuid, c.uuidErr
}
func (c *stubCryptDevice) LoadLUKS2() error {
return c.loadErr
}
func (c *stubCryptDevice) KeyslotAddByVolumeKey(_ int, _ string, _ string) error {
c.keySlotAddCalled = true
return c.keySlotAddErr
}
func (c *stubCryptDevice) Wipe(_ string, _ int, _ int, _ func(size, offset uint64), _ time.Duration) error {
return c.wipeErr
}
func (c *stubCryptDevice) Resize(_ string, _ uint64) error {
return c.resizeErr
}
func testMapper(stub *stubCryptDevice) func() deviceMapper {
return func() deviceMapper {
return stub
}
}