Refactor init/recovery to use kms URI

So far the masterSecret was sent to the initial bootstrapper
on init/recovery. With this commit this information is encoded
in the kmsURI that is sent during init.
For recover, the communication with the recoveryserver is
changed. Before a streaming gRPC call was used to
exchanges UUID for measurementSecret and state disk key.
Now a standard gRPC is made that includes the same kmsURI &
storageURI that are sent during init.
This commit is contained in:
Otto Bittner 2023-01-16 11:19:03 +01:00
parent 0e71322e2e
commit 9a1f52e94e
35 changed files with 466 additions and 623 deletions

View file

@ -23,7 +23,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog" "github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
"github.com/edgelesssys/constellation/v2/internal/kms/kms" "github.com/edgelesssys/constellation/v2/internal/kms/kms"
kmsSetup "github.com/edgelesssys/constellation/v2/internal/kms/setup" kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/nodestate" "github.com/edgelesssys/constellation/v2/internal/nodestate"
"github.com/edgelesssys/constellation/v2/internal/role" "github.com/edgelesssys/constellation/v2/internal/role"
@ -110,8 +110,13 @@ func (s *Server) Init(ctx context.Context, req *initproto.InitRequest) (*initpro
return nil, status.Errorf(codes.Internal, "invalid init secret %s", err) return nil, status.Errorf(codes.Internal, "invalid init secret %s", err)
} }
cloudKms, err := kmssetup.KMS(ctx, req.StorageUri, req.KmsUri)
if err != nil {
return nil, fmt.Errorf("creating kms client: %w", err)
}
// generate values for cluster attestation // generate values for cluster attestation
measurementSalt, clusterID, err := deriveMeasurementValues(req.MasterSecret, req.Salt) measurementSalt, clusterID, err := deriveMeasurementValues(ctx, cloudKms)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "deriving measurement values: %s", err) return nil, status.Errorf(codes.Internal, "deriving measurement values: %s", err)
} }
@ -130,7 +135,7 @@ func (s *Server) Init(ctx context.Context, req *initproto.InitRequest) (*initpro
return nil, status.Error(codes.FailedPrecondition, "node is already being activated") return nil, status.Error(codes.FailedPrecondition, "node is already being activated")
} }
if err := s.setupDisk(req.MasterSecret, req.Salt); err != nil { if err := s.setupDisk(ctx, cloudKms); err != nil {
return nil, status.Errorf(codes.Internal, "setting up disk: %s", err) return nil, status.Errorf(codes.Internal, "setting up disk: %s", err)
} }
@ -177,7 +182,7 @@ func (s *Server) Stop() {
s.log.Infof("Stopped") s.log.Infof("Stopped")
} }
func (s *Server) setupDisk(masterSecret, salt []byte) error { func (s *Server) setupDisk(ctx context.Context, cloudKms kms.CloudKMS) error {
if err := s.disk.Open(); err != nil { if err := s.disk.Open(); err != nil {
return fmt.Errorf("opening encrypted disk: %w", err) return fmt.Errorf("opening encrypted disk: %w", err)
} }
@ -189,7 +194,7 @@ func (s *Server) setupDisk(masterSecret, salt []byte) error {
} }
uuid = strings.ToLower(uuid) uuid = strings.ToLower(uuid)
diskKey, err := crypto.DeriveKey(masterSecret, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.DerivedKeyLengthDefault) diskKey, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+uuid, crypto.StateDiskKeyLength)
if err != nil { if err != nil {
return err return err
} }
@ -197,12 +202,12 @@ func (s *Server) setupDisk(masterSecret, salt []byte) error {
return s.disk.UpdatePassphrase(string(diskKey)) return s.disk.UpdatePassphrase(string(diskKey))
} }
func deriveMeasurementValues(masterSecret, hkdfSalt []byte) (salt, clusterID []byte, err error) { func deriveMeasurementValues(ctx context.Context, cloudKms kms.CloudKMS) (salt, clusterID []byte, err error) {
salt, err = crypto.GenerateRandomBytes(crypto.RNGLengthDefault) salt, err = crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
secret, err := attestation.DeriveMeasurementSecret(masterSecret, hkdfSalt) secret, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+crypto.MeasurementSecretKeyID, crypto.DerivedKeyLengthDefault)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -19,6 +19,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/crypto/testvector" "github.com/edgelesssys/constellation/v2/internal/crypto/testvector"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/oid" "github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/edgelesssys/constellation/v2/internal/versions/components" "github.com/edgelesssys/constellation/v2/internal/versions/components"
@ -30,7 +31,10 @@ import (
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
goleak.VerifyTestMain(m) goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
} }
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
@ -86,6 +90,8 @@ func TestInit(t *testing.T) {
initSecretHash, err := bcrypt.GenerateFromPassword(initSecret, bcrypt.DefaultCost) initSecretHash, err := bcrypt.GenerateFromPassword(initSecret, bcrypt.DefaultCost)
require.NoError(t, err) require.NoError(t, err)
masterSecret := kmssetup.MasterSecret{Key: []byte("secret"), Salt: []byte("salt")}
testCases := map[string]struct { testCases := map[string]struct {
nodeLock *fakeLock nodeLock *fakeLock
initializer ClusterInitializer initializer ClusterInitializer
@ -102,14 +108,14 @@ func TestInit(t *testing.T) {
disk: &stubDisk{}, disk: &stubDisk{},
fileHandler: file.NewHandler(afero.NewMemMapFs()), fileHandler: file.NewHandler(afero.NewMemMapFs()),
initSecretHash: initSecretHash, initSecretHash: initSecretHash,
req: &initproto.InitRequest{InitSecret: initSecret}, req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
}, },
"node locked": { "node locked": {
nodeLock: lockedLock, nodeLock: lockedLock,
initializer: &stubClusterInitializer{}, initializer: &stubClusterInitializer{},
disk: &stubDisk{}, disk: &stubDisk{},
fileHandler: file.NewHandler(afero.NewMemMapFs()), fileHandler: file.NewHandler(afero.NewMemMapFs()),
req: &initproto.InitRequest{InitSecret: initSecret}, req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash, initSecretHash: initSecretHash,
wantErr: true, wantErr: true,
wantShutdown: true, wantShutdown: true,
@ -119,7 +125,7 @@ func TestInit(t *testing.T) {
initializer: &stubClusterInitializer{}, initializer: &stubClusterInitializer{},
disk: &stubDisk{openErr: someErr}, disk: &stubDisk{openErr: someErr},
fileHandler: file.NewHandler(afero.NewMemMapFs()), fileHandler: file.NewHandler(afero.NewMemMapFs()),
req: &initproto.InitRequest{InitSecret: initSecret}, req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash, initSecretHash: initSecretHash,
wantErr: true, wantErr: true,
}, },
@ -128,7 +134,7 @@ func TestInit(t *testing.T) {
initializer: &stubClusterInitializer{}, initializer: &stubClusterInitializer{},
disk: &stubDisk{uuidErr: someErr}, disk: &stubDisk{uuidErr: someErr},
fileHandler: file.NewHandler(afero.NewMemMapFs()), fileHandler: file.NewHandler(afero.NewMemMapFs()),
req: &initproto.InitRequest{InitSecret: initSecret}, req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash, initSecretHash: initSecretHash,
wantErr: true, wantErr: true,
}, },
@ -137,7 +143,7 @@ func TestInit(t *testing.T) {
initializer: &stubClusterInitializer{}, initializer: &stubClusterInitializer{},
disk: &stubDisk{updatePassphraseErr: someErr}, disk: &stubDisk{updatePassphraseErr: someErr},
fileHandler: file.NewHandler(afero.NewMemMapFs()), fileHandler: file.NewHandler(afero.NewMemMapFs()),
req: &initproto.InitRequest{InitSecret: initSecret}, req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash, initSecretHash: initSecretHash,
wantErr: true, wantErr: true,
}, },
@ -146,7 +152,7 @@ func TestInit(t *testing.T) {
initializer: &stubClusterInitializer{}, initializer: &stubClusterInitializer{},
disk: &stubDisk{}, disk: &stubDisk{},
fileHandler: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())), fileHandler: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())),
req: &initproto.InitRequest{InitSecret: initSecret}, req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash, initSecretHash: initSecretHash,
wantErr: true, wantErr: true,
}, },
@ -155,7 +161,7 @@ func TestInit(t *testing.T) {
initializer: &stubClusterInitializer{initClusterErr: someErr}, initializer: &stubClusterInitializer{initClusterErr: someErr},
disk: &stubDisk{}, disk: &stubDisk{},
fileHandler: file.NewHandler(afero.NewMemMapFs()), fileHandler: file.NewHandler(afero.NewMemMapFs()),
req: &initproto.InitRequest{InitSecret: initSecret}, req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash, initSecretHash: initSecretHash,
wantErr: true, wantErr: true,
}, },
@ -212,19 +218,19 @@ func TestInit(t *testing.T) {
func TestSetupDisk(t *testing.T) { func TestSetupDisk(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
uuid string uuid string
masterSecret []byte masterKey []byte
salt []byte salt []byte
wantKey []byte wantKey []byte
}{ }{
"lower case uuid": { "lower case uuid": {
uuid: strings.ToLower(testvector.HKDF0xFF.Info), uuid: strings.ToLower(testvector.HKDF0xFF.Info),
masterSecret: testvector.HKDF0xFF.Secret, masterKey: testvector.HKDF0xFF.Secret,
salt: testvector.HKDF0xFF.Salt, salt: testvector.HKDF0xFF.Salt,
wantKey: testvector.HKDF0xFF.Output, wantKey: testvector.HKDF0xFF.Output,
}, },
"upper case uuid": { "upper case uuid": {
uuid: strings.ToUpper(testvector.HKDF0xFF.Info), uuid: strings.ToUpper(testvector.HKDF0xFF.Info),
masterSecret: testvector.HKDF0xFF.Secret, masterKey: testvector.HKDF0xFF.Secret,
salt: testvector.HKDF0xFF.Salt, salt: testvector.HKDF0xFF.Salt,
wantKey: testvector.HKDF0xFF.Output, wantKey: testvector.HKDF0xFF.Output,
}, },
@ -233,6 +239,7 @@ func TestSetupDisk(t *testing.T) {
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t)
disk := &fakeDisk{ disk := &fakeDisk{
uuid: tc.uuid, uuid: tc.uuid,
@ -242,7 +249,11 @@ func TestSetupDisk(t *testing.T) {
disk: disk, disk: disk,
} }
assert.NoError(server.setupDisk(tc.masterSecret, tc.salt)) masterSecret := kmssetup.MasterSecret{Key: tc.masterKey, Salt: tc.salt}
cloudKms, err := kmssetup.KMS(context.Background(), kmssetup.NoStoreURI, masterSecret.EncodeToURI())
require.NoError(err)
assert.NoError(server.setupDisk(context.Background(), cloudKms))
}) })
} }
} }

View file

@ -30,7 +30,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry" grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
keyservice "github.com/edgelesssys/constellation/v2/internal/kms/setup" kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/license" "github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/retry" "github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/versions" "github.com/edgelesssys/constellation/v2/internal/versions"
@ -146,8 +146,8 @@ func (i *initCmd) initialize(cmd *cobra.Command, newDialer func(validator *cloud
req := &initproto.InitRequest{ req := &initproto.InitRequest{
MasterSecret: masterSecret.Key, MasterSecret: masterSecret.Key,
Salt: masterSecret.Salt, Salt: masterSecret.Salt,
KmsUri: keyservice.ClusterKMSURI, KmsUri: masterSecret.EncodeToURI(),
StorageUri: keyservice.NoStoreURI, StorageUri: kmssetup.NoStoreURI,
KeyEncryptionKeyId: "", KeyEncryptionKeyId: "",
UseExistingKek: false, UseExistingKek: false,
CloudServiceAccountUri: serviceAccURI, CloudServiceAccountUri: serviceAccURI,
@ -296,26 +296,20 @@ type initFlags struct {
conformance bool conformance bool
} }
// masterSecret holds the master key and salt for deriving keys.
type masterSecret struct {
Key []byte `json:"key"`
Salt []byte `json:"salt"`
}
// readOrGenerateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret. // readOrGenerateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler file.Handler, filename string) (masterSecret, error) { func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler file.Handler, filename string) (kmssetup.MasterSecret, error) {
if filename != "" { if filename != "" {
i.log.Debugf("Reading master secret from file %q", filename) i.log.Debugf("Reading master secret from file %q", filename)
var secret masterSecret var secret kmssetup.MasterSecret
if err := fileHandler.ReadJSON(filename, &secret); err != nil { if err := fileHandler.ReadJSON(filename, &secret); err != nil {
return masterSecret{}, err return kmssetup.MasterSecret{}, err
} }
if len(secret.Key) < crypto.MasterSecretLengthMin { if len(secret.Key) < crypto.MasterSecretLengthMin {
return masterSecret{}, fmt.Errorf("provided master secret is smaller than the required minimum of %d Bytes", crypto.MasterSecretLengthMin) return kmssetup.MasterSecret{}, fmt.Errorf("provided master secret is smaller than the required minimum of %d Bytes", crypto.MasterSecretLengthMin)
} }
if len(secret.Salt) < crypto.RNGLengthDefault { if len(secret.Salt) < crypto.RNGLengthDefault {
return masterSecret{}, fmt.Errorf("provided salt is smaller than the required minimum of %d Bytes", crypto.RNGLengthDefault) return kmssetup.MasterSecret{}, fmt.Errorf("provided salt is smaller than the required minimum of %d Bytes", crypto.RNGLengthDefault)
} }
return secret, nil return secret, nil
} }
@ -324,19 +318,19 @@ func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler fi
i.log.Debugf("Generating new master secret") i.log.Debugf("Generating new master secret")
key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault) key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault)
if err != nil { if err != nil {
return masterSecret{}, err return kmssetup.MasterSecret{}, err
} }
salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault) salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil { if err != nil {
return masterSecret{}, err return kmssetup.MasterSecret{}, err
} }
secret := masterSecret{ secret := kmssetup.MasterSecret{
Key: key, Key: key,
Salt: salt, Salt: salt,
} }
i.log.Debugf("Generated master secret key and salt values") i.log.Debugf("Generated master secret key and salt values")
if err := fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil { if err := fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
return masterSecret{}, err return kmssetup.MasterSecret{}, err
} }
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to ./%s\n", constants.MasterSecretFilename) fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to ./%s\n", constants.MasterSecretFilename)
return secret, nil return secret, nil

View file

@ -18,6 +18,8 @@ import (
"testing" "testing"
"time" "time"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid" "github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
@ -183,7 +185,7 @@ func TestInitialize(t *testing.T) {
require.NoError(err) require.NoError(err)
// assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID"))) // assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID")))
assert.Contains(out.String(), hex.EncodeToString([]byte("clusterID"))) assert.Contains(out.String(), hex.EncodeToString([]byte("clusterID")))
var secret masterSecret var secret kmssetup.MasterSecret
assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret)) assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret))
assert.NotEmpty(secret.Key) assert.NotEmpty(secret.Key)
assert.NotEmpty(secret.Salt) assert.NotEmpty(secret.Salt)
@ -251,7 +253,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
createFileFunc: func(handler file.Handler) error { createFileFunc: func(handler file.Handler) error {
return handler.WriteJSON( return handler.WriteJSON(
"someSecret", "someSecret",
masterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")}, kmssetup.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")},
file.OptNone, file.OptNone,
) )
}, },
@ -282,7 +284,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
createFileFunc: func(handler file.Handler) error { createFileFunc: func(handler file.Handler) error {
return handler.WriteJSON( return handler.WriteJSON(
"shortSecret", "shortSecret",
masterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("short")}, kmssetup.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("short")},
file.OptNone, file.OptNone,
) )
}, },
@ -294,7 +296,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
createFileFunc: func(handler file.Handler) error { createFileFunc: func(handler file.Handler) error {
return handler.WriteJSON( return handler.WriteJSON(
"shortSecret", "shortSecret",
masterSecret{Key: []byte("short"), Salt: []byte("constellation-32Byte-length-salt")}, kmssetup.MasterSecret{Key: []byte("short"), Salt: []byte("constellation-32Byte-length-salt")},
file.OptNone, file.OptNone,
) )
}, },
@ -340,7 +342,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
tc.filename = strings.Trim(filename[1], "\n") tc.filename = strings.Trim(filename[1], "\n")
} }
var masterSecret masterSecret var masterSecret kmssetup.MasterSecret
require.NoError(fileHandler.ReadJSON(tc.filename, &masterSecret)) require.NoError(fileHandler.ReadJSON(tc.filename, &masterSecret))
assert.Equal(masterSecret.Key, secret.Key) assert.Equal(masterSecret.Key, secret.Key)
assert.Equal(masterSecret.Salt, secret.Salt) assert.Equal(masterSecret.Salt, secret.Salt)

View file

@ -8,7 +8,6 @@ package cmd
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -17,7 +16,6 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto" "github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
@ -25,6 +23,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry" grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/retry" "github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -73,7 +72,7 @@ func (r *recoverCmd) recover(
} }
r.log.Debugf("Using flags: %+v", flags) r.log.Debugf("Using flags: %+v", flags)
var masterSecret masterSecret var masterSecret kmssetup.MasterSecret
r.log.Debugf("Loading master secret file from %s", flags.secretPath) r.log.Debugf("Loading master secret file from %s", flags.secretPath)
if err := fileHandler.ReadJSON(flags.secretPath, &masterSecret); err != nil { if err := fileHandler.ReadJSON(flags.secretPath, &masterSecret); err != nil {
return err return err
@ -97,12 +96,7 @@ func (r *recoverCmd) recover(
r.log.Debugf("Created a new validator") r.log.Debugf("Created a new validator")
doer.setDialer(newDialer(validator), flags.endpoint) doer.setDialer(newDialer(validator), flags.endpoint)
r.log.Debugf("Set dialer for endpoint %s", flags.endpoint) r.log.Debugf("Set dialer for endpoint %s", flags.endpoint)
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt) doer.setURIs(masterSecret.EncodeToURI(), kmssetup.NoStoreURI)
r.log.Debugf("Derived measurementSecret")
if err != nil {
return err
}
doer.setSecrets(getStateDiskKeyFunc(masterSecret.Key, masterSecret.Salt), measurementSecret)
r.log.Debugf("Set secrets") r.log.Debugf("Set secrets")
if err := r.recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil { if err := r.recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil {
if grpcRetry.ServiceIsUnavailable(err) { if grpcRetry.ServiceIsUnavailable(err) {
@ -157,14 +151,14 @@ func (r *recoverCmd) recoverCall(ctx context.Context, out io.Writer, interval ti
type recoverDoerInterface interface { type recoverDoerInterface interface {
Do(ctx context.Context) error Do(ctx context.Context) error
setDialer(dialer grpcDialer, endpoint string) setDialer(dialer grpcDialer, endpoint string)
setSecrets(getDiskKey func(uuid string) ([]byte, error), measurementSecret []byte) setURIs(kmsURI, storageURI string)
} }
type recoverDoer struct { type recoverDoer struct {
dialer grpcDialer dialer grpcDialer
endpoint string endpoint string
measurementSecret []byte kmsURI string // encodes masterSecret
getDiskKey func(uuid string) (key []byte, err error) storageURI string
log debugLog log debugLog
} }
@ -177,53 +171,19 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
d.log.Debugf("Dialed recovery server") d.log.Debugf("Dialed recovery server")
defer conn.Close() defer conn.Close()
// set up streaming client
protoClient := recoverproto.NewAPIClient(conn) protoClient := recoverproto.NewAPIClient(conn)
d.log.Debugf("Created protoClient") d.log.Debugf("Created protoClient")
recoverclient, err := protoClient.Recover(ctx)
d.log.Debugf("Created recoverclient") req := &recoverproto.RecoverMessage{
KmsUri: d.kmsURI,
StorageUri: d.storageURI,
}
_, err = protoClient.Recover(ctx, req)
if err != nil { if err != nil {
return fmt.Errorf("creating client: %w", err) return fmt.Errorf("calling recover: %w", err)
} }
defer func() {
_ = recoverclient.CloseSend()
}()
// send measurement secret as first message
if err := recoverclient.Send(&recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: d.measurementSecret,
},
}); err != nil {
return fmt.Errorf("sending measurement secret: %w", err)
}
d.log.Debugf("Sent measurement secret")
// receive disk uuid
res, err := recoverclient.Recv()
if err != nil {
return fmt.Errorf("receiving disk uuid: %w", err)
}
d.log.Debugf("Received disk uuid")
stateDiskKey, err := d.getDiskKey(res.DiskUuid)
if err != nil {
return fmt.Errorf("getting state disk key: %w", err)
}
d.log.Debugf("Got state disk key")
// send disk key
if err := recoverclient.Send(&recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: stateDiskKey,
},
}); err != nil {
return fmt.Errorf("sending state disk key: %w", err)
}
d.log.Debugf("Sent state disk key")
if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("receiving confirmation: %w", err)
}
d.log.Debugf("Received confirmation") d.log.Debugf("Received confirmation")
return nil return nil
} }
@ -233,9 +193,9 @@ func (d *recoverDoer) setDialer(dialer grpcDialer, endpoint string) {
d.endpoint = endpoint d.endpoint = endpoint
} }
func (d *recoverDoer) setSecrets(getDiskKey func(string) ([]byte, error), measurementSecret []byte) { func (d *recoverDoer) setURIs(kmsURI, storageURI string) {
d.getDiskKey = getDiskKey d.kmsURI = kmsURI
d.measurementSecret = measurementSecret d.storageURI = storageURI
} }
type recoverFlags struct { type recoverFlags struct {
@ -281,6 +241,6 @@ func (r *recoverCmd) parseRecoverFlags(cmd *cobra.Command, fileHandler file.Hand
func getStateDiskKeyFunc(masterKey, salt []byte) func(uuid string) ([]byte, error) { func getStateDiskKeyFunc(masterKey, salt []byte) func(uuid string) ([]byte, error) {
return func(uuid string) ([]byte, error) { return func(uuid string) ([]byte, error) {
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.StateDiskKeyLength) return crypto.DeriveKey(masterKey, salt, []byte(crypto.DEKPrefix+uuid), crypto.StateDiskKeyLength)
} }
} }

View file

@ -26,6 +26,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer" "github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -156,7 +157,7 @@ func TestRecover(t *testing.T) {
require.NoError(fileHandler.WriteJSON( require.NoError(fileHandler.WriteJSON(
"constellation-mastersecret.json", "constellation-mastersecret.json",
masterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt}, kmssetup.MasterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
file.OptNone, file.OptNone,
)) ))
@ -244,100 +245,15 @@ func TestParseRecoverFlags(t *testing.T) {
} }
func TestDoRecovery(t *testing.T) { func TestDoRecovery(t *testing.T) {
someErr := errors.New("error")
testCases := map[string]struct { testCases := map[string]struct {
recoveryServer *stubRecoveryServer recoveryServer *stubRecoveryServer
wantErr bool wantErr bool
}{ }{
"success": { "success": {
recoveryServer: &stubRecoveryServer{ recoveryServer: &stubRecoveryServer{},
actions: [][]func(stream recoverproto.API_RecoverServer) error{{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return stream.Send(&recoverproto.RecoverResponse{
DiskUuid: "00000000-0000-0000-0000-000000000000",
})
},
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
}},
},
},
"error on first recv": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{
{
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
},
},
},
wantErr: true,
},
"error on send": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{
{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
},
},
},
wantErr: true,
},
"error on second recv": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{
{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return stream.Send(&recoverproto.RecoverResponse{
DiskUuid: "00000000-0000-0000-0000-000000000000",
})
},
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
},
},
},
wantErr: true,
},
"final message is an error": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return stream.Send(&recoverproto.RecoverResponse{
DiskUuid: "00000000-0000-0000-0000-000000000000",
})
},
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
}},
}, },
"server responds with error": {
recoveryServer: &stubRecoveryServer{recoverError: errors.New("someErr")},
wantErr: true, wantErr: true,
}, },
} }
@ -359,10 +275,6 @@ func TestDoRecovery(t *testing.T) {
recoverDoer := &recoverDoer{ recoverDoer := &recoverDoer{
dialer: dialer.New(nil, nil, netDialer), dialer: dialer.New(nil, nil, netDialer),
endpoint: addr, endpoint: addr,
measurementSecret: []byte("measurement-secret"),
getDiskKey: func(string) ([]byte, error) {
return []byte("disk-key"), nil
},
log: r.log, log: r.log,
} }
@ -402,23 +314,15 @@ func TestDeriveStateDiskKey(t *testing.T) {
} }
type stubRecoveryServer struct { type stubRecoveryServer struct {
actions [][]func(recoverproto.API_RecoverServer) error recoverError error
calls int
recoverproto.UnimplementedAPIServer recoverproto.UnimplementedAPIServer
} }
func (s *stubRecoveryServer) Recover(stream recoverproto.API_RecoverServer) error { func (s *stubRecoveryServer) Recover(context.Context, *recoverproto.RecoverMessage) (*recoverproto.RecoverResponse, error) {
if s.calls >= len(s.actions) { if s.recoverError != nil {
return status.Error(codes.Unavailable, "server is unavailable") return nil, s.recoverError
} }
s.calls++ return &recoverproto.RecoverResponse{}, nil
for _, action := range s.actions[s.calls-1] {
if err := action(stream); err != nil {
return err
}
}
return nil
} }
type stubDoer struct { type stubDoer struct {
@ -437,4 +341,4 @@ func (d *stubDoer) Do(context.Context) error {
func (d *stubDoer) setDialer(grpcDialer, string) {} func (d *stubDoer) setDialer(grpcDialer, string) {}
func (d *stubDoer) setSecrets(func(string) ([]byte, error), []byte) {} func (d *stubDoer) setURIs(kmsURI, storageURI string) {}

View file

@ -34,6 +34,7 @@ import (
qemucloud "github.com/edgelesssys/constellation/v2/internal/cloud/qemu" qemucloud "github.com/edgelesssys/constellation/v2/internal/cloud/qemu"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/role" "github.com/edgelesssys/constellation/v2/internal/role"
tpmClient "github.com/google/go-tpm-tools/client" tpmClient "github.com/google/go-tpm-tools/client"
@ -151,7 +152,7 @@ func main() {
// set up recovery server if control-plane node // set up recovery server if control-plane node
var recoveryServer setup.RecoveryServer var recoveryServer setup.RecoveryServer
if self.Role == role.ControlPlane { if self.Role == role.ControlPlane {
recoveryServer = recoveryserver.New(issuer, log.Named("recoveryServer")) recoveryServer = recoveryserver.New(issuer, kmssetup.KMS, log.Named("recoveryServer"))
} else { } else {
recoveryServer = recoveryserver.NewStub(log.Named("recoveryServer")) recoveryServer = recoveryserver.NewStub(log.Named("recoveryServer"))
} }

View file

@ -13,8 +13,10 @@ import (
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto" "github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog" "github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
"github.com/edgelesssys/constellation/v2/internal/kms/kms"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -22,6 +24,8 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
type kmsFactory func(ctx context.Context, storageURI string, kmsURI string) (kms.CloudKMS, error)
// RecoveryServer is a gRPC server that can be used by an admin to recover a restarting node. // RecoveryServer is a gRPC server that can be used by an admin to recover a restarting node.
type RecoveryServer struct { type RecoveryServer struct {
mux sync.Mutex mux sync.Mutex
@ -30,6 +34,7 @@ type RecoveryServer struct {
stateDiskKey []byte stateDiskKey []byte
measurementSecret []byte measurementSecret []byte
grpcServer server grpcServer server
factory kmsFactory
log *logger.Logger log *logger.Logger
@ -37,9 +42,10 @@ type RecoveryServer struct {
} }
// New returns a new RecoveryServer. // New returns a new RecoveryServer.
func New(issuer atls.Issuer, log *logger.Logger) *RecoveryServer { func New(issuer atls.Issuer, factory kmsFactory, log *logger.Logger) *RecoveryServer {
server := &RecoveryServer{ server := &RecoveryServer{
log: log, log: log,
factory: factory,
} }
grpcServer := grpc.NewServer( grpcServer := grpc.NewServer(
@ -87,47 +93,32 @@ func (s *RecoveryServer) Serve(ctx context.Context, listener net.Listener, diskU
} }
// Recover is a bidirectional streaming RPC that is used to send recovery keys to a restarting node. // Recover is a bidirectional streaming RPC that is used to send recovery keys to a restarting node.
func (s *RecoveryServer) Recover(stream recoverproto.API_RecoverServer) error { func (s *RecoveryServer) Recover(ctx context.Context, req *recoverproto.RecoverMessage) (*recoverproto.RecoverResponse, error) {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
log := s.log.With(zap.String("peer", grpclog.PeerAddrFromContext(stream.Context()))) log := s.log.With(zap.String("peer", grpclog.PeerAddrFromContext(ctx)))
log.Infof("Received recover call") log.Infof("Received recover call")
msg, err := stream.Recv() cloudKms, err := s.factory(ctx, req.StorageUri, req.KmsUri)
if err != nil { if err != nil {
return status.Error(codes.Internal, "failed to receive message") return nil, status.Errorf(codes.Internal, "creating kms client: %s", err)
} }
measurementSecret, ok := msg.GetRequest().(*recoverproto.RecoverMessage_MeasurementSecret) measurementSecret, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+crypto.MeasurementSecretKeyID, crypto.DerivedKeyLengthDefault)
if !ok {
log.Errorf("Received invalid first message: not a measurement secret")
return status.Error(codes.InvalidArgument, "first message is not a measurement secret")
}
if err := stream.Send(&recoverproto.RecoverResponse{DiskUuid: s.diskUUID}); err != nil {
log.With(zap.Error(err)).Errorf("Failed to send disk UUID")
return status.Error(codes.Internal, "failed to send response")
}
msg, err = stream.Recv()
if err != nil { if err != nil {
log.With(zap.Error(err)).Errorf("Failed to receive disk key") return nil, status.Errorf(codes.Internal, "requesting measurementSecret: %s", err)
return status.Error(codes.Internal, "failed to receive message")
} }
stateDiskKey, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+s.diskUUID, crypto.StateDiskKeyLength)
stateDiskKey, ok := msg.GetRequest().(*recoverproto.RecoverMessage_StateDiskKey) if err != nil {
if !ok { return nil, status.Errorf(codes.Internal, "requesting stateDiskKey: %s", err)
log.Errorf("Received invalid second message: not a state disk key")
return status.Error(codes.InvalidArgument, "second message is not a state disk key")
} }
s.stateDiskKey = stateDiskKey
s.stateDiskKey = stateDiskKey.StateDiskKey s.measurementSecret = measurementSecret
s.measurementSecret = measurementSecret.MeasurementSecret
log.Infof("Received state disk key and measurement secret, shutting down server") log.Infof("Received state disk key and measurement secret, shutting down server")
go s.grpcServer.GracefulStop() go s.grpcServer.GracefulStop()
return nil return &recoverproto.RecoverResponse{}, nil
} }
// StubServer implements the RecoveryServer interface but does not actually start a server. // StubServer implements the RecoveryServer interface but does not actually start a server.

View file

@ -8,7 +8,7 @@ package recoveryserver
import ( import (
"context" "context"
"io" "errors"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -17,6 +17,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer" "github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/v2/internal/kms/kms"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/oid" "github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -25,14 +26,17 @@ import (
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
goleak.VerifyTestMain(m) goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
} }
func TestServe(t *testing.T) { func TestServe(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
log := logger.NewTest(t) log := logger.NewTest(t)
uuid := "uuid" uuid := "uuid"
server := New(atls.NewFakeIssuer(oid.Dummy{}), log) server := New(atls.NewFakeIssuer(oid.Dummy{}), newStubKMS(nil, nil), log)
dialer := testdialer.NewBufconnDialer() dialer := testdialer.NewBufconnDialer()
listener := dialer.GetListener("192.0.2.1:1234") listener := dialer.GetListener("192.0.2.1:1234")
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -49,7 +53,7 @@ func TestServe(t *testing.T) {
cancel() cancel()
wg.Wait() wg.Wait()
server = New(atls.NewFakeIssuer(oid.Dummy{}), log) server = New(atls.NewFakeIssuer(oid.Dummy{}), newStubKMS(nil, nil), log)
dialer = testdialer.NewBufconnDialer() dialer = testdialer.NewBufconnDialer()
listener = dialer.GetListener("192.0.2.1:1234") listener = dialer.GetListener("192.0.2.1:1234")
@ -71,60 +75,27 @@ func TestServe(t *testing.T) {
func TestRecover(t *testing.T) { func TestRecover(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
initialMsg message kmsURI string
keyMsg message storageURI string
factory kmsFactory
wantErr bool wantErr bool
}{ }{
"success": { "success": {
initialMsg: message{ // base64 encoded: key=masterkey&salt=somesalt
recoverMsg: &recoverproto.RecoverMessage{ kmsURI: "kms://cluster-kms?key=bWFzdGVya2V5&salt=c29tZXNhbHQ=",
Request: &recoverproto.RecoverMessage_MeasurementSecret{ storageURI: "storage://no-store",
MeasurementSecret: []byte("measurementSecret"), factory: newStubKMS(nil, nil),
},
},
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
},
},
},
"first message is not a measurement secret": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
}, },
"kms init fails": {
factory: newStubKMS(errors.New("setup failed"), nil),
wantErr: true, wantErr: true,
}, },
keyMsg: message{ "GetDEK fails": {
recoverMsg: &recoverproto.RecoverMessage{ kmsURI: "kms://cluster-kms?key=bWFzdGVya2V5&salt=c29tZXNhbHQ=",
Request: &recoverproto.RecoverMessage_StateDiskKey{ storageURI: "storage://no-store",
StateDiskKey: []byte("diskKey"), factory: newStubKMS(nil, errors.New("GetDEK failed")),
},
},
},
},
"second message is not a state disk key": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
wantErr: true, wantErr: true,
}, },
},
} }
for name, tc := range testCases { for name, tc := range testCases {
@ -134,7 +105,7 @@ func TestRecover(t *testing.T) {
ctx := context.Background() ctx := context.Background()
serverUUID := "uuid" serverUUID := "uuid"
server := New(atls.NewFakeIssuer(oid.Dummy{}), logger.NewTest(t)) server := New(atls.NewFakeIssuer(oid.Dummy{}), tc.factory, logger.NewTest(t))
netDialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
listener := netDialer.GetListener("192.0.2.1:1234") listener := netDialer.GetListener("192.0.2.1:1234")
@ -154,41 +125,46 @@ func TestRecover(t *testing.T) {
conn, err := dialer.New(nil, nil, netDialer).Dial(ctx, "192.0.2.1:1234") conn, err := dialer.New(nil, nil, netDialer).Dial(ctx, "192.0.2.1:1234")
require.NoError(err) require.NoError(err)
defer conn.Close() defer conn.Close()
client, err := recoverproto.NewAPIClient(conn).Recover(ctx)
require.NoError(err)
// Send initial message req := recoverproto.RecoverMessage{
err = client.Send(tc.initialMsg.recoverMsg) KmsUri: tc.kmsURI,
require.NoError(err) StorageUri: tc.storageURI,
}
_, err = recoverproto.NewAPIClient(conn).Recover(ctx, &req)
// Receive uuid if tc.wantErr {
uuid, err := client.Recv()
if tc.initialMsg.wantErr {
assert.Error(err) assert.Error(err)
return return
} }
assert.Equal(serverUUID, uuid.DiskUuid)
// Send key message
err = client.Send(tc.keyMsg.recoverMsg)
require.NoError(err)
_, err = client.Recv()
if tc.keyMsg.wantErr {
assert.Error(err)
return
}
assert.ErrorIs(io.EOF, err)
wg.Wait() wg.Wait()
assert.NoError(serveErr) require.NoError(serveErr)
assert.Equal(tc.initialMsg.recoverMsg.GetMeasurementSecret(), measurementSecret) assert.NoError(err)
assert.Equal(tc.keyMsg.recoverMsg.GetStateDiskKey(), diskKey) assert.NotNil(measurementSecret)
assert.NotNil(diskKey)
}) })
} }
} }
type message struct { func newStubKMS(setupErr, getDEKErr error) kmsFactory {
recoverMsg *recoverproto.RecoverMessage return func(ctx context.Context, storageURI string, kmsURI string) (kms.CloudKMS, error) {
wantErr bool if setupErr != nil {
return nil, setupErr
}
return &stubKMS{getDEKErr: getDEKErr}, nil
}
}
type stubKMS struct {
getDEKErr error
}
func (s *stubKMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error {
return nil
}
func (s *stubKMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
if s.getDEKErr != nil {
return nil, s.getDEKErr
}
return []byte("someDEK"), nil
} }

View file

@ -25,11 +25,10 @@ type RecoverMessage struct {
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
// Types that are assignable to Request: // bytes state_disk_key = 1; removed
// // bytes measurement_secret = 2; removed
// *RecoverMessage_StateDiskKey KmsUri string `protobuf:"bytes,3,opt,name=kms_uri,json=kmsUri,proto3" json:"kms_uri,omitempty"`
// *RecoverMessage_MeasurementSecret StorageUri string `protobuf:"bytes,4,opt,name=storage_uri,json=storageUri,proto3" json:"storage_uri,omitempty"`
Request isRecoverMessage_Request `protobuf_oneof:"request"`
} }
func (x *RecoverMessage) Reset() { func (x *RecoverMessage) Reset() {
@ -64,49 +63,24 @@ func (*RecoverMessage) Descriptor() ([]byte, []int) {
return file_recover_proto_rawDescGZIP(), []int{0} return file_recover_proto_rawDescGZIP(), []int{0}
} }
func (m *RecoverMessage) GetRequest() isRecoverMessage_Request { func (x *RecoverMessage) GetKmsUri() string {
if m != nil { if x != nil {
return m.Request return x.KmsUri
} }
return nil return ""
} }
func (x *RecoverMessage) GetStateDiskKey() []byte { func (x *RecoverMessage) GetStorageUri() string {
if x, ok := x.GetRequest().(*RecoverMessage_StateDiskKey); ok { if x != nil {
return x.StateDiskKey return x.StorageUri
} }
return nil return ""
} }
func (x *RecoverMessage) GetMeasurementSecret() []byte {
if x, ok := x.GetRequest().(*RecoverMessage_MeasurementSecret); ok {
return x.MeasurementSecret
}
return nil
}
type isRecoverMessage_Request interface {
isRecoverMessage_Request()
}
type RecoverMessage_StateDiskKey struct {
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3,oneof"`
}
type RecoverMessage_MeasurementSecret struct {
MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3,oneof"`
}
func (*RecoverMessage_StateDiskKey) isRecoverMessage_Request() {}
func (*RecoverMessage_MeasurementSecret) isRecoverMessage_Request() {}
type RecoverResponse struct { type RecoverResponse struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
DiskUuid string `protobuf:"bytes,1,opt,name=disk_uuid,json=diskUuid,proto3" json:"disk_uuid,omitempty"`
} }
func (x *RecoverResponse) Reset() { func (x *RecoverResponse) Reset() {
@ -141,39 +115,27 @@ func (*RecoverResponse) Descriptor() ([]byte, []int) {
return file_recover_proto_rawDescGZIP(), []int{1} return file_recover_proto_rawDescGZIP(), []int{1}
} }
func (x *RecoverResponse) GetDiskUuid() string {
if x != nil {
return x.DiskUuid
}
return ""
}
var File_recover_proto protoreflect.FileDescriptor var File_recover_proto protoreflect.FileDescriptor
var file_recover_proto_rawDesc = []byte{ var file_recover_proto_rawDesc = []byte{
0x0a, 0x0d, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0a, 0x0d, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
0x0c, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x74, 0x0a, 0x0c, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x4a, 0x0a,
0x0e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x0e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12,
0x26, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65, 0x17, 0x0a, 0x07, 0x6b, 0x6d, 0x73, 0x5f, 0x75, 0x72, 0x69, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x6b, 0x6d, 0x73, 0x55, 0x72, 0x69, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x74, 0x6f, 0x72,
0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2f, 0x0a, 0x12, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x61, 0x67, 0x65, 0x5f, 0x75, 0x72, 0x69, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73,
0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x55, 0x72, 0x69, 0x22, 0x11, 0x0a, 0x0f, 0x52, 0x65, 0x63,
0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x4f, 0x0a, 0x03,
0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x41, 0x50, 0x49, 0x12, 0x48, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x12, 0x1c,
0x65, 0x73, 0x74, 0x22, 0x2e, 0x0a, 0x0f, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x75, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1d, 0x2e, 0x72,
0x75, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x69, 0x73, 0x6b, 0x55, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f,
0x75, 0x69, 0x64, 0x32, 0x53, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12, 0x4c, 0x0a, 0x07, 0x52, 0x65, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x42, 0x5a,
0x63, 0x6f, 0x76, 0x65, 0x72, 0x12, 0x1c, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x40, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65,
0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x6c, 0x65, 0x73, 0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c,
0x61, 0x67, 0x65, 0x1a, 0x1d, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x76, 0x32, 0x2f, 0x64, 0x69, 0x73, 0x6b, 0x2d, 0x6d, 0x61,
0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x70, 0x70, 0x65, 0x72, 0x2f, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74,
0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x42, 0x5a, 0x40, 0x67, 0x69, 0x74, 0x68, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73, 0x73, 0x73,
0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e,
0x2f, 0x76, 0x32, 0x2f, 0x64, 0x69, 0x73, 0x6b, 0x2d, 0x6d, 0x61, 0x70, 0x70, 0x65, 0x72, 0x2f,
0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (
@ -234,10 +196,6 @@ func file_recover_proto_init() {
} }
} }
} }
file_recover_proto_msgTypes[0].OneofWrappers = []interface{}{
(*RecoverMessage_StateDiskKey)(nil),
(*RecoverMessage_MeasurementSecret)(nil),
}
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{ File: protoimpl.DescBuilder{

View file

@ -5,16 +5,19 @@ package recoverproto;
option go_package = "github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"; option go_package = "github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto";
service API { service API {
rpc Recover(stream RecoverMessage) returns (stream RecoverResponse) {} // Recover sends the necessary information to the recoveryserver to start recovering the cluster.
rpc Recover(RecoverMessage) returns (RecoverResponse) {}
} }
message RecoverMessage { message RecoverMessage {
oneof request { // bytes state_disk_key = 1; removed
bytes state_disk_key = 1; // bytes measurement_secret = 2; removed
bytes measurement_secret = 2; // kms_uri is the URI of the KMS the recoveryserver should use to decrypt DEKs.
} string kms_uri = 3;
// storage_uri is the URI of the storage location the recoveryserver should use to fetch DEKs.
string storage_uri = 4;
} }
message RecoverResponse { message RecoverResponse {
string disk_uuid = 1; // string disk_uuid = 1; removed
} }

View file

@ -22,7 +22,7 @@ const _ = grpc.SupportPackageIsVersion7
// //
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type APIClient interface { type APIClient interface {
Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error) Recover(ctx context.Context, in *RecoverMessage, opts ...grpc.CallOption) (*RecoverResponse, error)
} }
type aPIClient struct { type aPIClient struct {
@ -33,42 +33,20 @@ func NewAPIClient(cc grpc.ClientConnInterface) APIClient {
return &aPIClient{cc} return &aPIClient{cc}
} }
func (c *aPIClient) Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error) { func (c *aPIClient) Recover(ctx context.Context, in *RecoverMessage, opts ...grpc.CallOption) (*RecoverResponse, error) {
stream, err := c.cc.NewStream(ctx, &API_ServiceDesc.Streams[0], "/recoverproto.API/Recover", opts...) out := new(RecoverResponse)
err := c.cc.Invoke(ctx, "/recoverproto.API/Recover", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
x := &aPIRecoverClient{stream} return out, nil
return x, nil
}
type API_RecoverClient interface {
Send(*RecoverMessage) error
Recv() (*RecoverResponse, error)
grpc.ClientStream
}
type aPIRecoverClient struct {
grpc.ClientStream
}
func (x *aPIRecoverClient) Send(m *RecoverMessage) error {
return x.ClientStream.SendMsg(m)
}
func (x *aPIRecoverClient) Recv() (*RecoverResponse, error) {
m := new(RecoverResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
} }
// APIServer is the server API for API service. // APIServer is the server API for API service.
// All implementations must embed UnimplementedAPIServer // All implementations must embed UnimplementedAPIServer
// for forward compatibility // for forward compatibility
type APIServer interface { type APIServer interface {
Recover(API_RecoverServer) error Recover(context.Context, *RecoverMessage) (*RecoverResponse, error)
mustEmbedUnimplementedAPIServer() mustEmbedUnimplementedAPIServer()
} }
@ -76,8 +54,8 @@ type APIServer interface {
type UnimplementedAPIServer struct { type UnimplementedAPIServer struct {
} }
func (UnimplementedAPIServer) Recover(API_RecoverServer) error { func (UnimplementedAPIServer) Recover(context.Context, *RecoverMessage) (*RecoverResponse, error) {
return status.Errorf(codes.Unimplemented, "method Recover not implemented") return nil, status.Errorf(codes.Unimplemented, "method Recover not implemented")
} }
func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {} func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {}
@ -92,30 +70,22 @@ func RegisterAPIServer(s grpc.ServiceRegistrar, srv APIServer) {
s.RegisterService(&API_ServiceDesc, srv) s.RegisterService(&API_ServiceDesc, srv)
} }
func _API_Recover_Handler(srv interface{}, stream grpc.ServerStream) error { func _API_Recover_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
return srv.(APIServer).Recover(&aPIRecoverServer{stream}) in := new(RecoverMessage)
} if err := dec(in); err != nil {
type API_RecoverServer interface {
Send(*RecoverResponse) error
Recv() (*RecoverMessage, error)
grpc.ServerStream
}
type aPIRecoverServer struct {
grpc.ServerStream
}
func (x *aPIRecoverServer) Send(m *RecoverResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *aPIRecoverServer) Recv() (*RecoverMessage, error) {
m := new(RecoverMessage)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err return nil, err
} }
return m, nil if interceptor == nil {
return srv.(APIServer).Recover(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/recoverproto.API/Recover",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(APIServer).Recover(ctx, req.(*RecoverMessage))
}
return interceptor(ctx, in, info, handler)
} }
// API_ServiceDesc is the grpc.ServiceDesc for API service. // API_ServiceDesc is the grpc.ServiceDesc for API service.
@ -124,14 +94,12 @@ func (x *aPIRecoverServer) Recv() (*RecoverMessage, error) {
var API_ServiceDesc = grpc.ServiceDesc{ var API_ServiceDesc = grpc.ServiceDesc{
ServiceName: "recoverproto.API", ServiceName: "recoverproto.API",
HandlerType: (*APIServer)(nil), HandlerType: (*APIServer)(nil),
Methods: []grpc.MethodDesc{}, Methods: []grpc.MethodDesc{
Streams: []grpc.StreamDesc{
{ {
StreamName: "Recover", MethodName: "Recover",
Handler: _API_Recover_Handler, Handler: _API_Recover_Handler,
ServerStreams: true,
ClientStreams: true,
}, },
}, },
Streams: []grpc.StreamDesc{},
Metadata: "recover.proto", Metadata: "recover.proto",
} }

View file

@ -20,10 +20,5 @@ const (
// DeriveClusterID derives the cluster ID from a salt and secret value. // DeriveClusterID derives the cluster ID from a salt and secret value.
func DeriveClusterID(secret, salt []byte) ([]byte, error) { func DeriveClusterID(secret, salt []byte) ([]byte, error) {
return crypto.DeriveKey(secret, salt, []byte(crypto.HKDFInfoPrefix+clusterIDContext), crypto.DerivedKeyLengthDefault) return crypto.DeriveKey(secret, salt, []byte(crypto.DEKPrefix+clusterIDContext), crypto.DerivedKeyLengthDefault)
}
// DeriveMeasurementSecret derives the secret value needed to derive ClusterID.
func DeriveMeasurementSecret(masterSecret, salt []byte) ([]byte, error) {
return crypto.DeriveKey(masterSecret, salt, []byte(crypto.HKDFInfoPrefix+MeasurementSecretContext), crypto.DerivedKeyLengthDefault)
} }

View file

@ -31,21 +31,3 @@ func TestDeriveClusterID(t *testing.T) {
require.NoError(err) require.NoError(err)
assert.NotEqual(clusterID, clusterIDdiff) assert.NotEqual(clusterID, clusterIDdiff)
} }
func TestDeriveMeasurementSecret(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
testvector := testvector.HKDFMeasurementSecret
measurementSecret, err := DeriveMeasurementSecret(testvector.Secret, testvector.Salt)
require.NoError(err)
assert.Equal(testvector.Output, measurementSecret)
measurementSecretdiff, err := DeriveMeasurementSecret(testvector.Secret, []byte("different-salt"))
require.NoError(err)
assert.NotEqual(measurementSecret, measurementSecretdiff)
measurementSecretdiff, err = DeriveMeasurementSecret([]byte("different-secret"), testvector.Salt)
require.NoError(err)
assert.NotEqual(measurementSecret, measurementSecretdiff)
}

View file

@ -30,8 +30,10 @@ const (
MasterSecretLengthMin = 16 MasterSecretLengthMin = 16
// RNGLengthDefault is the number of bytes used for generating nonces. // RNGLengthDefault is the number of bytes used for generating nonces.
RNGLengthDefault = 32 RNGLengthDefault = 32
// HKDFInfoPrefix is the prefix used for the info parameter in HKDF. // DEKPrefix is the prefix used to prefix DEK IDs. Originally introduced as a requirement for the HKDF info parameter.
HKDFInfoPrefix = "key-" DEKPrefix = "key-"
// MeasurementSecretKeyID is name used for the measurementSecret DEK.
MeasurementSecretKeyID = "measurementSecret"
) )
// DeriveKey derives a key from a secret. // DeriveKey derives a key from a secret.

View file

@ -56,13 +56,14 @@ type KMSClient struct {
awsClient ClientAPI awsClient ClientAPI
policyProducer KeyPolicyProducer policyProducer KeyPolicyProducer
storage kmsInterface.Storage storage kmsInterface.Storage
kekID string
} }
// New creates and initializes a new KMSClient for AWS. // New creates and initializes a new KMSClient for AWS.
// //
// The parameter client needs to be initialized with valid AWS credentials (https://aws.github.io/aws-sdk-go-v2/docs/getting-started). // The parameter client needs to be initialized with valid AWS credentials (https://aws.github.io/aws-sdk-go-v2/docs/getting-started).
// If storage is nil, the default MemMapStorage is used. // If storage is nil, the default MemMapStorage is used.
func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterface.Storage, optFns ...func(*awsconfig.LoadOptions) error) (*KMSClient, error) { func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterface.Storage, kekID string, optFns ...func(*awsconfig.LoadOptions) error) (*KMSClient, error) {
if store == nil { if store == nil {
store = storage.NewMemMapStorage() store = storage.NewMemMapStorage()
} }
@ -77,6 +78,7 @@ func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterfa
awsClient: client, awsClient: client,
policyProducer: policyProducer, policyProducer: policyProducer,
storage: store, storage: store,
kekID: kekID,
}, nil }, nil
} }
@ -206,9 +208,9 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
} }
// GetDEK returns the DEK for dekID and kekID from the KMS. // GetDEK returns the DEK for dekID and kekID from the KMS.
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) { func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
// The KEK should be identified by its alias. The alias always has the same scheme: 'alias/<kekId>' // The KEK should be identified by its alias. The alias always has the same scheme: 'alias/<kekId>'
kekID = "alias/" + kekID kekID := "alias/" + c.kekID
// If a key for keyID exists in the storage, decrypt the key using the KEK. // If a key for keyID exists in the storage, decrypt the key using the KEK.
dek, err := c.decryptDEKFromStorage(ctx, kekID, keyID) dek, err := c.decryptDEKFromStorage(ctx, kekID, keyID)

View file

@ -253,22 +253,23 @@ func TestAWSKMSClient(t *testing.T) {
awsClient := &fakeAWSClient{kekPool: make(map[string][]byte, 2)} awsClient := &fakeAWSClient{kekPool: make(map[string][]byte, 2)}
client := &KMSClient{
awsClient: awsClient,
policyProducer: &stubKeyPolicyProducer{},
storage: storage.NewMemMapStorage(),
}
awsClient.keyIDCount = -1
testKEK1ID := "testKEK1" testKEK1ID := "testKEK1"
testKEK1 := []byte("test KEK") testKEK1 := []byte("test KEK")
testKEK2ID := "testKEK2" testKEK2ID := "testKEK2"
testKEK2 := []byte("more test KEK") testKEK2 := []byte("more test KEK")
client := &KMSClient{
awsClient: awsClient,
policyProducer: &stubKeyPolicyProducer{},
storage: storage.NewMemMapStorage(),
kekID: testKEK1ID,
}
awsClient.keyIDCount = -1
ctx := context.Background() ctx := context.Background()
// try to get a DEK before setting the KEK // try to get a DEK before setting the KEK
_, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength) _, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
assert.Error(err) assert.Error(err)
assert.ErrorIs(err, kmsInterface.ErrKEKUnknown) assert.ErrorIs(err, kmsInterface.ErrKEKUnknown)
@ -285,16 +286,16 @@ func TestAWSKMSClient(t *testing.T) {
assert.Equal(testKEK2, awsClient.kekPool[strconv.Itoa(awsClient.keyIDCount)]) assert.Equal(testKEK2, awsClient.kekPool[strconv.Itoa(awsClient.keyIDCount)])
// test GetDEK method // test GetDEK method
dek1, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength) dek1, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
dek2, err := client.GetDEK(ctx, testKEK2ID, "volume02", config.SymmetricKeyLength) dek2, err := client.GetDEK(ctx, "volume02", config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
// make sure that GetDEK is idempotent // make sure that GetDEK is idempotent
dek1Copy, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength) dek1Copy, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
assert.Equal(dek1, dek1Copy) assert.Equal(dek1, dek1Copy)
dek2Copy, err := client.GetDEK(ctx, testKEK2ID, "volume02", config.SymmetricKeyLength) dek2Copy, err := client.GetDEK(ctx, "volume02", config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
assert.Equal(dek2, dek2Copy) assert.Equal(dek2, dek2Copy)
} }

View file

@ -47,6 +47,7 @@ type kmsClientAPI interface {
type KMSClient struct { type KMSClient struct {
client kmsClientAPI client kmsClientAPI
storage kms.Storage storage kms.Storage
kekID string
} }
// Opts are optional settings for AKV clients. // Opts are optional settings for AKV clients.
@ -57,7 +58,7 @@ type Opts struct {
} }
// New initializes a KMS client for Azure Key Vault. // New initializes a KMS client for Azure Key Vault.
func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms.Storage, opts *Opts) (*KMSClient, error) { func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms.Storage, kekID string, opts *Opts) (*KMSClient, error) {
if opts == nil { if opts == nil {
opts = &Opts{} opts = &Opts{}
} }
@ -80,7 +81,7 @@ func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms
if store == nil { if store == nil {
store = storage.NewMemMapStorage() store = storage.NewMemMapStorage()
} }
return &KMSClient{client: client, storage: store}, nil return &KMSClient{client: client, storage: store, kekID: kekID}, nil
} }
// CreateKEK saves a new Key Encryption Key using Azure Key Vault. // CreateKEK saves a new Key Encryption Key using Azure Key Vault.
@ -111,8 +112,8 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
} }
// GetDEK decrypts a DEK from storage. // GetDEK decrypts a DEK from storage.
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) { func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
kek, err := c.getKEK(ctx, kekID) kek, err := c.getKEK(ctx, c.kekID)
if err != nil { if err != nil {
return nil, fmt.Errorf("loading KEK from key vault: %w", err) return nil, fmt.Errorf("loading KEK from key vault: %w", err)
} }

View file

@ -155,7 +155,7 @@ func TestKMSGetDEK(t *testing.T) {
storage: tc.storage, storage: tc.storage,
} }
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32) dek, err := client.GetDEK(context.Background(), "volume-01", 32)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { } else {

View file

@ -38,10 +38,11 @@ type HSMClient struct {
client hsmClientAPI client hsmClientAPI
storage kms.Storage storage kms.Storage
vaultURL string vaultURL string
kekID string
} }
// NewHSM initializes a KMS client for Azure manged HSM Key Vault. // NewHSM initializes a KMS client for Azure manged HSM Key Vault.
func NewHSM(ctx context.Context, vaultName string, store kms.Storage, opts *Opts) (*HSMClient, error) { func NewHSM(ctx context.Context, vaultName string, store kms.Storage, kekID string, opts *Opts) (*HSMClient, error) {
if opts == nil { if opts == nil {
opts = &Opts{} opts = &Opts{}
} }
@ -72,6 +73,7 @@ func NewHSM(ctx context.Context, vaultName string, store kms.Storage, opts *Opts
client: client, client: client,
credentials: cred, credentials: cred,
storage: store, storage: store,
kekID: kekID,
}, nil }, nil
} }
@ -114,7 +116,7 @@ func (c *HSMClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
} }
// GetDEK loads an encrypted DEK from storage and unwraps it using an HSM-backed key. // GetDEK loads an encrypted DEK from storage and unwraps it using an HSM-backed key.
func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekSize int) ([]byte, error) { func (c *HSMClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
encryptedDEK, err := c.storage.Get(ctx, keyID) encryptedDEK, err := c.storage.Get(ctx, keyID)
if err != nil { if err != nil {
if !errors.Is(err, storage.ErrDEKUnset) { if !errors.Is(err, storage.ErrDEKUnset) {
@ -126,7 +128,7 @@ func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekS
if err != nil { if err != nil {
return nil, fmt.Errorf("key generation: %w", err) return nil, fmt.Errorf("key generation: %w", err)
} }
if err := c.putDEK(ctx, kekID, keyID, newDEK); err != nil { if err := c.putDEK(ctx, c.kekID, keyID, newDEK); err != nil {
return nil, fmt.Errorf("creating new DEK: %w", err) return nil, fmt.Errorf("creating new DEK: %w", err)
} }
@ -137,7 +139,7 @@ func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekS
Algorithm: to.Ptr(azkeys.JSONWebKeyEncryptionAlgorithmA256KW), Algorithm: to.Ptr(azkeys.JSONWebKeyEncryptionAlgorithmA256KW),
Value: encryptedDEK, Value: encryptedDEK,
} }
res, err := c.client.UnwrapKey(ctx, kekID, "", params, &azkeys.UnwrapKeyOptions{}) res, err := c.client.UnwrapKey(ctx, c.kekID, "", params, &azkeys.UnwrapKeyOptions{})
if err != nil { if err != nil {
return nil, fmt.Errorf("unwrapping key: %w", err) return nil, fmt.Errorf("unwrapping key: %w", err)
} }

View file

@ -165,7 +165,7 @@ func TestHSMGetNewDEK(t *testing.T) {
storage: tc.storage, storage: tc.storage,
} }
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32) dek, err := client.GetDEK(context.Background(), "volume-01", 32)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { } else {
@ -208,7 +208,7 @@ func TestHSMGetExistingDEK(t *testing.T) {
storage: storage, storage: storage,
} }
dek, err := client.GetDEK(context.Background(), "test-key", keyID, len(testKey)) dek, err := client.GetDEK(context.Background(), keyID, len(testKey))
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { } else {

View file

@ -20,8 +20,15 @@ type KMS struct {
} }
// New creates a new ClusterKMS. // New creates a new ClusterKMS.
func New(salt []byte) *KMS { func New(key []byte, salt []byte) (*KMS, error) {
return &KMS{salt: salt} if len(key) == 0 {
return nil, errors.New("missing master key")
}
if len(salt) == 0 {
return nil, errors.New("missing salt")
}
return &KMS{masterKey: key, salt: salt}, nil
} }
// CreateKEK sets the ClusterKMS masterKey. // CreateKEK sets the ClusterKMS masterKey.
@ -31,7 +38,7 @@ func (c *KMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error {
} }
// GetDEK derives a key from the KMS masterKey. // GetDEK derives a key from the KMS masterKey.
func (c *KMS) GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error) { func (c *KMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
if len(c.masterKey) == 0 { if len(c.masterKey) == 0 {
return nil, errors.New("master key not set for Constellation KMS") return nil, errors.New("master key not set for Constellation KMS")
} }

View file

@ -24,19 +24,12 @@ func TestMain(m *testing.M) {
func TestClusterKMS(t *testing.T) { func TestClusterKMS(t *testing.T) {
testVector := testvector.HKDF0xFF testVector := testvector.HKDF0xFF
assert := assert.New(t) assert := assert.New(t)
kms := New(testVector.Salt) require := require.New(t)
kms, err := New(testVector.Secret, testVector.Salt)
key, err := kms.GetDEK(context.Background(), "", "key-1", 32) require.NoError(err)
assert.Error(err)
assert.Nil(key)
err = kms.CreateKEK(context.Background(), "", testVector.Secret)
assert.NoError(err)
assert.Equal(testVector.Secret, kms.masterKey)
keyLower, err := kms.GetDEK( keyLower, err := kms.GetDEK(
context.Background(), context.Background(),
"",
strings.ToLower(testVector.InfoPrefix+testVector.Info), strings.ToLower(testVector.InfoPrefix+testVector.Info),
int(testVector.Length), int(testVector.Length),
) )
@ -46,12 +39,11 @@ func TestClusterKMS(t *testing.T) {
// output of the KMS should be case sensitive // output of the KMS should be case sensitive
keyUpper, err := kms.GetDEK( keyUpper, err := kms.GetDEK(
context.Background(), context.Background(),
"",
strings.ToUpper(testVector.InfoPrefix+testVector.Info), strings.ToUpper(testVector.InfoPrefix+testVector.Info),
int(testVector.Length), int(testVector.Length),
) )
assert.NoError(err) assert.NoError(err)
assert.NotEqual(key, keyUpper) assert.NotEqual(keyLower, keyUpper)
} }
func TestVectorsHKDF(t *testing.T) { func TestVectorsHKDF(t *testing.T) {
@ -61,6 +53,7 @@ func TestVectorsHKDF(t *testing.T) {
dekID string dekID string
dekSize uint dekSize uint
wantKey []byte wantKey []byte
wantErr bool
}{ }{
"rfc Test Case 1": { "rfc Test Case 1": {
kek: testvector.HKDFrfc1.Secret, kek: testvector.HKDFrfc1.Secret,
@ -82,6 +75,7 @@ func TestVectorsHKDF(t *testing.T) {
dekID: testvector.HKDFrfc3.Info, dekID: testvector.HKDFrfc3.Info,
dekSize: testvector.HKDFrfc3.Length, dekSize: testvector.HKDFrfc3.Length,
wantKey: testvector.HKDFrfc3.Output, wantKey: testvector.HKDFrfc3.Output,
wantErr: true,
}, },
"HKDF zero": { "HKDF zero": {
kek: testvector.HKDFZero.Secret, kek: testvector.HKDFZero.Secret,
@ -104,10 +98,15 @@ func TestVectorsHKDF(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
kms := New(tc.salt) kms, err := New(tc.kek, tc.salt)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
require.NoError(kms.CreateKEK(context.Background(), "", tc.kek)) require.NoError(kms.CreateKEK(context.Background(), "", tc.kek))
out, err := kms.GetDEK(context.Background(), "", tc.dekID, int(tc.dekSize)) out, err := kms.GetDEK(context.Background(), tc.dekID, int(tc.dekSize))
require.NoError(err) require.NoError(err)
assert.Equal(tc.wantKey, out) assert.Equal(tc.wantKey, out)
}) })

View file

@ -50,11 +50,12 @@ type KMSClient struct {
waitBackoffLimit int waitBackoffLimit int
storage kmsInterface.Storage storage kmsInterface.Storage
protectionLevel kmspb.ProtectionLevel protectionLevel kmspb.ProtectionLevel
kekID string
opts []gax.CallOption opts []gax.CallOption
} }
// New initializes a KMS client for Google Cloud Platform. // New initializes a KMS client for Google Cloud Platform.
func New(ctx context.Context, projectID, locationID, keyRingID string, store kmsInterface.Storage, protectionLvl kmspb.ProtectionLevel, opts ...gax.CallOption) (*KMSClient, error) { func New(ctx context.Context, projectID, locationID, keyRingID string, store kmsInterface.Storage, protectionLvl kmspb.ProtectionLevel, kekID string, opts ...gax.CallOption) (*KMSClient, error) {
if store == nil { if store == nil {
store = storage.NewMemMapStorage() store = storage.NewMemMapStorage()
} }
@ -71,6 +72,7 @@ func New(ctx context.Context, projectID, locationID, keyRingID string, store kms
waitBackoffLimit: 10, waitBackoffLimit: 10,
storage: store, storage: store,
protectionLevel: protectionLvl, protectionLevel: protectionLvl,
kekID: kekID,
opts: opts, opts: opts,
} }
@ -108,7 +110,7 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
} }
// GetDEK fetches an encrypted Data Encryption Key from storage and decrypts it using a KEK stored in Google's KMS. // GetDEK fetches an encrypted Data Encryption Key from storage and decrypts it using a KEK stored in Google's KMS.
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) { func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
client, err := c.newClient(ctx) client, err := c.newClient(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -126,11 +128,11 @@ func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int
if err != nil { if err != nil {
return nil, fmt.Errorf("key generation: %w", err) return nil, fmt.Errorf("key generation: %w", err)
} }
return newDEK, c.putDEK(ctx, client, kekID, keyID, newDEK) return newDEK, c.putDEK(ctx, client, c.kekID, keyID, newDEK)
} }
request := &kmspb.DecryptRequest{ request := &kmspb.DecryptRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.projectID, c.locationID, c.keyRingID, kekID), Name: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.projectID, c.locationID, c.keyRingID, c.kekID),
Ciphertext: encryptedDEK, Ciphertext: encryptedDEK,
} }

View file

@ -321,7 +321,7 @@ func TestGetDEK(t *testing.T) {
storage: tc.storage, storage: tc.storage,
} }
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32) dek, err := client.GetDEK(context.Background(), "volume-01", 32)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { } else {

View file

@ -17,7 +17,7 @@ type CloudKMS interface {
CreateKEK(ctx context.Context, keyID string, kek []byte) error CreateKEK(ctx context.Context, keyID string, kek []byte) error
// GetDEK returns the DEK for dekID and kekID from the KMS. // GetDEK returns the DEK for dekID and kekID from the KMS.
// If the DEK does not exist, a new one is created and saved to storage. // If the DEK does not exist, a new one is created and saved to storage.
GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error)
} }
// Storage provides an abstract interface for the storage backend used for DEKs. // Storage provides an abstract interface for the storage backend used for DEKs.

View file

@ -24,10 +24,10 @@ import (
// Well known endpoints for KMS services. // Well known endpoints for KMS services.
const ( const (
AWSKMSURI = "kms://aws?keyPolicy=%s" AWSKMSURI = "kms://aws?keyPolicy=%s&kekID=%s"
AzureKMSURI = "kms://azure-kms?name=%s&type=%s" AzureKMSURI = "kms://azure-kms?name=%s&type=%s&kekID=%s"
AzureHSMURI = "kms://azure-hsm?name=%s" AzureHSMURI = "kms://azure-hsm?name=%s&kekID=%s"
GCPKMSURI = "kms://gcp?project=%s&location=%s&keyRing=%s&protectionLvl=%s" GCPKMSURI = "kms://gcp?project=%s&location=%s&keyRing=%s&protectionLvl=%s&kekID=%s"
ClusterKMSURI = "kms://cluster-kms?key=%s&salt=%s" ClusterKMSURI = "kms://cluster-kms?key=%s&salt=%s"
AWSS3URI = "storage://aws?bucket=%s" AWSS3URI = "storage://aws?bucket=%s"
AzureBlobURI = "storage://azure?container=%s&connectionString=%s" AzureBlobURI = "storage://azure?container=%s&connectionString=%s"
@ -35,6 +35,21 @@ const (
NoStoreURI = "storage://no-store" NoStoreURI = "storage://no-store"
) )
// MasterSecret holds the master key and salt for deriving keys.
type MasterSecret struct {
Key []byte `json:"key"`
Salt []byte `json:"salt"`
}
// EncodeToURI returns an URI encoding the master secret.
func (m *MasterSecret) EncodeToURI() string {
return fmt.Sprintf(
ClusterKMSURI,
base64.URLEncoding.EncodeToString(m.Key),
base64.URLEncoding.EncodeToString(m.Salt),
)
}
// KMSInformation about an existing KMS. // KMSInformation about an existing KMS.
type KMSInformation struct { type KMSInformation struct {
KMSURI string KMSURI string
@ -104,39 +119,39 @@ func getKMS(ctx context.Context, kmsURI string, store kms.Storage) (kms.CloudKMS
switch uri.Host { switch uri.Host {
case "aws": case "aws":
poliyProducer, err := getAWSKMSConfig(uri) poliyProducer, kekID, err := getAWSKMSConfig(uri)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return aws.New(ctx, poliyProducer, store) return aws.New(ctx, poliyProducer, store, kekID)
case "azure-kms": case "azure-kms":
vaultName, vaultType, err := getAzureKMSConfig(uri) vaultName, vaultType, kekID, err := getAzureKMSConfig(uri)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return azure.New(ctx, vaultName, azure.VaultSuffix(vaultType), store, nil) return azure.New(ctx, vaultName, azure.VaultSuffix(vaultType), store, kekID, nil)
case "azure-hsm": case "azure-hsm":
vaultName, err := getAzureHSMConfig(uri) vaultName, kekID, err := getAzureHSMConfig(uri)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return azure.NewHSM(ctx, vaultName, store, nil) return azure.NewHSM(ctx, vaultName, store, kekID, nil)
case "gcp": case "gcp":
project, location, keyRing, protectionLvl, err := getGCPKMSConfig(uri) project, location, keyRing, protectionLvl, kekID, err := getGCPKMSConfig(uri)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return gcp.New(ctx, project, location, keyRing, store, kmspb.ProtectionLevel(protectionLvl)) return gcp.New(ctx, project, location, keyRing, store, kmspb.ProtectionLevel(protectionLvl), kekID)
case "cluster-kms": case "cluster-kms":
salt, err := getClusterKMSConfig(uri) masterSecret, err := getClusterKMSConfig(uri)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return cluster.New(salt), nil return cluster.New(masterSecret.Key, masterSecret.Salt)
default: default:
return nil, fmt.Errorf("unknown KMS type: %s", uri.Host) return nil, fmt.Errorf("unknown KMS type: %s", uri.Host)
@ -156,22 +171,56 @@ func getAWSS3Config(uri *url.URL) (string, error) {
return r[0], err return r[0], err
} }
func getAWSKMSConfig(uri *url.URL) (*defaultPolicyProducer, error) { func getAWSKMSConfig(uri *url.URL) (*defaultPolicyProducer, string, error) {
r, err := getConfig(uri.Query(), []string{"keyPolicy"}) r, err := getConfig(uri.Query(), []string{"keyPolicy", "kekID"})
if err != nil { if err != nil {
return nil, err return nil, "", err
}
return &defaultPolicyProducer{policy: r[0]}, err
} }
func getAzureKMSConfig(uri *url.URL) (string, string, error) { if len(r) != 2 {
r, err := getConfig(uri.Query(), []string{"name", "type"}) return nil, "", fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
return r[0], r[1], err
} }
func getAzureHSMConfig(uri *url.URL) (string, error) { kekID, err := base64.URLEncoding.DecodeString(r[1])
r, err := getConfig(uri.Query(), []string{"name"}) if err != nil {
return r[0], err return nil, "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return &defaultPolicyProducer{policy: r[0]}, string(kekID), err
}
func getAzureKMSConfig(uri *url.URL) (string, string, string, error) {
r, err := getConfig(uri.Query(), []string{"name", "type", "kekID"})
if err != nil {
return "", "", "", fmt.Errorf("getting config: %w", err)
}
if len(r) != 3 {
return "", "", "", fmt.Errorf("expected 3 KmsURI args, got %d", len(r))
}
kekID, err := base64.URLEncoding.DecodeString(r[2])
if err != nil {
return "", "", "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return r[0], r[1], string(kekID), err
}
func getAzureHSMConfig(uri *url.URL) (string, string, error) {
r, err := getConfig(uri.Query(), []string{"name", "kekID"})
if err != nil {
return "", "", fmt.Errorf("getting config: %w", err)
}
if len(r) != 2 {
return "", "", fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
}
kekID, err := base64.URLEncoding.DecodeString(r[1])
if err != nil {
return "", "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return r[0], string(kekID), err
} }
func getAzureBlobConfig(uri *url.URL) (string, string, error) { func getAzureBlobConfig(uri *url.URL) (string, string, error) {
@ -182,16 +231,26 @@ func getAzureBlobConfig(uri *url.URL) (string, string, error) {
return r[0], r[1], nil return r[0], r[1], nil
} }
func getGCPKMSConfig(uri *url.URL) (string, string, string, int32, error) { func getGCPKMSConfig(uri *url.URL) (project string, location string, keyRing string, protectionLvl int32, kekID string, err error) {
r, err := getConfig(uri.Query(), []string{"project", "location", "keyRing", "protectionLvl"}) r, err := getConfig(uri.Query(), []string{"project", "location", "keyRing", "protectionLvl", "kekID"})
if err != nil { if err != nil {
return "", "", "", 0, err return "", "", "", 0, "", err
} }
protectionLvl, err := strconv.ParseInt(r[3], 10, 32)
if len(r) != 5 {
return "", "", "", 0, "", fmt.Errorf("expected 5 KmsURI args, got %d", len(r))
}
kekIDByte, err := base64.URLEncoding.DecodeString(r[4])
if err != nil { if err != nil {
return "", "", "", 0, err return "", "", "", 0, "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
} }
return r[0], r[1], r[2], int32(protectionLvl), nil
protectionLvl32, err := strconv.ParseInt(r[3], 10, 32)
if err != nil {
return "", "", "", 0, "", err
}
return r[0], r[1], r[2], int32(protectionLvl32), string(kekIDByte), nil
} }
func getGCPStorageConfig(uri *url.URL) (string, string, error) { func getGCPStorageConfig(uri *url.URL) (string, string, error) {
@ -199,12 +258,26 @@ func getGCPStorageConfig(uri *url.URL) (string, string, error) {
return r[0], r[1], err return r[0], r[1], err
} }
func getClusterKMSConfig(uri *url.URL) ([]byte, error) { func getClusterKMSConfig(uri *url.URL) (MasterSecret, error) {
r, err := getConfig(uri.Query(), []string{"salt"}) r, err := getConfig(uri.Query(), []string{"key", "salt"})
if err != nil { if err != nil {
return nil, err return MasterSecret{}, err
} }
return base64.URLEncoding.DecodeString(r[0])
if len(r) != 2 {
return MasterSecret{}, fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
}
key, err := base64.URLEncoding.DecodeString(r[0])
if err != nil {
return MasterSecret{}, fmt.Errorf("parsing key from kmsUri: %w", err)
}
salt, err := base64.URLEncoding.DecodeString(r[1])
if err != nil {
return MasterSecret{}, fmt.Errorf("parsing salt from kmsUri: %w", err)
}
return MasterSecret{Key: key, Salt: salt}, nil
} }
// getConfig parses url query values, returning a map of the requested values. // getConfig parses url query values, returning a map of the requested values.

View file

@ -18,6 +18,8 @@ import (
"go.uber.org/goleak" "go.uber.org/goleak"
) )
const constellationKekID = "Constellation"
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262 // https://github.com/census-instrumentation/opencensus-go/issues/1262
@ -80,23 +82,23 @@ func TestGetKMS(t *testing.T) {
wantErr bool wantErr bool
}{ }{
"cluster kms": { "cluster kms": {
uri: fmt.Sprintf("%s?salt=%s", ClusterKMSURI, base64.URLEncoding.EncodeToString([]byte("salt"))), uri: fmt.Sprintf(ClusterKMSURI, base64.URLEncoding.EncodeToString([]byte("key")), base64.URLEncoding.EncodeToString([]byte("salt"))),
wantErr: false, wantErr: false,
}, },
"aws kms": { "aws kms": {
uri: fmt.Sprintf(AWSKMSURI, ""), uri: fmt.Sprintf(AWSKMSURI, "", ""),
wantErr: true, wantErr: true,
}, },
"azure kms": { "azure kms": {
uri: fmt.Sprintf(AzureKMSURI, "", ""), uri: fmt.Sprintf(AzureKMSURI, "", "", ""),
wantErr: true, wantErr: true,
}, },
"azure hsm": { "azure hsm": {
uri: fmt.Sprintf(AzureHSMURI, ""), uri: fmt.Sprintf(AzureHSMURI, "", ""),
wantErr: true, wantErr: true,
}, },
"gcp kms": { "gcp kms": {
uri: fmt.Sprintf(GCPKMSURI, "", "", "", ""), uri: fmt.Sprintf(GCPKMSURI, "", "", "", "", ""),
wantErr: true, wantErr: true,
}, },
"unknown kms": { "unknown kms": {
@ -135,7 +137,8 @@ func TestSetUpKMS(t *testing.T) {
assert.Error(err) assert.Error(err)
assert.Nil(kms) assert.Nil(kms)
kms, err = KMS(context.Background(), "storage://no-store", "kms://cluster-kms?salt="+base64.URLEncoding.EncodeToString([]byte("salt"))) masterSecret := MasterSecret{Key: []byte("key"), Salt: []byte("salt")}
kms, err = KMS(context.Background(), "storage://no-store", masterSecret.EncodeToURI())
assert.NoError(err) assert.NoError(err)
assert.NotNil(kms) assert.NotNil(kms)
} }
@ -146,13 +149,15 @@ func TestGetAWSKMSConfig(t *testing.T) {
policy := "{keyPolicy: keyPolicy}" policy := "{keyPolicy: keyPolicy}"
escapedPolicy := url.QueryEscape(policy) escapedPolicy := url.QueryEscape(policy)
uri, err := url.Parse(fmt.Sprintf(AWSKMSURI, escapedPolicy)) kekID := base64.URLEncoding.EncodeToString([]byte(constellationKekID))
uri, err := url.Parse(fmt.Sprintf(AWSKMSURI, escapedPolicy, kekID))
require.NoError(err) require.NoError(err)
policyProducer, err := getAWSKMSConfig(uri) policyProducer, rKekID, err := getAWSKMSConfig(uri)
require.NoError(err) require.NoError(err)
keyPolicy, err := policyProducer.CreateKeyPolicy("") keyPolicy, err := policyProducer.CreateKeyPolicy("")
require.NoError(err) require.NoError(err)
assert.Equal(policy, keyPolicy) assert.Equal(policy, keyPolicy)
assert.Equal(constellationKekID, rKekID)
} }
func TestGetAzureBlobConfig(t *testing.T) { func TestGetAzureBlobConfig(t *testing.T) {
@ -178,18 +183,20 @@ func TestGetGCPKMSConfig(t *testing.T) {
location := "global" location := "global"
keyRing := "test-ring" keyRing := "test-ring"
protectionLvl := "2" protectionLvl := "2"
uri, err := url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, protectionLvl)) kekID := base64.URLEncoding.EncodeToString([]byte(constellationKekID))
uri, err := url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, protectionLvl, kekID))
require.NoError(err) require.NoError(err)
rProject, rLocation, rKeyRing, rProtectionLvl, err := getGCPKMSConfig(uri) rProject, rLocation, rKeyRing, rProtectionLvl, rKekID, err := getGCPKMSConfig(uri)
require.NoError(err) require.NoError(err)
assert.Equal(project, rProject) assert.Equal(project, rProject)
assert.Equal(location, rLocation) assert.Equal(location, rLocation)
assert.Equal(keyRing, rKeyRing) assert.Equal(keyRing, rKeyRing)
assert.Equal(int32(2), rProtectionLvl) assert.Equal(int32(2), rProtectionLvl)
assert.Equal(constellationKekID, rKekID)
uri, err = url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, "invalid")) uri, err = url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, "invalid", kekID))
require.NoError(err) require.NoError(err)
_, _, _, _, err = getGCPKMSConfig(uri) _, _, _, _, _, err = getGCPKMSConfig(uri)
assert.Error(err) assert.Error(err)
} }
@ -202,12 +209,13 @@ func TestGetClusterKMSConfig(t *testing.T) {
0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf,
} }
uri, err := url.Parse(ClusterKMSURI + "?salt=" + base64.URLEncoding.EncodeToString(expectedSalt)) masterSecretIn := MasterSecret{Key: []byte("key"), Salt: expectedSalt}
uri, err := url.Parse(masterSecretIn.EncodeToURI())
require.NoError(err) require.NoError(err)
salt, err := getClusterKMSConfig(uri) masterSecretOut, err := getClusterKMSConfig(uri)
assert.NoError(err) assert.NoError(err)
assert.Equal(expectedSalt, salt) assert.Equal(expectedSalt, masterSecretOut.Salt)
} }
func TestGetConfig(t *testing.T) { func TestGetConfig(t *testing.T) {

View file

@ -59,7 +59,7 @@ const (
// JoinImage image of Constellation join service. // JoinImage image of Constellation join service.
JoinImage = "ghcr.io/edgelesssys/constellation/join-service:v2.5.0-pre.0.20230118154955-632090c21b93@sha256:7c53b43f2580ded9f04a9927d4ff585d3edce5d10a1d83006688c818e6395eb1" // renovate:container JoinImage = "ghcr.io/edgelesssys/constellation/join-service:v2.5.0-pre.0.20230118154955-632090c21b93@sha256:7c53b43f2580ded9f04a9927d4ff585d3edce5d10a1d83006688c818e6395eb1" // renovate:container
// KeyServiceImage image of Constellation KMS server. // KeyServiceImage image of Constellation KMS server.
KeyServiceImage = "ghcr.io/edgelesssys/constellation/kmsserver:v2.5.0-pre.0.20230112123617-d0e9f427d1ba@sha256:d4319308eb62e2ee079cc86858acdd1faccc404edec7bfabecf35861284a55f3" // renovate:container KeyServiceImage = "ghcr.io/edgelesssys/constellation/keyservice:v2.5.0-pre.0.20230116125211-d37bd077d8c6@sha256:4c14176f94899054bbf945f6f209521ffcdbcb9042abc5850d778240fe3693a4" // renovate:container
// VerificationImage image of Constellation verification service. // VerificationImage image of Constellation verification service.
VerificationImage = "ghcr.io/edgelesssys/constellation/verification-service:v2.5.0-pre.0.20230118154955-632090c21b93@sha256:593f735a236f0cb8f4373a7a2dca41be9ab2ba1b784a2ebcf8fb5271705822a3" // renovate:container VerificationImage = "ghcr.io/edgelesssys/constellation/verification-service:v2.5.0-pre.0.20230118154955-632090c21b93@sha256:593f735a236f0cb8f4373a7a2dca41be9ab2ba1b784a2ebcf8fb5271705822a3" // renovate:container
// GcpGuestImage image for GCP guest agent. // GcpGuestImage image for GCP guest agent.

View file

@ -8,7 +8,6 @@ package main
import ( import (
"context" "context"
"encoding/base64"
"errors" "errors"
"flag" "flag"
"path/filepath" "path/filepath"
@ -53,18 +52,15 @@ func main() {
if len(salt) < crypto.RNGLengthDefault { if len(salt) < crypto.RNGLengthDefault {
log.With(zap.Error(errors.New("invalid salt length"))).Fatalf("Expected salt to be %d bytes, but got %d", crypto.RNGLengthDefault, len(salt)) log.With(zap.Error(errors.New("invalid salt length"))).Fatalf("Expected salt to be %d bytes, but got %d", crypto.RNGLengthDefault, len(salt))
} }
keyURI := setup.ClusterKMSURI + "?salt=" + base64.URLEncoding.EncodeToString(salt) masterSecret := setup.MasterSecret{Key: masterKey, Salt: salt}
// set up Key Management Service // set up Key Management Service
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel() defer cancel()
conKMS, err := setup.KMS(ctx, setup.NoStoreURI, keyURI) conKMS, err := setup.KMS(ctx, setup.NoStoreURI, masterSecret.EncodeToURI())
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to setup KMS") log.With(zap.Error(err)).Fatalf("Failed to setup KMS")
} }
if err := conKMS.CreateKEK(ctx, "Constellation", masterKey); err != nil {
log.With(zap.Error(err)).Fatalf("Failed to create KMS KEK from MasterKey")
}
if err := server.New(log.Named("keyservice"), conKMS).Run(*port); err != nil { if err := server.New(log.Named("keyservice"), conKMS).Run(*port); err != nil {
log.With(zap.Error(err)).Fatalf("Failed to run keyservice server") log.With(zap.Error(err)).Fatalf("Failed to run keyservice server")

View file

@ -74,7 +74,7 @@ func (s *Server) GetDataKey(ctx context.Context, in *keyserviceproto.GetDataKeyR
return nil, status.Error(codes.InvalidArgument, "no data key ID specified") return nil, status.Error(codes.InvalidArgument, "no data key ID specified")
} }
key, err := s.conKMS.GetDEK(ctx, "Constellation", crypto.HKDFInfoPrefix+in.DataKeyId, int(in.Length)) key, err := s.conKMS.GetDEK(ctx, crypto.DEKPrefix+in.DataKeyId, int(in.Length))
if err != nil { if err != nil {
log.With(zap.Error(err)).Errorf("Failed to get data key") log.With(zap.Error(err)).Errorf("Failed to get data key")
return nil, status.Errorf(codes.Internal, "%v", err) return nil, status.Errorf(codes.Internal, "%v", err)

View file

@ -63,7 +63,7 @@ func (c *stubKMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error
return nil return nil
} }
func (c *stubKMS) GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error) { func (c *stubKMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
if c.deriveKeyErr != nil { if c.deriveKeyErr != nil {
return nil, c.deriveKeyErr return nil, c.deriveKeyErr
} }

View file

@ -146,7 +146,7 @@ func TestAwsKms(t *testing.T) {
require.NotEqual(newKEKId1, newKEKId2) require.NotEqual(newKEKId1, newKEKId2)
var keyPolicyProducer createKeyPolicyFunc var keyPolicyProducer createKeyPolicyFunc
client, err := awsInterface.New(context.Background(), &keyPolicyProducer, nil) client, err := awsInterface.New(context.Background(), &keyPolicyProducer, nil, newKEKId1)
require.NoError(err) require.NoError(err)
privateKEK1 := []byte(strings.Repeat("1234", 8)) privateKEK1 := []byte(strings.Repeat("1234", 8))
@ -166,14 +166,14 @@ func TestAwsKms(t *testing.T) {
assert.NoError(client.CreateKEK(ctx, newKEKId1, privateKEK2)) assert.NoError(client.CreateKEK(ctx, newKEKId1, privateKEK2))
// make sure that GetDEK is idempotent // make sure that GetDEK is idempotent
volumeKey1, err := client.GetDEK(ctx, newKEKId1, "volume01", kmsconfig.SymmetricKeyLength) volumeKey1, err := client.GetDEK(ctx, "volume01", kmsconfig.SymmetricKeyLength)
require.NoError(err) require.NoError(err)
volumeKey1Copy, err := client.GetDEK(ctx, newKEKId1, "volume01", kmsconfig.SymmetricKeyLength) volumeKey1Copy, err := client.GetDEK(ctx, "volume01", kmsconfig.SymmetricKeyLength)
require.NoError(err) require.NoError(err)
assert.Equal(volumeKey1, volumeKey1Copy) assert.Equal(volumeKey1, volumeKey1Copy)
// test setting a second DEK // test setting a second DEK
volumeKey2, err := client.GetDEK(ctx, newKEKId1, "volume02", kmsconfig.SymmetricKeyLength) volumeKey2, err := client.GetDEK(ctx, "volume02", kmsconfig.SymmetricKeyLength)
require.NoError(err) require.NoError(err)
assert.NotEqual(volumeKey1, volumeKey2) assert.NotEqual(volumeKey1, volumeKey2)
@ -184,7 +184,7 @@ func TestAwsKms(t *testing.T) {
assert.NoError(client.CreateKEK(ctx, newKEKId2, privateKEK3)) assert.NoError(client.CreateKEK(ctx, newKEKId2, privateKEK3))
// test setting a DEK with AWS KMS generated KEK // test setting a DEK with AWS KMS generated KEK
volumeKey3, err := client.GetDEK(ctx, newKEKId2, "volume03", kmsconfig.SymmetricKeyLength) volumeKey3, err := client.GetDEK(ctx, "volume03", kmsconfig.SymmetricKeyLength)
require.NoError(err) require.NoError(err)
assert.NotEqual(volumeKey1, volumeKey3) assert.NotEqual(volumeKey1, volumeKey3)

View file

@ -66,22 +66,22 @@ func TestAzureKeyVault(t *testing.T) {
store := storage.NewMemMapStorage() store := storage.NewMemMapStorage()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel() defer cancel()
client, err := azure.New(ctx, azVaultName, azure.DefaultCloud, store, nil) kekName := "test-kek"
client, err := azure.New(ctx, azVaultName, azure.DefaultCloud, store, kekName, nil)
require.NoError(err) require.NoError(err)
kekName := "test-kek"
dekName := "test-dek" dekName := "test-dek"
assert.NoError(client.CreateKEK(ctx, kekName, nil)) assert.NoError(client.CreateKEK(ctx, kekName, nil))
res, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength) res, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
res2, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength) res2, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
assert.Equal(res, res2) assert.Equal(res, res2)
res3, err := client.GetDEK(ctx, kekName, addSuffix(dekName), config.SymmetricKeyLength) res3, err := client.GetDEK(ctx, addSuffix(dekName), config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
assert.Len(res3, config.SymmetricKeyLength) assert.Len(res3, config.SymmetricKeyLength)
assert.NotEqual(res, res3) assert.NotEqual(res, res3)
@ -102,10 +102,10 @@ func TestAzureHSM(t *testing.T) {
store := storage.NewMemMapStorage() store := storage.NewMemMapStorage()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel() defer cancel()
client, err := azure.NewHSM(ctx, azHSMName, store, nil) kekName := "test-kek"
client, err := azure.NewHSM(ctx, azHSMName, store, kekName, nil)
require.NoError(err) require.NoError(err)
kekName := "test-kek"
dekName := "test-dek" dekName := "test-dek"
importedKek := "test-kek-import" importedKek := "test-kek-import"
kekData := []byte{0x52, 0xFD, 0xFC, 0x07, 0x21, 0x82, 0x65, 0x4F, 0x16, 0x3F, 0x5F, 0x0F, 0x9A, 0x62, 0x1D, 0x72, 0x95, 0x66, 0xC7, 0x4D, 0x10, 0x03, 0x7C, 0x4D, 0x7B, 0xBB, 0x04, 0x07, 0xD1, 0xE2, 0xC6, 0x49} kekData := []byte{0x52, 0xFD, 0xFC, 0x07, 0x21, 0x82, 0x65, 0x4F, 0x16, 0x3F, 0x5F, 0x0F, 0x9A, 0x62, 0x1D, 0x72, 0x95, 0x66, 0xC7, 0x4D, 0x10, 0x03, 0x7C, 0x4D, 0x7B, 0xBB, 0x04, 0x07, 0xD1, 0xE2, 0xC6, 0x49}
@ -114,15 +114,15 @@ func TestAzureHSM(t *testing.T) {
assert.NoError(client.CreateKEK(ctx, kekName, nil)) assert.NoError(client.CreateKEK(ctx, kekName, nil))
res, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength) res, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
require.NoError(err) require.NoError(err)
assert.NotNil(res) assert.NotNil(res)
res2, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength) res2, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
require.NoError(err) require.NoError(err)
assert.Equal(res, res2) assert.Equal(res, res2)
res3, err := client.GetDEK(ctx, kekName, addSuffix(dekName), config.SymmetricKeyLength) res3, err := client.GetDEK(ctx, addSuffix(dekName), config.SymmetricKeyLength)
require.NoError(err) require.NoError(err)
assert.Len(res3, config.SymmetricKeyLength) assert.Len(res3, config.SymmetricKeyLength)
assert.NotEqual(res, res3) assert.NotEqual(res, res3)

View file

@ -42,20 +42,20 @@ func TestCreateGcpKEK(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel() defer cancel()
kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE) kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE, kekName)
require.NoError(err) require.NoError(err)
// Key name is random, but there is a chance we try to create a key that already exists, in that case the test fails // Key name is random, but there is a chance we try to create a key that already exists, in that case the test fails
assert.NoError(kmsClient.CreateKEK(ctx, kekName, nil)) assert.NoError(kmsClient.CreateKEK(ctx, kekName, nil))
res, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength) res, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
res2, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength) res2, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
assert.Equal(res, res2) assert.Equal(res, res2)
res3, err := kmsClient.GetDEK(ctx, kekName, addSuffix(dekName), config.SymmetricKeyLength) res3, err := kmsClient.GetDEK(ctx, addSuffix(dekName), config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
assert.Len(res3, config.SymmetricKeyLength) assert.Len(res3, config.SymmetricKeyLength)
assert.NotEqual(res, res3) assert.NotEqual(res, res3)
@ -76,15 +76,15 @@ func TestImportGcpKEK(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel() defer cancel()
kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE) kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE, kekName)
require.NoError(err) require.NoError(err)
assert.NoError(kmsClient.CreateKEK(ctx, kekName, kekData)) assert.NoError(kmsClient.CreateKEK(ctx, kekName, kekData))
res, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength) res, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
res2, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength) res2, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err) assert.NoError(err)
assert.Equal(res, res2) assert.Equal(res, res2)
} }