AB#2103 Derive key from LUKS UUID instead of disk name (#156)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-05-19 08:47:17 +02:00 committed by GitHub
parent daf356d88e
commit 0a24de24ee
3 changed files with 216 additions and 116 deletions

View File

@ -68,6 +68,8 @@ type DeviceMapper interface {
Free() bool Free() bool
// GetDeviceName gets the path to the underlying device. // GetDeviceName gets the path to the underlying device.
GetDeviceName() string GetDeviceName() string
// GetUUID gets the devices UUID
GetUUID() string
// Load loads crypt device parameters from the on-disk header. // Load loads crypt device parameters from the on-disk header.
// Returns nil on success, or an error otherwise. // Returns nil on success, or an error otherwise.
Load(cryptsetup.DeviceType) error Load(cryptsetup.DeviceType) error
@ -171,18 +173,8 @@ func (c *CryptMapper) CloseCryptDevice(volumeID string) error {
// OpenCryptDevice maps the volume at source to the crypt device identified by volumeID. // OpenCryptDevice maps the volume at source to the crypt device identified by volumeID.
// The key used to encrypt the volume is fetched using CryptMapper's kms client. // The key used to encrypt the volume is fetched using CryptMapper's kms client.
func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID string, integrity bool) (string, error) { func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID string, integrity bool) (string, error) {
klog.V(4).Infof("Fetching data encryption key for volume %q", volumeID)
passphrase, err := c.kms.GetDEK(ctx, volumeID, constants.StateDiskKeyLength)
if err != nil {
return "", err
}
if len(passphrase) != constants.StateDiskKeyLength {
return "", fmt.Errorf("expected key length to be [%d] but got [%d]", constants.StateDiskKeyLength, len(passphrase))
}
m := &mount.SafeFormatAndMount{Exec: utilexec.New()} m := &mount.SafeFormatAndMount{Exec: utilexec.New()}
return openCryptDevice(c.mapper, source, volumeID, string(passphrase), integrity, m.GetDiskFormat) return openCryptDevice(ctx, c.mapper, source, volumeID, integrity, c.kms.GetDEK, m.GetDiskFormat)
} }
// ResizeCryptDevice resizes the underlying crypt device and returns the mapped device path. // ResizeCryptDevice resizes the underlying crypt device and returns the mapped device path.
@ -228,7 +220,9 @@ func closeCryptDevice(device DeviceMapper, source, volumeID, deviceType string)
} }
// openCryptDevice maps the volume at source to the crypt device identified by volumeID. // openCryptDevice maps the volume at source to the crypt device identified by volumeID.
func openCryptDevice(device DeviceMapper, source, volumeID, passphrase string, integrity bool, diskInfo func(disk string) (string, error)) (string, error) { func openCryptDevice(ctx context.Context, device DeviceMapper, source, volumeID string, integrity bool,
getKey func(ctx context.Context, keyID string, keySize int) ([]byte, error), diskInfo func(disk string) (string, error),
) (string, error) {
packageLock.Lock() packageLock.Lock()
defer packageLock.Unlock() defer packageLock.Unlock()
@ -248,7 +242,7 @@ func openCryptDevice(device DeviceMapper, source, volumeID, passphrase string, i
} }
defer device.Free() defer device.Free()
needWipe := false var passphrase []byte
// Try to load LUKS headers // Try to load LUKS headers
// If this fails, the device is either not formatted at all, or already formatted with a different FS // If this fails, the device is either not formatted at all, or already formatted with a different FS
if err := device.Load(cryptsetup.LUKS2{}); err != nil { if err := device.Load(cryptsetup.LUKS2{}); err != nil {
@ -287,22 +281,41 @@ func openCryptDevice(device DeviceMapper, source, volumeID, passphrase string, i
return "", fmt.Errorf("formatting device %q failed: %w", source, err) return "", fmt.Errorf("formatting device %q failed: %w", source, err)
} }
uuid := device.GetUUID()
klog.V(4).Infof("Fetching data encryption key for volume %q", volumeID)
passphrase, err = getKey(ctx, uuid, constants.StateDiskKeyLength)
if err != nil {
return "", err
}
if len(passphrase) != constants.StateDiskKeyLength {
return "", fmt.Errorf("expected key length to be [%d] but got [%d]", constants.StateDiskKeyLength, len(passphrase))
}
// Add a new keyslot using the internal volume key // Add a new keyslot using the internal volume key
if err := device.KeyslotAddByVolumeKey(0, "", passphrase); err != nil { if err := device.KeyslotAddByVolumeKey(0, "", string(passphrase)); err != nil {
return "", fmt.Errorf("adding keyslot: %w", err) return "", fmt.Errorf("adding keyslot: %w", err)
} }
needWipe = true
}
if integrity && needWipe { if integrity {
if err := performWipe(device, volumeID); err != nil { if err := performWipe(device, volumeID); err != nil {
return "", fmt.Errorf("wiping device: %w", err) return "", fmt.Errorf("wiping device: %w", err)
}
}
} else {
uuid := device.GetUUID()
klog.V(4).Infof("Fetching data encryption key for volume %q", volumeID)
passphrase, err = getKey(ctx, uuid, constants.StateDiskKeyLength)
if err != nil {
return "", err
}
if len(passphrase) != constants.StateDiskKeyLength {
return "", fmt.Errorf("expected key length to be [%d] but got [%d]", constants.StateDiskKeyLength, len(passphrase))
} }
} }
klog.V(4).Infof("Activating LUKS2 device %q", cryptPrefix+volumeID) klog.V(4).Infof("Activating LUKS2 device %q", cryptPrefix+volumeID)
if err := device.ActivateByPassphrase(volumeID, 0, passphrase, 0); err != nil { if err := device.ActivateByPassphrase(volumeID, 0, string(passphrase), 0); err != nil {
klog.Errorf("Trying to activate dm-crypt volume: %s", err) klog.Errorf("Trying to activate dm-crypt volume: %s", err)
return "", fmt.Errorf("trying to activate dm-crypt volume: %w", err) return "", fmt.Errorf("trying to activate dm-crypt volume: %w", err)
} }

View File

@ -10,19 +10,9 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var testDEK = []byte{
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
}
type stubCryptDevice struct { type stubCryptDevice struct {
deviceName string deviceName string
uuid string
initErr error initErr error
initByNameErr error initByNameErr error
activateErr error activateErr error
@ -68,6 +58,10 @@ func (c *stubCryptDevice) GetDeviceName() string {
return c.deviceName return c.deviceName
} }
func (c *stubCryptDevice) GetUUID() string {
return c.uuid
}
func (c *stubCryptDevice) Load(cryptsetup.DeviceType) error { func (c *stubCryptDevice) Load(cryptsetup.DeviceType) error {
return c.loadErr return c.loadErr
} }
@ -124,116 +118,156 @@ func TestCloseCryptDevice(t *testing.T) {
func TestOpenCryptDevice(t *testing.T) { func TestOpenCryptDevice(t *testing.T) {
someErr := errors.New("error") someErr := errors.New("error")
getKeyFunc := func(context.Context, string, int) ([]byte, error) {
return []byte{
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
}, nil
}
testCases := map[string]struct { testCases := map[string]struct {
source string source string
volumeID string volumeID string
passphrase string integrity bool
integrity bool mapper *stubCryptDevice
mapper *stubCryptDevice getKey func(context.Context, string, int) ([]byte, error)
diskInfo func(disk string) (string, error) diskInfo func(disk string) (string, error)
wantErr bool wantErr bool
}{ }{
"success with Load": { "success with Load": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(testDEK), mapper: &stubCryptDevice{},
mapper: &stubCryptDevice{}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil }, diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: false, wantErr: false,
}, },
"success with error on Load": { "success with error on Load": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(testDEK), mapper: &stubCryptDevice{loadErr: someErr},
mapper: &stubCryptDevice{loadErr: someErr}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil }, diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: false, wantErr: false,
}, },
"success with integrity": { "success with integrity": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(append(testDEK, testDEK[:32]...)), integrity: true,
integrity: true, mapper: &stubCryptDevice{loadErr: someErr},
mapper: &stubCryptDevice{loadErr: someErr}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil }, diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: false, wantErr: false,
}, },
"error on Init": { "error on Init": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(testDEK), mapper: &stubCryptDevice{initErr: someErr},
mapper: &stubCryptDevice{initErr: someErr}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil }, diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true, wantErr: true,
}, },
"error on Format": { "error on Format": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(testDEK), mapper: &stubCryptDevice{loadErr: someErr, formatErr: someErr},
mapper: &stubCryptDevice{loadErr: someErr, formatErr: someErr}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil }, diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true, wantErr: true,
}, },
"error on Activate": { "error on Activate": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(testDEK), mapper: &stubCryptDevice{activatePassErr: someErr},
mapper: &stubCryptDevice{activatePassErr: someErr}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil }, diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true, wantErr: true,
}, },
"error on diskInfo": { "error on diskInfo": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(testDEK), mapper: &stubCryptDevice{loadErr: someErr},
mapper: &stubCryptDevice{loadErr: someErr}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", someErr }, diskInfo: func(disk string) (string, error) { return "", someErr },
wantErr: true, wantErr: true,
}, },
"disk is already formatted": { "disk is already formatted": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(testDEK), mapper: &stubCryptDevice{loadErr: someErr},
mapper: &stubCryptDevice{loadErr: someErr}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "ext4", nil }, diskInfo: func(disk string) (string, error) { return "ext4", nil },
wantErr: true, wantErr: true,
}, },
"error with integrity on wipe": { "error with integrity on wipe": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(append(testDEK, testDEK[:32]...)), integrity: true,
integrity: true, mapper: &stubCryptDevice{loadErr: someErr, wipeErr: someErr},
mapper: &stubCryptDevice{loadErr: someErr, wipeErr: someErr}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil }, diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true, wantErr: true,
}, },
"error with integrity on activate": { "error with integrity on activate": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(append(testDEK, testDEK[:32]...)), integrity: true,
integrity: true, mapper: &stubCryptDevice{loadErr: someErr, activateErr: someErr},
mapper: &stubCryptDevice{loadErr: someErr, activateErr: someErr}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil }, diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true, wantErr: true,
}, },
"error with integrity on deactivate": { "error with integrity on deactivate": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(append(testDEK, testDEK[:32]...)), integrity: true,
integrity: true, mapper: &stubCryptDevice{loadErr: someErr, deactivateErr: someErr},
mapper: &stubCryptDevice{loadErr: someErr, deactivateErr: someErr}, getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil }, diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true, wantErr: true,
}, },
"error on adding keyslot": { "error on adding keyslot": {
source: "/dev/some-device", source: "/dev/some-device",
volumeID: "volume0", volumeID: "volume0",
passphrase: string(testDEK),
mapper: &stubCryptDevice{ mapper: &stubCryptDevice{
loadErr: someErr, loadErr: someErr,
keySlotAddErr: someErr, keySlotAddErr: someErr,
}, },
getKey: getKeyFunc,
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"incorrect key length": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{},
getKey: func(ctx context.Context, s string, i int) ([]byte, error) { return []byte{0x1, 0x2, 0x3}, nil },
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"incorrect key length with error on Load": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: someErr},
getKey: func(ctx context.Context, s string, i int) ([]byte, error) { return []byte{0x1, 0x2, 0x3}, nil },
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"getKey fails": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{},
getKey: func(ctx context.Context, s string, i int) ([]byte, error) { return nil, someErr },
diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true,
},
"getKey fails with error on Load": {
source: "/dev/some-device",
volumeID: "volume0",
mapper: &stubCryptDevice{loadErr: someErr},
getKey: func(ctx context.Context, s string, i int) ([]byte, error) { return nil, someErr },
diskInfo: func(disk string) (string, error) { return "", nil }, diskInfo: func(disk string) (string, error) { return "", nil },
wantErr: true, wantErr: true,
}, },
@ -243,7 +277,15 @@ func TestOpenCryptDevice(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
out, err := openCryptDevice(tc.mapper, tc.source, tc.volumeID, tc.passphrase, tc.integrity, tc.diskInfo) out, err := openCryptDevice(
context.Background(),
tc.mapper,
tc.source,
tc.volumeID,
tc.integrity,
tc.getKey,
tc.diskInfo,
)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { } else {

View File

@ -25,8 +25,12 @@ func setup() {
exec.Command("/bin/dd", "if=/dev/zero", fmt.Sprintf("of=%s", DevicePath), "bs=64M", "count=1").Run() exec.Command("/bin/dd", "if=/dev/zero", fmt.Sprintf("of=%s", DevicePath), "bs=64M", "count=1").Run()
} }
func teardown() { func teardown(devicePath string) {
exec.Command("/bin/rm", "-f", DevicePath).Run() exec.Command("/bin/rm", "-f", devicePath).Run()
}
func copy(source, target string) error {
return exec.Command("cp", source, target).Run()
} }
func resize() { func resize() {
@ -50,7 +54,7 @@ func TestOpenAndClose(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
setup() setup()
defer teardown() defer teardown(DevicePath)
kms := kms.NewStaticKMS() kms := kms.NewStaticKMS()
mapper := cryptmapper.New(kms, &cryptmapper.CryptDevice{}) mapper := cryptmapper.New(kms, &cryptmapper.CryptDevice{})
@ -78,13 +82,18 @@ func TestOpenAndClose(t *testing.T) {
// assert crypt device got removed // assert crypt device got removed
_, err = os.Stat(newPath) _, err = os.Stat(newPath)
assert.True(os.IsNotExist(err)) assert.True(os.IsNotExist(err))
// check if we can reopen the device
_, err = mapper.OpenCryptDevice(context.Background(), DevicePath, DeviceName, true)
assert.NoError(err)
assert.NoError(mapper.CloseCryptDevice(DeviceName))
} }
func TestOpenAndCloseIntegrity(t *testing.T) { func TestOpenAndCloseIntegrity(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
setup() setup()
defer teardown() defer teardown(DevicePath)
kms := kms.NewStaticKMS() kms := kms.NewStaticKMS()
mapper := cryptmapper.New(kms, &cryptmapper.CryptDevice{}) mapper := cryptmapper.New(kms, &cryptmapper.CryptDevice{})
@ -113,4 +122,40 @@ func TestOpenAndCloseIntegrity(t *testing.T) {
// assert integrity device got removed // assert integrity device got removed
_, err = os.Stat(newPath + "_dif") _, err = os.Stat(newPath + "_dif")
assert.True(os.IsNotExist(err)) assert.True(os.IsNotExist(err))
// check if we can reopen the device
_, err = mapper.OpenCryptDevice(context.Background(), DevicePath, DeviceName, true)
assert.NoError(err)
assert.NoError(mapper.CloseCryptDevice(DeviceName))
}
func TestDeviceCloning(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
setup()
defer teardown(DevicePath)
mapper := cryptmapper.New(&dynamicKMS{}, &cryptmapper.CryptDevice{})
_, err := mapper.OpenCryptDevice(context.Background(), DevicePath, DeviceName, false)
assert.NoError(err)
require.NoError(copy(DevicePath, DevicePath+"-copy"))
defer teardown(DevicePath + "-copy")
_, err = mapper.OpenCryptDevice(context.Background(), DevicePath+"-copy", DeviceName+"-copy", false)
assert.NoError(err)
assert.NoError(mapper.CloseCryptDevice(DeviceName))
assert.NoError(mapper.CloseCryptDevice(DeviceName + "-copy"))
}
type dynamicKMS struct{}
func (k *dynamicKMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
key := make([]byte, dekSize)
for i := range key {
key[i] = 0x41 ^ dekID[i%len(dekID)]
}
return key, nil
} }