constellation/csi/cryptmapper/cryptmapper.go
Daniel Weiße f69ae26122
csi: fix concurrent use of cryptmapper package (#2408)
* Dont error on opening already active devices

* Fix concurrency issues when working with more than one device

---------

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
2023-10-05 11:20:22 +02:00

280 lines
8.8 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
// Package cryptmapper provides a wrapper around libcryptsetup to manage dm-crypt volumes for CSI drivers.
package cryptmapper
import (
"context"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/cryptsetup"
)
const (
// LUKSHeaderSize is the amount of bytes taken up by the header of a LUKS2 partition.
// The header is 16MiB (1048576 Bytes * 16).
LUKSHeaderSize = 16777216
cryptPrefix = "/dev/mapper/"
integritySuffix = "_dif"
integrityFSSuffix = "-integrity"
keySizeIntegrity = 96
keySizeCrypt = 64
)
// CryptMapper manages dm-crypt volumes.
type CryptMapper struct {
mapper func() deviceMapper
kms keyCreator
getDiskFormat func(disk string) (string, error)
}
// New initializes a new CryptMapper with the given kms client and key-encryption-key ID.
// kms is used to fetch data encryption keys for the dm-crypt volumes.
func New(kms keyCreator) *CryptMapper {
return &CryptMapper{
mapper: func() deviceMapper { return cryptsetup.New() },
kms: kms,
getDiskFormat: getDiskFormat,
}
}
// CloseCryptDevice closes the crypt device mapped for volumeID.
// Returns nil if the volume does not exist.
func (c *CryptMapper) CloseCryptDevice(volumeID string) error {
source, err := filepath.EvalSymlinks(cryptPrefix + volumeID)
if err != nil {
var pathErr *fs.PathError
if errors.As(err, &pathErr) {
return nil
}
return fmt.Errorf("getting device path for disk %q: %w", cryptPrefix+volumeID, err)
}
if err := c.closeCryptDevice(source, volumeID, "crypt"); err != nil {
return fmt.Errorf("closing crypt device: %w", err)
}
integrity, err := filepath.EvalSymlinks(cryptPrefix + volumeID + integritySuffix)
if err == nil {
// If device was created with integrity, we need to also close the integrity device
integrityErr := c.closeCryptDevice(integrity, volumeID+integritySuffix, "integrity")
if integrityErr != nil {
return integrityErr
}
}
if err != nil {
var pathErr *fs.PathError
if errors.As(err, &pathErr) {
// integrity device does not exist
return nil
}
return fmt.Errorf("getting device path for disk %q: %w", cryptPrefix+volumeID, err)
}
return nil
}
// OpenCryptDevice maps the volume at source to the crypt device identified by volumeID.
// The key used to encrypt the volume is fetched using CryptMapper's kms client.
func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID string, integrity bool) (string, error) {
// Initialize the block device
mapper := c.mapper()
free, err := mapper.Init(source)
if err != nil {
return "", fmt.Errorf("initializing dm-crypt to map device %q: %w", source, err)
}
defer free()
deviceName := filepath.Join(cryptPrefix, volumeID)
var passphrase []byte
// Try to load LUKS headers
// If this fails, the device is either not formatted at all, or already formatted with a different FS
if err := mapper.LoadLUKS2(); err != nil {
passphrase, err = c.formatNewDevice(ctx, mapper, volumeID, source, integrity)
if err != nil {
return "", fmt.Errorf("formatting device: %w", err)
}
} else {
// Check if device is already active
// If yes, this is a no-op
// Simply return the device name
if _, err := os.Stat(deviceName); err == nil {
_, err := os.Stat(deviceName + integritySuffix)
if integrity && err != nil {
return "", fmt.Errorf("device %s already exists, but integrity device %s is missing", deviceName, deviceName+integritySuffix)
}
return deviceName, nil
}
uuid, err := mapper.GetUUID()
if err != nil {
return "", err
}
passphrase, err = c.kms.GetDEK(ctx, uuid, crypto.StateDiskKeyLength)
if err != nil {
return "", err
}
if len(passphrase) != crypto.StateDiskKeyLength {
return "", fmt.Errorf("expected key length to be [%d] but got [%d]", crypto.StateDiskKeyLength, len(passphrase))
}
}
if err := mapper.ActivateByPassphrase(volumeID, 0, string(passphrase), cryptsetup.ReadWriteQueueBypass); err != nil {
return "", fmt.Errorf("trying to activate dm-crypt volume: %w", err)
}
return deviceName, nil
}
// ResizeCryptDevice resizes the underlying crypt device and returns the mapped device path.
func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (string, error) {
mapper := c.mapper()
free, err := mapper.InitByName(volumeID)
if err != nil {
return "", fmt.Errorf("initializing device: %w", err)
}
defer free()
if err := mapper.LoadLUKS2(); err != nil {
return "", fmt.Errorf("loading device: %w", err)
}
uuid, err := mapper.GetUUID()
if err != nil {
return "", err
}
passphrase, err := c.kms.GetDEK(ctx, uuid, crypto.StateDiskKeyLength)
if err != nil {
return "", fmt.Errorf("getting key: %w", err)
}
if err := mapper.ActivateByPassphrase("", 0, string(passphrase), resizeFlags); err != nil {
return "", fmt.Errorf("activating keyring for crypt device %q with passphrase: %w", volumeID, err)
}
if err := mapper.Resize(volumeID, 0); err != nil {
return "", fmt.Errorf("resizing device: %w", err)
}
return cryptPrefix + volumeID, nil
}
// GetDevicePath returns the device path of a mapped crypt device.
func (c *CryptMapper) GetDevicePath(volumeID string) (string, error) {
mapper := c.mapper()
name := strings.TrimPrefix(volumeID, cryptPrefix)
free, err := mapper.InitByName(name)
if err != nil {
return "", fmt.Errorf("initializing device: %w", err)
}
defer free()
deviceName := mapper.GetDeviceName()
if deviceName == "" {
return "", errors.New("unable to determine device name")
}
return deviceName, nil
}
// closeCryptDevice closes the crypt device mapped for volumeID.
func (c *CryptMapper) closeCryptDevice(source, volumeID, deviceType string) error {
mapper := c.mapper()
free, err := mapper.InitByName(volumeID)
if err != nil {
return fmt.Errorf("initializing dm-%s to unmap device %q: %w", deviceType, source, err)
}
defer free()
if err := mapper.Deactivate(volumeID); err != nil {
return fmt.Errorf("deactivating dm-%s volume %q for device %q: %w", deviceType, cryptPrefix+volumeID, source, err)
}
return nil
}
func (c *CryptMapper) formatNewDevice(ctx context.Context, mapper deviceMapper, volumeID, source string, integrity bool) ([]byte, error) {
format, err := c.getDiskFormat(source)
if err != nil {
return nil, fmt.Errorf("determining if disk is formatted: %w", err)
}
if format != "" {
return nil, fmt.Errorf("disk %q is already formatted as: %s", source, format)
}
// Device is not formatted, so we can safely create a new LUKS2 partition
if err := mapper.Format(integrity); err != nil {
return nil, fmt.Errorf("formatting device %q: %w", source, err)
}
uuid, err := mapper.GetUUID()
if err != nil {
return nil, err
}
passphrase, err := c.kms.GetDEK(ctx, uuid, crypto.StateDiskKeyLength)
if err != nil {
return nil, err
}
if len(passphrase) != crypto.StateDiskKeyLength {
return nil, fmt.Errorf("expected key length to be [%d] but got [%d]", crypto.StateDiskKeyLength, len(passphrase))
}
// Add a new keyslot using the internal volume key
if err := mapper.KeyslotAddByVolumeKey(0, "", string(passphrase)); err != nil {
return nil, fmt.Errorf("adding keyslot: %w", err)
}
if integrity {
logProgress := func(size, offset uint64) {
prog := (float64(offset) / float64(size)) * 100
fmt.Printf("Wipe in progress: %.2f%%\n", prog)
}
if err := mapper.Wipe(volumeID, 1024*1024, 0, logProgress, 30*time.Second); err != nil {
return nil, fmt.Errorf("wiping device: %w", err)
}
}
return passphrase, nil
}
// IsIntegrityFS checks if the fstype string contains an integrity suffix.
// If yes, returns the trimmed fstype and true, fstype and false otherwise.
func IsIntegrityFS(fstype string) (string, bool) {
if strings.HasSuffix(fstype, integrityFSSuffix) {
return strings.TrimSuffix(fstype, integrityFSSuffix), true
}
return fstype, false
}
// deviceMapper is an interface for device mapper methods.
type deviceMapper interface {
Init(devicePath string) (func(), error)
InitByName(name string) (func(), error)
ActivateByPassphrase(deviceName string, keyslot int, passphrase string, flags int) error
ActivateByVolumeKey(deviceName string, volumeKey string, volumeKeySize int, flags int) error
Deactivate(deviceName string) error
Format(integrity bool) error
Free()
GetDeviceName() string
GetUUID() (string, error)
LoadLUKS2() error
KeyslotAddByVolumeKey(keyslot int, volumeKey string, passphrase string) error
Wipe(name string, wipeBlockSize int, flags int, progress func(size, offset uint64), frequency time.Duration) error
Resize(name string, newSize uint64) error
}
// keyCreator is an interface to create data encryption keys.
type keyCreator interface {
GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error)
}