csi: fix concurrent use of cryptmapper package (#2408)

* Dont error on opening already active devices

* Fix concurrency issues when working with more than one device

---------

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-10-05 11:20:22 +02:00 committed by GitHub
parent 6ba43b03ee
commit f69ae26122
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 36 deletions

View File

@ -12,6 +12,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
@ -33,7 +34,7 @@ const (
// CryptMapper manages dm-crypt volumes. // CryptMapper manages dm-crypt volumes.
type CryptMapper struct { type CryptMapper struct {
mapper deviceMapper mapper func() deviceMapper
kms keyCreator kms keyCreator
getDiskFormat func(disk string) (string, error) getDiskFormat func(disk string) (string, error)
} }
@ -42,7 +43,7 @@ type CryptMapper struct {
// kms is used to fetch data encryption keys for the dm-crypt volumes. // kms is used to fetch data encryption keys for the dm-crypt volumes.
func New(kms keyCreator) *CryptMapper { func New(kms keyCreator) *CryptMapper {
return &CryptMapper{ return &CryptMapper{
mapper: cryptsetup.New(), mapper: func() deviceMapper { return cryptsetup.New() },
kms: kms, kms: kms,
getDiskFormat: getDiskFormat, getDiskFormat: getDiskFormat,
} }
@ -87,22 +88,35 @@ func (c *CryptMapper) CloseCryptDevice(volumeID string) error {
// The key used to encrypt the volume is fetched using CryptMapper's kms client. // The key used to encrypt the volume is fetched using CryptMapper's kms client.
func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID string, integrity bool) (string, error) { func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID string, integrity bool) (string, error) {
// Initialize the block device // Initialize the block device
free, err := c.mapper.Init(source) mapper := c.mapper()
free, err := mapper.Init(source)
if err != nil { if err != nil {
return "", fmt.Errorf("initializing dm-crypt to map device %q: %w", source, err) return "", fmt.Errorf("initializing dm-crypt to map device %q: %w", source, err)
} }
defer free() defer free()
deviceName := filepath.Join(cryptPrefix, volumeID)
var passphrase []byte var passphrase []byte
// Try to load LUKS headers // Try to load LUKS headers
// If this fails, the device is either not formatted at all, or already formatted with a different FS // If this fails, the device is either not formatted at all, or already formatted with a different FS
if err := c.mapper.LoadLUKS2(); err != nil { if err := mapper.LoadLUKS2(); err != nil {
passphrase, err = c.formatNewDevice(ctx, volumeID, source, integrity) passphrase, err = c.formatNewDevice(ctx, mapper, volumeID, source, integrity)
if err != nil { if err != nil {
return "", fmt.Errorf("formatting device: %w", err) return "", fmt.Errorf("formatting device: %w", err)
} }
} else { } else {
uuid, err := c.mapper.GetUUID() // Check if device is already active
// If yes, this is a no-op
// Simply return the device name
if _, err := os.Stat(deviceName); err == nil {
_, err := os.Stat(deviceName + integritySuffix)
if integrity && err != nil {
return "", fmt.Errorf("device %s already exists, but integrity device %s is missing", deviceName, deviceName+integritySuffix)
}
return deviceName, nil
}
uuid, err := mapper.GetUUID()
if err != nil { if err != nil {
return "", err return "", err
} }
@ -115,26 +129,27 @@ func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID stri
} }
} }
if err := c.mapper.ActivateByPassphrase(volumeID, 0, string(passphrase), cryptsetup.ReadWriteQueueBypass); err != nil { if err := mapper.ActivateByPassphrase(volumeID, 0, string(passphrase), cryptsetup.ReadWriteQueueBypass); err != nil {
return "", fmt.Errorf("trying to activate dm-crypt volume: %w", err) return "", fmt.Errorf("trying to activate dm-crypt volume: %w", err)
} }
return cryptPrefix + volumeID, nil return deviceName, nil
} }
// ResizeCryptDevice resizes the underlying crypt device and returns the mapped device path. // ResizeCryptDevice resizes the underlying crypt device and returns the mapped device path.
func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (string, error) { func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (string, error) {
free, err := c.mapper.InitByName(volumeID) mapper := c.mapper()
free, err := mapper.InitByName(volumeID)
if err != nil { if err != nil {
return "", fmt.Errorf("initializing device: %w", err) return "", fmt.Errorf("initializing device: %w", err)
} }
defer free() defer free()
if err := c.mapper.LoadLUKS2(); err != nil { if err := mapper.LoadLUKS2(); err != nil {
return "", fmt.Errorf("loading device: %w", err) return "", fmt.Errorf("loading device: %w", err)
} }
uuid, err := c.mapper.GetUUID() uuid, err := mapper.GetUUID()
if err != nil { if err != nil {
return "", err return "", err
} }
@ -143,11 +158,11 @@ func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (s
return "", fmt.Errorf("getting key: %w", err) return "", fmt.Errorf("getting key: %w", err)
} }
if err := c.mapper.ActivateByPassphrase("", 0, string(passphrase), resizeFlags); err != nil { if err := mapper.ActivateByPassphrase("", 0, string(passphrase), resizeFlags); err != nil {
return "", fmt.Errorf("activating keyring for crypt device %q with passphrase: %w", volumeID, err) return "", fmt.Errorf("activating keyring for crypt device %q with passphrase: %w", volumeID, err)
} }
if err := c.mapper.Resize(volumeID, 0); err != nil { if err := mapper.Resize(volumeID, 0); err != nil {
return "", fmt.Errorf("resizing device: %w", err) return "", fmt.Errorf("resizing device: %w", err)
} }
@ -156,14 +171,15 @@ func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (s
// GetDevicePath returns the device path of a mapped crypt device. // GetDevicePath returns the device path of a mapped crypt device.
func (c *CryptMapper) GetDevicePath(volumeID string) (string, error) { func (c *CryptMapper) GetDevicePath(volumeID string) (string, error) {
mapper := c.mapper()
name := strings.TrimPrefix(volumeID, cryptPrefix) name := strings.TrimPrefix(volumeID, cryptPrefix)
free, err := c.mapper.InitByName(name) free, err := mapper.InitByName(name)
if err != nil { if err != nil {
return "", fmt.Errorf("initializing device: %w", err) return "", fmt.Errorf("initializing device: %w", err)
} }
defer free() defer free()
deviceName := c.mapper.GetDeviceName() deviceName := mapper.GetDeviceName()
if deviceName == "" { if deviceName == "" {
return "", errors.New("unable to determine device name") return "", errors.New("unable to determine device name")
} }
@ -172,20 +188,21 @@ func (c *CryptMapper) GetDevicePath(volumeID string) (string, error) {
// closeCryptDevice closes the crypt device mapped for volumeID. // closeCryptDevice closes the crypt device mapped for volumeID.
func (c *CryptMapper) closeCryptDevice(source, volumeID, deviceType string) error { func (c *CryptMapper) closeCryptDevice(source, volumeID, deviceType string) error {
free, err := c.mapper.InitByName(volumeID) mapper := c.mapper()
free, err := mapper.InitByName(volumeID)
if err != nil { if err != nil {
return fmt.Errorf("initializing dm-%s to unmap device %q: %w", deviceType, source, err) return fmt.Errorf("initializing dm-%s to unmap device %q: %w", deviceType, source, err)
} }
defer free() defer free()
if err := c.mapper.Deactivate(volumeID); err != nil { if err := mapper.Deactivate(volumeID); err != nil {
return fmt.Errorf("deactivating dm-%s volume %q for device %q: %w", deviceType, cryptPrefix+volumeID, source, err) return fmt.Errorf("deactivating dm-%s volume %q for device %q: %w", deviceType, cryptPrefix+volumeID, source, err)
} }
return nil return nil
} }
func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source string, integrity bool) ([]byte, error) { func (c *CryptMapper) formatNewDevice(ctx context.Context, mapper deviceMapper, volumeID, source string, integrity bool) ([]byte, error) {
format, err := c.getDiskFormat(source) format, err := c.getDiskFormat(source)
if err != nil { if err != nil {
return nil, fmt.Errorf("determining if disk is formatted: %w", err) return nil, fmt.Errorf("determining if disk is formatted: %w", err)
@ -195,11 +212,11 @@ func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source stri
} }
// Device is not formatted, so we can safely create a new LUKS2 partition // Device is not formatted, so we can safely create a new LUKS2 partition
if err := c.mapper.Format(integrity); err != nil { if err := mapper.Format(integrity); err != nil {
return nil, fmt.Errorf("formatting device %q: %w", source, err) return nil, fmt.Errorf("formatting device %q: %w", source, err)
} }
uuid, err := c.mapper.GetUUID() uuid, err := mapper.GetUUID()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -212,7 +229,7 @@ func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source stri
} }
// Add a new keyslot using the internal volume key // Add a new keyslot using the internal volume key
if err := c.mapper.KeyslotAddByVolumeKey(0, "", string(passphrase)); err != nil { if err := mapper.KeyslotAddByVolumeKey(0, "", string(passphrase)); err != nil {
return nil, fmt.Errorf("adding keyslot: %w", err) return nil, fmt.Errorf("adding keyslot: %w", err)
} }
@ -222,7 +239,7 @@ func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source stri
fmt.Printf("Wipe in progress: %.2f%%\n", prog) fmt.Printf("Wipe in progress: %.2f%%\n", prog)
} }
if err := c.mapper.Wipe(volumeID, 1024*1024, 0, logProgress, 30*time.Second); err != nil { if err := mapper.Wipe(volumeID, 1024*1024, 0, logProgress, 30*time.Second); err != nil {
return nil, fmt.Errorf("wiping device: %w", err) return nil, fmt.Errorf("wiping device: %w", err)
} }
} }

View File

@ -46,7 +46,7 @@ func TestCloseCryptDevice(t *testing.T) {
mapper := &CryptMapper{ mapper := &CryptMapper{
kms: &fakeKMS{}, kms: &fakeKMS{},
mapper: tc.mapper, mapper: testMapper(tc.mapper),
} }
err := mapper.closeCryptDevice("/dev/mapper/volume01", "volume01-unit-test", "crypt") err := mapper.closeCryptDevice("/dev/mapper/volume01", "volume01-unit-test", "crypt")
if tc.wantErr { if tc.wantErr {
@ -58,7 +58,7 @@ func TestCloseCryptDevice(t *testing.T) {
} }
mapper := &CryptMapper{ mapper := &CryptMapper{
mapper: &stubCryptDevice{}, mapper: testMapper(&stubCryptDevice{}),
kms: &fakeKMS{}, kms: &fakeKMS{},
getDiskFormat: getDiskFormat, getDiskFormat: getDiskFormat,
} }
@ -197,7 +197,7 @@ func TestOpenCryptDevice(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
mapper := &CryptMapper{ mapper := &CryptMapper{
mapper: tc.mapper, mapper: testMapper(tc.mapper),
kms: tc.kms, kms: tc.kms,
getDiskFormat: tc.diskInfo, getDiskFormat: tc.diskInfo,
} }
@ -219,7 +219,7 @@ func TestOpenCryptDevice(t *testing.T) {
} }
mapper := &CryptMapper{ mapper := &CryptMapper{
mapper: &stubCryptDevice{}, mapper: testMapper(&stubCryptDevice{}),
kms: &fakeKMS{}, kms: &fakeKMS{},
getDiskFormat: getDiskFormat, getDiskFormat: getDiskFormat,
} }
@ -267,7 +267,7 @@ func TestResizeCryptDevice(t *testing.T) {
mapper := &CryptMapper{ mapper := &CryptMapper{
kms: &fakeKMS{}, kms: &fakeKMS{},
mapper: tc.device, mapper: testMapper(tc.device),
} }
res, err := mapper.ResizeCryptDevice(context.Background(), tc.volumeID) res, err := mapper.ResizeCryptDevice(context.Background(), tc.volumeID)
@ -310,7 +310,7 @@ func TestGetDevicePath(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
mapper := &CryptMapper{ mapper := &CryptMapper{
mapper: tc.device, mapper: testMapper(tc.device),
} }
res, err := mapper.GetDevicePath(tc.volumeID) res, err := mapper.GetDevicePath(tc.volumeID)
@ -451,3 +451,9 @@ func (c *stubCryptDevice) Wipe(_ string, _ int, _ int, _ func(size, offset uint6
func (c *stubCryptDevice) Resize(_ string, _ uint64) error { func (c *stubCryptDevice) Resize(_ string, _ uint64) error {
return c.resizeErr return c.resizeErr
} }
func testMapper(stub *stubCryptDevice) func() deviceMapper {
return func() deviceMapper {
return stub
}
}

View File

@ -13,6 +13,7 @@ import (
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"sync"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/csi/cryptmapper" "github.com/edgelesssys/constellation/v2/csi/cryptmapper"
@ -23,10 +24,10 @@ import (
const ( const (
devicePath string = "testDevice" devicePath string = "testDevice"
deviceName string = "testdeviceName" deviceName string = "testDeviceName"
) )
func setup() { func setup(devicePath string) {
if err := exec.Command("/bin/dd", "if=/dev/zero", fmt.Sprintf("of=%s", devicePath), "bs=64M", "count=1").Run(); err != nil { if err := exec.Command("/bin/dd", "if=/dev/zero", fmt.Sprintf("of=%s", devicePath), "bs=64M", "count=1").Run(); err != nil {
panic(err) panic(err)
} }
@ -42,7 +43,7 @@ func cp(source, target string) error {
return exec.Command("cp", source, target).Run() return exec.Command("cp", source, target).Run()
} }
func resize() { func resize(devicePath string) {
if err := exec.Command("/bin/dd", "if=/dev/zero", fmt.Sprintf("of=%s", devicePath), "bs=32M", "count=1", "oflag=append", "conv=notrunc").Run(); err != nil { if err := exec.Command("/bin/dd", "if=/dev/zero", fmt.Sprintf("of=%s", devicePath), "bs=32M", "count=1", "oflag=append", "conv=notrunc").Run(); err != nil {
panic(err) panic(err)
} }
@ -63,7 +64,7 @@ func TestMain(m *testing.M) {
func TestOpenAndClose(t *testing.T) { func TestOpenAndClose(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
setup() setup(devicePath)
defer teardown(devicePath) defer teardown(devicePath)
mapper := cryptmapper.New(&fakeKMS{}) mapper := cryptmapper.New(&fakeKMS{})
@ -81,8 +82,13 @@ func TestOpenAndClose(t *testing.T) {
_, err = os.Stat(newPath + "_dif") _, err = os.Stat(newPath + "_dif")
assert.True(os.IsNotExist(err)) assert.True(os.IsNotExist(err))
// Opening the same device should return the same path and not error
newPath2, err := mapper.OpenCryptDevice(context.Background(), devicePath, deviceName, false)
require.NoError(err)
assert.Equal(newPath, newPath2)
// Resize the device // Resize the device
resize() resize(devicePath)
resizedPath, err := mapper.ResizeCryptDevice(context.Background(), deviceName) resizedPath, err := mapper.ResizeCryptDevice(context.Background(), deviceName)
require.NoError(err) require.NoError(err)
@ -103,7 +109,7 @@ func TestOpenAndClose(t *testing.T) {
func TestOpenAndCloseIntegrity(t *testing.T) { func TestOpenAndCloseIntegrity(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
setup() setup(devicePath)
defer teardown(devicePath) defer teardown(devicePath)
mapper := cryptmapper.New(&fakeKMS{}) mapper := cryptmapper.New(&fakeKMS{})
@ -119,8 +125,13 @@ func TestOpenAndCloseIntegrity(t *testing.T) {
_, err = os.Stat(newPath + "_dif") _, err = os.Stat(newPath + "_dif")
assert.NoError(err) assert.NoError(err)
// Opening the same device should return the same path and not error
newPath2, err := mapper.OpenCryptDevice(context.Background(), devicePath, deviceName, true)
require.NoError(err)
assert.Equal(newPath, newPath2)
// integrity devices do not support resizing // integrity devices do not support resizing
resize() resize(devicePath)
_, err = mapper.ResizeCryptDevice(context.Background(), deviceName) _, err = mapper.ResizeCryptDevice(context.Background(), deviceName)
assert.Error(err) assert.Error(err)
@ -142,7 +153,7 @@ func TestOpenAndCloseIntegrity(t *testing.T) {
func TestDeviceCloning(t *testing.T) { func TestDeviceCloning(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
setup() setup(devicePath)
defer teardown(devicePath) defer teardown(devicePath)
mapper := cryptmapper.New(&dynamicKMS{}) mapper := cryptmapper.New(&dynamicKMS{})
@ -160,6 +171,41 @@ func TestDeviceCloning(t *testing.T) {
assert.NoError(mapper.CloseCryptDevice(deviceName + "-copy")) assert.NoError(mapper.CloseCryptDevice(deviceName + "-copy"))
} }
func TestConcurrency(t *testing.T) {
assert := assert.New(t)
setup(devicePath)
defer teardown(devicePath)
device2 := devicePath + "-2"
setup(device2)
defer teardown(device2)
mapper := cryptmapper.New(&fakeKMS{})
wg := sync.WaitGroup{}
runTest := func(path, name string) {
newPath, err := mapper.OpenCryptDevice(context.Background(), path, name, false)
assert.NoError(err)
defer func() {
_ = mapper.CloseCryptDevice(name)
}()
// assert crypt device got created
_, err = os.Stat(newPath)
assert.NoError(err)
// assert no integrity device got created
_, err = os.Stat(newPath + "_dif")
assert.True(os.IsNotExist(err))
assert.NoError(mapper.CloseCryptDevice(name))
wg.Done()
}
wg.Add(2)
go runTest(devicePath, deviceName)
go runTest(device2, deviceName+"-2")
wg.Wait()
}
type fakeKMS struct{} type fakeKMS struct{}
func (k *fakeKMS) GetDEK(_ context.Context, _ string, dekSize int) ([]byte, error) { func (k *fakeKMS) GetDEK(_ context.Context, _ string, dekSize int) ([]byte, error) {