Refactor init/recovery to use kms URI

So far the masterSecret was sent to the initial bootstrapper
on init/recovery. With this commit this information is encoded
in the kmsURI that is sent during init.
For recover, the communication with the recoveryserver is
changed. Before a streaming gRPC call was used to
exchanges UUID for measurementSecret and state disk key.
Now a standard gRPC is made that includes the same kmsURI &
storageURI that are sent during init.
This commit is contained in:
Otto Bittner 2023-01-16 11:19:03 +01:00
parent 0e71322e2e
commit 9a1f52e94e
35 changed files with 466 additions and 623 deletions

View File

@ -23,7 +23,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
"github.com/edgelesssys/constellation/v2/internal/kms/kms"
kmsSetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/nodestate"
"github.com/edgelesssys/constellation/v2/internal/role"
@ -110,8 +110,13 @@ func (s *Server) Init(ctx context.Context, req *initproto.InitRequest) (*initpro
return nil, status.Errorf(codes.Internal, "invalid init secret %s", err)
}
cloudKms, err := kmssetup.KMS(ctx, req.StorageUri, req.KmsUri)
if err != nil {
return nil, fmt.Errorf("creating kms client: %w", err)
}
// generate values for cluster attestation
measurementSalt, clusterID, err := deriveMeasurementValues(req.MasterSecret, req.Salt)
measurementSalt, clusterID, err := deriveMeasurementValues(ctx, cloudKms)
if err != nil {
return nil, status.Errorf(codes.Internal, "deriving measurement values: %s", err)
}
@ -130,7 +135,7 @@ func (s *Server) Init(ctx context.Context, req *initproto.InitRequest) (*initpro
return nil, status.Error(codes.FailedPrecondition, "node is already being activated")
}
if err := s.setupDisk(req.MasterSecret, req.Salt); err != nil {
if err := s.setupDisk(ctx, cloudKms); err != nil {
return nil, status.Errorf(codes.Internal, "setting up disk: %s", err)
}
@ -177,7 +182,7 @@ func (s *Server) Stop() {
s.log.Infof("Stopped")
}
func (s *Server) setupDisk(masterSecret, salt []byte) error {
func (s *Server) setupDisk(ctx context.Context, cloudKms kms.CloudKMS) error {
if err := s.disk.Open(); err != nil {
return fmt.Errorf("opening encrypted disk: %w", err)
}
@ -189,7 +194,7 @@ func (s *Server) setupDisk(masterSecret, salt []byte) error {
}
uuid = strings.ToLower(uuid)
diskKey, err := crypto.DeriveKey(masterSecret, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.DerivedKeyLengthDefault)
diskKey, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+uuid, crypto.StateDiskKeyLength)
if err != nil {
return err
}
@ -197,12 +202,12 @@ func (s *Server) setupDisk(masterSecret, salt []byte) error {
return s.disk.UpdatePassphrase(string(diskKey))
}
func deriveMeasurementValues(masterSecret, hkdfSalt []byte) (salt, clusterID []byte, err error) {
func deriveMeasurementValues(ctx context.Context, cloudKms kms.CloudKMS) (salt, clusterID []byte, err error) {
salt, err = crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return nil, nil, err
}
secret, err := attestation.DeriveMeasurementSecret(masterSecret, hkdfSalt)
secret, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+crypto.MeasurementSecretKeyID, crypto.DerivedKeyLengthDefault)
if err != nil {
return nil, nil, err
}

View File

@ -19,6 +19,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/crypto/testvector"
"github.com/edgelesssys/constellation/v2/internal/file"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/edgelesssys/constellation/v2/internal/versions/components"
@ -30,7 +31,10 @@ import (
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
}
func TestNew(t *testing.T) {
@ -86,6 +90,8 @@ func TestInit(t *testing.T) {
initSecretHash, err := bcrypt.GenerateFromPassword(initSecret, bcrypt.DefaultCost)
require.NoError(t, err)
masterSecret := kmssetup.MasterSecret{Key: []byte("secret"), Salt: []byte("salt")}
testCases := map[string]struct {
nodeLock *fakeLock
initializer ClusterInitializer
@ -102,14 +108,14 @@ func TestInit(t *testing.T) {
disk: &stubDisk{},
fileHandler: file.NewHandler(afero.NewMemMapFs()),
initSecretHash: initSecretHash,
req: &initproto.InitRequest{InitSecret: initSecret},
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
},
"node locked": {
nodeLock: lockedLock,
initializer: &stubClusterInitializer{},
disk: &stubDisk{},
fileHandler: file.NewHandler(afero.NewMemMapFs()),
req: &initproto.InitRequest{InitSecret: initSecret},
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash,
wantErr: true,
wantShutdown: true,
@ -119,7 +125,7 @@ func TestInit(t *testing.T) {
initializer: &stubClusterInitializer{},
disk: &stubDisk{openErr: someErr},
fileHandler: file.NewHandler(afero.NewMemMapFs()),
req: &initproto.InitRequest{InitSecret: initSecret},
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash,
wantErr: true,
},
@ -128,7 +134,7 @@ func TestInit(t *testing.T) {
initializer: &stubClusterInitializer{},
disk: &stubDisk{uuidErr: someErr},
fileHandler: file.NewHandler(afero.NewMemMapFs()),
req: &initproto.InitRequest{InitSecret: initSecret},
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash,
wantErr: true,
},
@ -137,7 +143,7 @@ func TestInit(t *testing.T) {
initializer: &stubClusterInitializer{},
disk: &stubDisk{updatePassphraseErr: someErr},
fileHandler: file.NewHandler(afero.NewMemMapFs()),
req: &initproto.InitRequest{InitSecret: initSecret},
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash,
wantErr: true,
},
@ -146,7 +152,7 @@ func TestInit(t *testing.T) {
initializer: &stubClusterInitializer{},
disk: &stubDisk{},
fileHandler: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())),
req: &initproto.InitRequest{InitSecret: initSecret},
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash,
wantErr: true,
},
@ -155,7 +161,7 @@ func TestInit(t *testing.T) {
initializer: &stubClusterInitializer{initClusterErr: someErr},
disk: &stubDisk{},
fileHandler: file.NewHandler(afero.NewMemMapFs()),
req: &initproto.InitRequest{InitSecret: initSecret},
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
initSecretHash: initSecretHash,
wantErr: true,
},
@ -211,28 +217,29 @@ func TestInit(t *testing.T) {
func TestSetupDisk(t *testing.T) {
testCases := map[string]struct {
uuid string
masterSecret []byte
salt []byte
wantKey []byte
uuid string
masterKey []byte
salt []byte
wantKey []byte
}{
"lower case uuid": {
uuid: strings.ToLower(testvector.HKDF0xFF.Info),
masterSecret: testvector.HKDF0xFF.Secret,
salt: testvector.HKDF0xFF.Salt,
wantKey: testvector.HKDF0xFF.Output,
uuid: strings.ToLower(testvector.HKDF0xFF.Info),
masterKey: testvector.HKDF0xFF.Secret,
salt: testvector.HKDF0xFF.Salt,
wantKey: testvector.HKDF0xFF.Output,
},
"upper case uuid": {
uuid: strings.ToUpper(testvector.HKDF0xFF.Info),
masterSecret: testvector.HKDF0xFF.Secret,
salt: testvector.HKDF0xFF.Salt,
wantKey: testvector.HKDF0xFF.Output,
uuid: strings.ToUpper(testvector.HKDF0xFF.Info),
masterKey: testvector.HKDF0xFF.Secret,
salt: testvector.HKDF0xFF.Salt,
wantKey: testvector.HKDF0xFF.Output,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
disk := &fakeDisk{
uuid: tc.uuid,
@ -242,7 +249,11 @@ func TestSetupDisk(t *testing.T) {
disk: disk,
}
assert.NoError(server.setupDisk(tc.masterSecret, tc.salt))
masterSecret := kmssetup.MasterSecret{Key: tc.masterKey, Salt: tc.salt}
cloudKms, err := kmssetup.KMS(context.Background(), kmssetup.NoStoreURI, masterSecret.EncodeToURI())
require.NoError(err)
assert.NoError(server.setupDisk(context.Background(), cloudKms))
})
}
}

View File

@ -30,7 +30,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
keyservice "github.com/edgelesssys/constellation/v2/internal/kms/setup"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/versions"
@ -146,8 +146,8 @@ func (i *initCmd) initialize(cmd *cobra.Command, newDialer func(validator *cloud
req := &initproto.InitRequest{
MasterSecret: masterSecret.Key,
Salt: masterSecret.Salt,
KmsUri: keyservice.ClusterKMSURI,
StorageUri: keyservice.NoStoreURI,
KmsUri: masterSecret.EncodeToURI(),
StorageUri: kmssetup.NoStoreURI,
KeyEncryptionKeyId: "",
UseExistingKek: false,
CloudServiceAccountUri: serviceAccURI,
@ -296,26 +296,20 @@ type initFlags struct {
conformance bool
}
// masterSecret holds the master key and salt for deriving keys.
type masterSecret struct {
Key []byte `json:"key"`
Salt []byte `json:"salt"`
}
// readOrGenerateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler file.Handler, filename string) (masterSecret, error) {
func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler file.Handler, filename string) (kmssetup.MasterSecret, error) {
if filename != "" {
i.log.Debugf("Reading master secret from file %q", filename)
var secret masterSecret
var secret kmssetup.MasterSecret
if err := fileHandler.ReadJSON(filename, &secret); err != nil {
return masterSecret{}, err
return kmssetup.MasterSecret{}, err
}
if len(secret.Key) < crypto.MasterSecretLengthMin {
return masterSecret{}, fmt.Errorf("provided master secret is smaller than the required minimum of %d Bytes", crypto.MasterSecretLengthMin)
return kmssetup.MasterSecret{}, fmt.Errorf("provided master secret is smaller than the required minimum of %d Bytes", crypto.MasterSecretLengthMin)
}
if len(secret.Salt) < crypto.RNGLengthDefault {
return masterSecret{}, fmt.Errorf("provided salt is smaller than the required minimum of %d Bytes", crypto.RNGLengthDefault)
return kmssetup.MasterSecret{}, fmt.Errorf("provided salt is smaller than the required minimum of %d Bytes", crypto.RNGLengthDefault)
}
return secret, nil
}
@ -324,19 +318,19 @@ func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler fi
i.log.Debugf("Generating new master secret")
key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault)
if err != nil {
return masterSecret{}, err
return kmssetup.MasterSecret{}, err
}
salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return masterSecret{}, err
return kmssetup.MasterSecret{}, err
}
secret := masterSecret{
secret := kmssetup.MasterSecret{
Key: key,
Salt: salt,
}
i.log.Debugf("Generated master secret key and salt values")
if err := fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
return masterSecret{}, err
return kmssetup.MasterSecret{}, err
}
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to ./%s\n", constants.MasterSecretFilename)
return secret, nil

View File

@ -18,6 +18,8 @@ import (
"testing"
"time"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
@ -183,7 +185,7 @@ func TestInitialize(t *testing.T) {
require.NoError(err)
// assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID")))
assert.Contains(out.String(), hex.EncodeToString([]byte("clusterID")))
var secret masterSecret
var secret kmssetup.MasterSecret
assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret))
assert.NotEmpty(secret.Key)
assert.NotEmpty(secret.Salt)
@ -251,7 +253,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
createFileFunc: func(handler file.Handler) error {
return handler.WriteJSON(
"someSecret",
masterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")},
kmssetup.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")},
file.OptNone,
)
},
@ -282,7 +284,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
createFileFunc: func(handler file.Handler) error {
return handler.WriteJSON(
"shortSecret",
masterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("short")},
kmssetup.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("short")},
file.OptNone,
)
},
@ -294,7 +296,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
createFileFunc: func(handler file.Handler) error {
return handler.WriteJSON(
"shortSecret",
masterSecret{Key: []byte("short"), Salt: []byte("constellation-32Byte-length-salt")},
kmssetup.MasterSecret{Key: []byte("short"), Salt: []byte("constellation-32Byte-length-salt")},
file.OptNone,
)
},
@ -340,7 +342,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
tc.filename = strings.Trim(filename[1], "\n")
}
var masterSecret masterSecret
var masterSecret kmssetup.MasterSecret
require.NoError(fileHandler.ReadJSON(tc.filename, &masterSecret))
assert.Equal(masterSecret.Key, secret.Key)
assert.Equal(masterSecret.Salt, secret.Salt)

View File

@ -8,7 +8,6 @@ package cmd
import (
"context"
"errors"
"fmt"
"io"
"net"
@ -17,7 +16,6 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
@ -25,6 +23,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/spf13/afero"
"github.com/spf13/cobra"
@ -73,7 +72,7 @@ func (r *recoverCmd) recover(
}
r.log.Debugf("Using flags: %+v", flags)
var masterSecret masterSecret
var masterSecret kmssetup.MasterSecret
r.log.Debugf("Loading master secret file from %s", flags.secretPath)
if err := fileHandler.ReadJSON(flags.secretPath, &masterSecret); err != nil {
return err
@ -97,12 +96,7 @@ func (r *recoverCmd) recover(
r.log.Debugf("Created a new validator")
doer.setDialer(newDialer(validator), flags.endpoint)
r.log.Debugf("Set dialer for endpoint %s", flags.endpoint)
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt)
r.log.Debugf("Derived measurementSecret")
if err != nil {
return err
}
doer.setSecrets(getStateDiskKeyFunc(masterSecret.Key, masterSecret.Salt), measurementSecret)
doer.setURIs(masterSecret.EncodeToURI(), kmssetup.NoStoreURI)
r.log.Debugf("Set secrets")
if err := r.recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil {
if grpcRetry.ServiceIsUnavailable(err) {
@ -157,15 +151,15 @@ func (r *recoverCmd) recoverCall(ctx context.Context, out io.Writer, interval ti
type recoverDoerInterface interface {
Do(ctx context.Context) error
setDialer(dialer grpcDialer, endpoint string)
setSecrets(getDiskKey func(uuid string) ([]byte, error), measurementSecret []byte)
setURIs(kmsURI, storageURI string)
}
type recoverDoer struct {
dialer grpcDialer
endpoint string
measurementSecret []byte
getDiskKey func(uuid string) (key []byte, err error)
log debugLog
dialer grpcDialer
endpoint string
kmsURI string // encodes masterSecret
storageURI string
log debugLog
}
// Do performs the recover streaming rpc.
@ -177,53 +171,19 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
d.log.Debugf("Dialed recovery server")
defer conn.Close()
// set up streaming client
protoClient := recoverproto.NewAPIClient(conn)
d.log.Debugf("Created protoClient")
recoverclient, err := protoClient.Recover(ctx)
d.log.Debugf("Created recoverclient")
req := &recoverproto.RecoverMessage{
KmsUri: d.kmsURI,
StorageUri: d.storageURI,
}
_, err = protoClient.Recover(ctx, req)
if err != nil {
return fmt.Errorf("creating client: %w", err)
return fmt.Errorf("calling recover: %w", err)
}
defer func() {
_ = recoverclient.CloseSend()
}()
// send measurement secret as first message
if err := recoverclient.Send(&recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: d.measurementSecret,
},
}); err != nil {
return fmt.Errorf("sending measurement secret: %w", err)
}
d.log.Debugf("Sent measurement secret")
// receive disk uuid
res, err := recoverclient.Recv()
if err != nil {
return fmt.Errorf("receiving disk uuid: %w", err)
}
d.log.Debugf("Received disk uuid")
stateDiskKey, err := d.getDiskKey(res.DiskUuid)
if err != nil {
return fmt.Errorf("getting state disk key: %w", err)
}
d.log.Debugf("Got state disk key")
// send disk key
if err := recoverclient.Send(&recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: stateDiskKey,
},
}); err != nil {
return fmt.Errorf("sending state disk key: %w", err)
}
d.log.Debugf("Sent state disk key")
if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("receiving confirmation: %w", err)
}
d.log.Debugf("Received confirmation")
return nil
}
@ -233,9 +193,9 @@ func (d *recoverDoer) setDialer(dialer grpcDialer, endpoint string) {
d.endpoint = endpoint
}
func (d *recoverDoer) setSecrets(getDiskKey func(string) ([]byte, error), measurementSecret []byte) {
d.getDiskKey = getDiskKey
d.measurementSecret = measurementSecret
func (d *recoverDoer) setURIs(kmsURI, storageURI string) {
d.kmsURI = kmsURI
d.storageURI = storageURI
}
type recoverFlags struct {
@ -281,6 +241,6 @@ func (r *recoverCmd) parseRecoverFlags(cmd *cobra.Command, fileHandler file.Hand
func getStateDiskKeyFunc(masterKey, salt []byte) func(uuid string) ([]byte, error) {
return func(uuid string) ([]byte, error) {
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.StateDiskKeyLength)
return crypto.DeriveKey(masterKey, salt, []byte(crypto.DEKPrefix+uuid), crypto.StateDiskKeyLength)
}
}

View File

@ -26,6 +26,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
@ -156,7 +157,7 @@ func TestRecover(t *testing.T) {
require.NoError(fileHandler.WriteJSON(
"constellation-mastersecret.json",
masterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
kmssetup.MasterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
file.OptNone,
))
@ -244,101 +245,16 @@ func TestParseRecoverFlags(t *testing.T) {
}
func TestDoRecovery(t *testing.T) {
someErr := errors.New("error")
testCases := map[string]struct {
recoveryServer *stubRecoveryServer
wantErr bool
}{
"success": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return stream.Send(&recoverproto.RecoverResponse{
DiskUuid: "00000000-0000-0000-0000-000000000000",
})
},
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
}},
},
recoveryServer: &stubRecoveryServer{},
},
"error on first recv": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{
{
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
},
},
},
wantErr: true,
},
"error on send": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{
{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
},
},
},
wantErr: true,
},
"error on second recv": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{
{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return stream.Send(&recoverproto.RecoverResponse{
DiskUuid: "00000000-0000-0000-0000-000000000000",
})
},
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
},
},
},
wantErr: true,
},
"final message is an error": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return stream.Send(&recoverproto.RecoverResponse{
DiskUuid: "00000000-0000-0000-0000-000000000000",
})
},
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
}},
},
wantErr: true,
"server responds with error": {
recoveryServer: &stubRecoveryServer{recoverError: errors.New("someErr")},
wantErr: true,
},
}
@ -357,13 +273,9 @@ func TestDoRecovery(t *testing.T) {
r := &recoverCmd{log: logger.NewTest(t)}
recoverDoer := &recoverDoer{
dialer: dialer.New(nil, nil, netDialer),
endpoint: addr,
measurementSecret: []byte("measurement-secret"),
getDiskKey: func(string) ([]byte, error) {
return []byte("disk-key"), nil
},
log: r.log,
dialer: dialer.New(nil, nil, netDialer),
endpoint: addr,
log: r.log,
}
err := recoverDoer.Do(context.Background())
@ -402,23 +314,15 @@ func TestDeriveStateDiskKey(t *testing.T) {
}
type stubRecoveryServer struct {
actions [][]func(recoverproto.API_RecoverServer) error
calls int
recoverError error
recoverproto.UnimplementedAPIServer
}
func (s *stubRecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
if s.calls >= len(s.actions) {
return status.Error(codes.Unavailable, "server is unavailable")
func (s *stubRecoveryServer) Recover(context.Context, *recoverproto.RecoverMessage) (*recoverproto.RecoverResponse, error) {
if s.recoverError != nil {
return nil, s.recoverError
}
s.calls++
for _, action := range s.actions[s.calls-1] {
if err := action(stream); err != nil {
return err
}
}
return nil
return &recoverproto.RecoverResponse{}, nil
}
type stubDoer struct {
@ -437,4 +341,4 @@ func (d *stubDoer) Do(context.Context) error {
func (d *stubDoer) setDialer(grpcDialer, string) {}
func (d *stubDoer) setSecrets(func(string) ([]byte, error), []byte) {}
func (d *stubDoer) setURIs(kmsURI, storageURI string) {}

View File

@ -34,6 +34,7 @@ import (
qemucloud "github.com/edgelesssys/constellation/v2/internal/cloud/qemu"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/role"
tpmClient "github.com/google/go-tpm-tools/client"
@ -151,7 +152,7 @@ func main() {
// set up recovery server if control-plane node
var recoveryServer setup.RecoveryServer
if self.Role == role.ControlPlane {
recoveryServer = recoveryserver.New(issuer, log.Named("recoveryServer"))
recoveryServer = recoveryserver.New(issuer, kmssetup.KMS, log.Named("recoveryServer"))
} else {
recoveryServer = recoveryserver.NewStub(log.Named("recoveryServer"))
}

View File

@ -13,8 +13,10 @@ import (
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
"github.com/edgelesssys/constellation/v2/internal/kms/kms"
"github.com/edgelesssys/constellation/v2/internal/logger"
"go.uber.org/zap"
"google.golang.org/grpc"
@ -22,6 +24,8 @@ import (
"google.golang.org/grpc/status"
)
type kmsFactory func(ctx context.Context, storageURI string, kmsURI string) (kms.CloudKMS, error)
// RecoveryServer is a gRPC server that can be used by an admin to recover a restarting node.
type RecoveryServer struct {
mux sync.Mutex
@ -30,6 +34,7 @@ type RecoveryServer struct {
stateDiskKey []byte
measurementSecret []byte
grpcServer server
factory kmsFactory
log *logger.Logger
@ -37,9 +42,10 @@ type RecoveryServer struct {
}
// New returns a new RecoveryServer.
func New(issuer atls.Issuer, log *logger.Logger) *RecoveryServer {
func New(issuer atls.Issuer, factory kmsFactory, log *logger.Logger) *RecoveryServer {
server := &RecoveryServer{
log: log,
log: log,
factory: factory,
}
grpcServer := grpc.NewServer(
@ -87,47 +93,32 @@ func (s *RecoveryServer) Serve(ctx context.Context, listener net.Listener, diskU
}
// Recover is a bidirectional streaming RPC that is used to send recovery keys to a restarting node.
func (s *RecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
func (s *RecoveryServer) Recover(ctx context.Context, req *recoverproto.RecoverMessage) (*recoverproto.RecoverResponse, error) {
s.mux.Lock()
defer s.mux.Unlock()
log := s.log.With(zap.String("peer", grpclog.PeerAddrFromContext(stream.Context())))
log := s.log.With(zap.String("peer", grpclog.PeerAddrFromContext(ctx)))
log.Infof("Received recover call")
msg, err := stream.Recv()
cloudKms, err := s.factory(ctx, req.StorageUri, req.KmsUri)
if err != nil {
return status.Error(codes.Internal, "failed to receive message")
return nil, status.Errorf(codes.Internal, "creating kms client: %s", err)
}
measurementSecret, ok := msg.GetRequest().(*recoverproto.RecoverMessage_MeasurementSecret)
if !ok {
log.Errorf("Received invalid first message: not a measurement secret")
return status.Error(codes.InvalidArgument, "first message is not a measurement secret")
}
if err := stream.Send(&recoverproto.RecoverResponse{DiskUuid: s.diskUUID}); err != nil {
log.With(zap.Error(err)).Errorf("Failed to send disk UUID")
return status.Error(codes.Internal, "failed to send response")
}
msg, err = stream.Recv()
measurementSecret, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+crypto.MeasurementSecretKeyID, crypto.DerivedKeyLengthDefault)
if err != nil {
log.With(zap.Error(err)).Errorf("Failed to receive disk key")
return status.Error(codes.Internal, "failed to receive message")
return nil, status.Errorf(codes.Internal, "requesting measurementSecret: %s", err)
}
stateDiskKey, ok := msg.GetRequest().(*recoverproto.RecoverMessage_StateDiskKey)
if !ok {
log.Errorf("Received invalid second message: not a state disk key")
return status.Error(codes.InvalidArgument, "second message is not a state disk key")
stateDiskKey, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+s.diskUUID, crypto.StateDiskKeyLength)
if err != nil {
return nil, status.Errorf(codes.Internal, "requesting stateDiskKey: %s", err)
}
s.stateDiskKey = stateDiskKey.StateDiskKey
s.measurementSecret = measurementSecret.MeasurementSecret
s.stateDiskKey = stateDiskKey
s.measurementSecret = measurementSecret
log.Infof("Received state disk key and measurement secret, shutting down server")
go s.grpcServer.GracefulStop()
return nil
return &recoverproto.RecoverResponse{}, nil
}
// StubServer implements the RecoveryServer interface but does not actually start a server.

View File

@ -8,7 +8,7 @@ package recoveryserver
import (
"context"
"io"
"errors"
"sync"
"testing"
"time"
@ -17,6 +17,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/v2/internal/kms/kms"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/stretchr/testify/assert"
@ -25,14 +26,17 @@ import (
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
}
func TestServe(t *testing.T) {
assert := assert.New(t)
log := logger.NewTest(t)
uuid := "uuid"
server := New(atls.NewFakeIssuer(oid.Dummy{}), log)
server := New(atls.NewFakeIssuer(oid.Dummy{}), newStubKMS(nil, nil), log)
dialer := testdialer.NewBufconnDialer()
listener := dialer.GetListener("192.0.2.1:1234")
ctx, cancel := context.WithCancel(context.Background())
@ -49,7 +53,7 @@ func TestServe(t *testing.T) {
cancel()
wg.Wait()
server = New(atls.NewFakeIssuer(oid.Dummy{}), log)
server = New(atls.NewFakeIssuer(oid.Dummy{}), newStubKMS(nil, nil), log)
dialer = testdialer.NewBufconnDialer()
listener = dialer.GetListener("192.0.2.1:1234")
@ -71,59 +75,26 @@ func TestServe(t *testing.T) {
func TestRecover(t *testing.T) {
testCases := map[string]struct {
initialMsg message
keyMsg message
kmsURI string
storageURI string
factory kmsFactory
wantErr bool
}{
"success": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
},
},
// base64 encoded: key=masterkey&salt=somesalt
kmsURI: "kms://cluster-kms?key=bWFzdGVya2V5&salt=c29tZXNhbHQ=",
storageURI: "storage://no-store",
factory: newStubKMS(nil, nil),
},
"first message is not a measurement secret": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
},
wantErr: true,
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: []byte("diskKey"),
},
},
},
"kms init fails": {
factory: newStubKMS(errors.New("setup failed"), nil),
wantErr: true,
},
"second message is not a state disk key": {
initialMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
},
keyMsg: message{
recoverMsg: &recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: []byte("measurementSecret"),
},
},
wantErr: true,
},
"GetDEK fails": {
kmsURI: "kms://cluster-kms?key=bWFzdGVya2V5&salt=c29tZXNhbHQ=",
storageURI: "storage://no-store",
factory: newStubKMS(nil, errors.New("GetDEK failed")),
wantErr: true,
},
}
@ -134,7 +105,7 @@ func TestRecover(t *testing.T) {
ctx := context.Background()
serverUUID := "uuid"
server := New(atls.NewFakeIssuer(oid.Dummy{}), logger.NewTest(t))
server := New(atls.NewFakeIssuer(oid.Dummy{}), tc.factory, logger.NewTest(t))
netDialer := testdialer.NewBufconnDialer()
listener := netDialer.GetListener("192.0.2.1:1234")
@ -154,41 +125,46 @@ func TestRecover(t *testing.T) {
conn, err := dialer.New(nil, nil, netDialer).Dial(ctx, "192.0.2.1:1234")
require.NoError(err)
defer conn.Close()
client, err := recoverproto.NewAPIClient(conn).Recover(ctx)
require.NoError(err)
// Send initial message
err = client.Send(tc.initialMsg.recoverMsg)
require.NoError(err)
req := recoverproto.RecoverMessage{
KmsUri: tc.kmsURI,
StorageUri: tc.storageURI,
}
_, err = recoverproto.NewAPIClient(conn).Recover(ctx, &req)
// Receive uuid
uuid, err := client.Recv()
if tc.initialMsg.wantErr {
if tc.wantErr {
assert.Error(err)
return
}
assert.Equal(serverUUID, uuid.DiskUuid)
// Send key message
err = client.Send(tc.keyMsg.recoverMsg)
require.NoError(err)
_, err = client.Recv()
if tc.keyMsg.wantErr {
assert.Error(err)
return
}
assert.ErrorIs(io.EOF, err)
wg.Wait()
assert.NoError(serveErr)
assert.Equal(tc.initialMsg.recoverMsg.GetMeasurementSecret(), measurementSecret)
assert.Equal(tc.keyMsg.recoverMsg.GetStateDiskKey(), diskKey)
require.NoError(serveErr)
assert.NoError(err)
assert.NotNil(measurementSecret)
assert.NotNil(diskKey)
})
}
}
type message struct {
recoverMsg *recoverproto.RecoverMessage
wantErr bool
func newStubKMS(setupErr, getDEKErr error) kmsFactory {
return func(ctx context.Context, storageURI string, kmsURI string) (kms.CloudKMS, error) {
if setupErr != nil {
return nil, setupErr
}
return &stubKMS{getDEKErr: getDEKErr}, nil
}
}
type stubKMS struct {
getDEKErr error
}
func (s *stubKMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error {
return nil
}
func (s *stubKMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
if s.getDEKErr != nil {
return nil, s.getDEKErr
}
return []byte("someDEK"), nil
}

View File

@ -25,11 +25,10 @@ type RecoverMessage struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Types that are assignable to Request:
//
// *RecoverMessage_StateDiskKey
// *RecoverMessage_MeasurementSecret
Request isRecoverMessage_Request `protobuf_oneof:"request"`
// bytes state_disk_key = 1; removed
// bytes measurement_secret = 2; removed
KmsUri string `protobuf:"bytes,3,opt,name=kms_uri,json=kmsUri,proto3" json:"kms_uri,omitempty"`
StorageUri string `protobuf:"bytes,4,opt,name=storage_uri,json=storageUri,proto3" json:"storage_uri,omitempty"`
}
func (x *RecoverMessage) Reset() {
@ -64,49 +63,24 @@ func (*RecoverMessage) Descriptor() ([]byte, []int) {
return file_recover_proto_rawDescGZIP(), []int{0}
}
func (m *RecoverMessage) GetRequest() isRecoverMessage_Request {
if m != nil {
return m.Request
func (x *RecoverMessage) GetKmsUri() string {
if x != nil {
return x.KmsUri
}
return nil
return ""
}
func (x *RecoverMessage) GetStateDiskKey() []byte {
if x, ok := x.GetRequest().(*RecoverMessage_StateDiskKey); ok {
return x.StateDiskKey
func (x *RecoverMessage) GetStorageUri() string {
if x != nil {
return x.StorageUri
}
return nil
return ""
}
func (x *RecoverMessage) GetMeasurementSecret() []byte {
if x, ok := x.GetRequest().(*RecoverMessage_MeasurementSecret); ok {
return x.MeasurementSecret
}
return nil
}
type isRecoverMessage_Request interface {
isRecoverMessage_Request()
}
type RecoverMessage_StateDiskKey struct {
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3,oneof"`
}
type RecoverMessage_MeasurementSecret struct {
MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3,oneof"`
}
func (*RecoverMessage_StateDiskKey) isRecoverMessage_Request() {}
func (*RecoverMessage_MeasurementSecret) isRecoverMessage_Request() {}
type RecoverResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
DiskUuid string `protobuf:"bytes,1,opt,name=disk_uuid,json=diskUuid,proto3" json:"disk_uuid,omitempty"`
}
func (x *RecoverResponse) Reset() {
@ -141,39 +115,27 @@ func (*RecoverResponse) Descriptor() ([]byte, []int) {
return file_recover_proto_rawDescGZIP(), []int{1}
}
func (x *RecoverResponse) GetDiskUuid() string {
if x != nil {
return x.DiskUuid
}
return ""
}
var File_recover_proto protoreflect.FileDescriptor
var file_recover_proto_rawDesc = []byte{
0x0a, 0x0d, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
0x0c, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x74, 0x0a,
0x0c, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x4a, 0x0a,
0x0e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12,
0x26, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x74, 0x65,
0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2f, 0x0a, 0x12, 0x6d, 0x65, 0x61, 0x73, 0x75,
0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20,
0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65,
0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x22, 0x2e, 0x0a, 0x0f, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65,
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x75,
0x75, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x69, 0x73, 0x6b, 0x55,
0x75, 0x69, 0x64, 0x32, 0x53, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12, 0x4c, 0x0a, 0x07, 0x52, 0x65,
0x63, 0x6f, 0x76, 0x65, 0x72, 0x12, 0x1c, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73,
0x61, 0x67, 0x65, 0x1a, 0x1d, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x42, 0x5a, 0x40, 0x67, 0x69, 0x74, 0x68,
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73, 0x73, 0x73,
0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e,
0x2f, 0x76, 0x32, 0x2f, 0x64, 0x69, 0x73, 0x6b, 0x2d, 0x6d, 0x61, 0x70, 0x70, 0x65, 0x72, 0x2f,
0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x33,
0x17, 0x0a, 0x07, 0x6b, 0x6d, 0x73, 0x5f, 0x75, 0x72, 0x69, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
0x52, 0x06, 0x6b, 0x6d, 0x73, 0x55, 0x72, 0x69, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x74, 0x6f, 0x72,
0x61, 0x67, 0x65, 0x5f, 0x75, 0x72, 0x69, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73,
0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x55, 0x72, 0x69, 0x22, 0x11, 0x0a, 0x0f, 0x52, 0x65, 0x63,
0x6f, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x4f, 0x0a, 0x03,
0x41, 0x50, 0x49, 0x12, 0x48, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x12, 0x1c,
0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65,
0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1d, 0x2e, 0x72,
0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f,
0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x42, 0x5a,
0x40, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65,
0x6c, 0x65, 0x73, 0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c,
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x76, 0x32, 0x2f, 0x64, 0x69, 0x73, 0x6b, 0x2d, 0x6d, 0x61,
0x70, 0x70, 0x65, 0x72, 0x2f, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
@ -234,10 +196,6 @@ func file_recover_proto_init() {
}
}
}
file_recover_proto_msgTypes[0].OneofWrappers = []interface{}{
(*RecoverMessage_StateDiskKey)(nil),
(*RecoverMessage_MeasurementSecret)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{

View File

@ -5,16 +5,19 @@ package recoverproto;
option go_package = "github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto";
service API {
rpc Recover(stream RecoverMessage) returns (stream RecoverResponse) {}
// Recover sends the necessary information to the recoveryserver to start recovering the cluster.
rpc Recover(RecoverMessage) returns (RecoverResponse) {}
}
message RecoverMessage {
oneof request {
bytes state_disk_key = 1;
bytes measurement_secret = 2;
}
// bytes state_disk_key = 1; removed
// bytes measurement_secret = 2; removed
// kms_uri is the URI of the KMS the recoveryserver should use to decrypt DEKs.
string kms_uri = 3;
// storage_uri is the URI of the storage location the recoveryserver should use to fetch DEKs.
string storage_uri = 4;
}
message RecoverResponse {
string disk_uuid = 1;
// string disk_uuid = 1; removed
}

View File

@ -22,7 +22,7 @@ const _ = grpc.SupportPackageIsVersion7
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type APIClient interface {
Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error)
Recover(ctx context.Context, in *RecoverMessage, opts ...grpc.CallOption) (*RecoverResponse, error)
}
type aPIClient struct {
@ -33,42 +33,20 @@ func NewAPIClient(cc grpc.ClientConnInterface) APIClient {
return &aPIClient{cc}
}
func (c *aPIClient) Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error) {
stream, err := c.cc.NewStream(ctx, &API_ServiceDesc.Streams[0], "/recoverproto.API/Recover", opts...)
func (c *aPIClient) Recover(ctx context.Context, in *RecoverMessage, opts ...grpc.CallOption) (*RecoverResponse, error) {
out := new(RecoverResponse)
err := c.cc.Invoke(ctx, "/recoverproto.API/Recover", in, out, opts...)
if err != nil {
return nil, err
}
x := &aPIRecoverClient{stream}
return x, nil
}
type API_RecoverClient interface {
Send(*RecoverMessage) error
Recv() (*RecoverResponse, error)
grpc.ClientStream
}
type aPIRecoverClient struct {
grpc.ClientStream
}
func (x *aPIRecoverClient) Send(m *RecoverMessage) error {
return x.ClientStream.SendMsg(m)
}
func (x *aPIRecoverClient) Recv() (*RecoverResponse, error) {
m := new(RecoverResponse)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
return out, nil
}
// APIServer is the server API for API service.
// All implementations must embed UnimplementedAPIServer
// for forward compatibility
type APIServer interface {
Recover(API_RecoverServer) error
Recover(context.Context, *RecoverMessage) (*RecoverResponse, error)
mustEmbedUnimplementedAPIServer()
}
@ -76,8 +54,8 @@ type APIServer interface {
type UnimplementedAPIServer struct {
}
func (UnimplementedAPIServer) Recover(API_RecoverServer) error {
return status.Errorf(codes.Unimplemented, "method Recover not implemented")
func (UnimplementedAPIServer) Recover(context.Context, *RecoverMessage) (*RecoverResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Recover not implemented")
}
func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {}
@ -92,30 +70,22 @@ func RegisterAPIServer(s grpc.ServiceRegistrar, srv APIServer) {
s.RegisterService(&API_ServiceDesc, srv)
}
func _API_Recover_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(APIServer).Recover(&aPIRecoverServer{stream})
}
type API_RecoverServer interface {
Send(*RecoverResponse) error
Recv() (*RecoverMessage, error)
grpc.ServerStream
}
type aPIRecoverServer struct {
grpc.ServerStream
}
func (x *aPIRecoverServer) Send(m *RecoverResponse) error {
return x.ServerStream.SendMsg(m)
}
func (x *aPIRecoverServer) Recv() (*RecoverMessage, error) {
m := new(RecoverMessage)
if err := x.ServerStream.RecvMsg(m); err != nil {
func _API_Recover_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RecoverMessage)
if err := dec(in); err != nil {
return nil, err
}
return m, nil
if interceptor == nil {
return srv.(APIServer).Recover(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/recoverproto.API/Recover",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(APIServer).Recover(ctx, req.(*RecoverMessage))
}
return interceptor(ctx, in, info, handler)
}
// API_ServiceDesc is the grpc.ServiceDesc for API service.
@ -124,14 +94,12 @@ func (x *aPIRecoverServer) Recv() (*RecoverMessage, error) {
var API_ServiceDesc = grpc.ServiceDesc{
ServiceName: "recoverproto.API",
HandlerType: (*APIServer)(nil),
Methods: []grpc.MethodDesc{},
Streams: []grpc.StreamDesc{
Methods: []grpc.MethodDesc{
{
StreamName: "Recover",
Handler: _API_Recover_Handler,
ServerStreams: true,
ClientStreams: true,
MethodName: "Recover",
Handler: _API_Recover_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "recover.proto",
}

View File

@ -20,10 +20,5 @@ const (
// DeriveClusterID derives the cluster ID from a salt and secret value.
func DeriveClusterID(secret, salt []byte) ([]byte, error) {
return crypto.DeriveKey(secret, salt, []byte(crypto.HKDFInfoPrefix+clusterIDContext), crypto.DerivedKeyLengthDefault)
}
// DeriveMeasurementSecret derives the secret value needed to derive ClusterID.
func DeriveMeasurementSecret(masterSecret, salt []byte) ([]byte, error) {
return crypto.DeriveKey(masterSecret, salt, []byte(crypto.HKDFInfoPrefix+MeasurementSecretContext), crypto.DerivedKeyLengthDefault)
return crypto.DeriveKey(secret, salt, []byte(crypto.DEKPrefix+clusterIDContext), crypto.DerivedKeyLengthDefault)
}

View File

@ -31,21 +31,3 @@ func TestDeriveClusterID(t *testing.T) {
require.NoError(err)
assert.NotEqual(clusterID, clusterIDdiff)
}
func TestDeriveMeasurementSecret(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
testvector := testvector.HKDFMeasurementSecret
measurementSecret, err := DeriveMeasurementSecret(testvector.Secret, testvector.Salt)
require.NoError(err)
assert.Equal(testvector.Output, measurementSecret)
measurementSecretdiff, err := DeriveMeasurementSecret(testvector.Secret, []byte("different-salt"))
require.NoError(err)
assert.NotEqual(measurementSecret, measurementSecretdiff)
measurementSecretdiff, err = DeriveMeasurementSecret([]byte("different-secret"), testvector.Salt)
require.NoError(err)
assert.NotEqual(measurementSecret, measurementSecretdiff)
}

View File

@ -30,8 +30,10 @@ const (
MasterSecretLengthMin = 16
// RNGLengthDefault is the number of bytes used for generating nonces.
RNGLengthDefault = 32
// HKDFInfoPrefix is the prefix used for the info parameter in HKDF.
HKDFInfoPrefix = "key-"
// DEKPrefix is the prefix used to prefix DEK IDs. Originally introduced as a requirement for the HKDF info parameter.
DEKPrefix = "key-"
// MeasurementSecretKeyID is name used for the measurementSecret DEK.
MeasurementSecretKeyID = "measurementSecret"
)
// DeriveKey derives a key from a secret.

View File

@ -56,13 +56,14 @@ type KMSClient struct {
awsClient ClientAPI
policyProducer KeyPolicyProducer
storage kmsInterface.Storage
kekID string
}
// New creates and initializes a new KMSClient for AWS.
//
// The parameter client needs to be initialized with valid AWS credentials (https://aws.github.io/aws-sdk-go-v2/docs/getting-started).
// If storage is nil, the default MemMapStorage is used.
func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterface.Storage, optFns ...func(*awsconfig.LoadOptions) error) (*KMSClient, error) {
func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterface.Storage, kekID string, optFns ...func(*awsconfig.LoadOptions) error) (*KMSClient, error) {
if store == nil {
store = storage.NewMemMapStorage()
}
@ -77,6 +78,7 @@ func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterfa
awsClient: client,
policyProducer: policyProducer,
storage: store,
kekID: kekID,
}, nil
}
@ -206,9 +208,9 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
}
// GetDEK returns the DEK for dekID and kekID from the KMS.
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) {
func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
// The KEK should be identified by its alias. The alias always has the same scheme: 'alias/<kekId>'
kekID = "alias/" + kekID
kekID := "alias/" + c.kekID
// If a key for keyID exists in the storage, decrypt the key using the KEK.
dek, err := c.decryptDEKFromStorage(ctx, kekID, keyID)

View File

@ -253,22 +253,23 @@ func TestAWSKMSClient(t *testing.T) {
awsClient := &fakeAWSClient{kekPool: make(map[string][]byte, 2)}
client := &KMSClient{
awsClient: awsClient,
policyProducer: &stubKeyPolicyProducer{},
storage: storage.NewMemMapStorage(),
}
awsClient.keyIDCount = -1
testKEK1ID := "testKEK1"
testKEK1 := []byte("test KEK")
testKEK2ID := "testKEK2"
testKEK2 := []byte("more test KEK")
client := &KMSClient{
awsClient: awsClient,
policyProducer: &stubKeyPolicyProducer{},
storage: storage.NewMemMapStorage(),
kekID: testKEK1ID,
}
awsClient.keyIDCount = -1
ctx := context.Background()
// try to get a DEK before setting the KEK
_, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength)
_, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
assert.Error(err)
assert.ErrorIs(err, kmsInterface.ErrKEKUnknown)
@ -285,16 +286,16 @@ func TestAWSKMSClient(t *testing.T) {
assert.Equal(testKEK2, awsClient.kekPool[strconv.Itoa(awsClient.keyIDCount)])
// test GetDEK method
dek1, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength)
dek1, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
assert.NoError(err)
dek2, err := client.GetDEK(ctx, testKEK2ID, "volume02", config.SymmetricKeyLength)
dek2, err := client.GetDEK(ctx, "volume02", config.SymmetricKeyLength)
assert.NoError(err)
// make sure that GetDEK is idempotent
dek1Copy, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength)
dek1Copy, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
assert.NoError(err)
assert.Equal(dek1, dek1Copy)
dek2Copy, err := client.GetDEK(ctx, testKEK2ID, "volume02", config.SymmetricKeyLength)
dek2Copy, err := client.GetDEK(ctx, "volume02", config.SymmetricKeyLength)
assert.NoError(err)
assert.Equal(dek2, dek2Copy)
}

View File

@ -47,6 +47,7 @@ type kmsClientAPI interface {
type KMSClient struct {
client kmsClientAPI
storage kms.Storage
kekID string
}
// Opts are optional settings for AKV clients.
@ -57,7 +58,7 @@ type Opts struct {
}
// New initializes a KMS client for Azure Key Vault.
func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms.Storage, opts *Opts) (*KMSClient, error) {
func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms.Storage, kekID string, opts *Opts) (*KMSClient, error) {
if opts == nil {
opts = &Opts{}
}
@ -80,7 +81,7 @@ func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms
if store == nil {
store = storage.NewMemMapStorage()
}
return &KMSClient{client: client, storage: store}, nil
return &KMSClient{client: client, storage: store, kekID: kekID}, nil
}
// CreateKEK saves a new Key Encryption Key using Azure Key Vault.
@ -111,8 +112,8 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
}
// GetDEK decrypts a DEK from storage.
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) {
kek, err := c.getKEK(ctx, kekID)
func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
kek, err := c.getKEK(ctx, c.kekID)
if err != nil {
return nil, fmt.Errorf("loading KEK from key vault: %w", err)
}

View File

@ -155,7 +155,7 @@ func TestKMSGetDEK(t *testing.T) {
storage: tc.storage,
}
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
dek, err := client.GetDEK(context.Background(), "volume-01", 32)
if tc.wantErr {
assert.Error(err)
} else {

View File

@ -38,10 +38,11 @@ type HSMClient struct {
client hsmClientAPI
storage kms.Storage
vaultURL string
kekID string
}
// NewHSM initializes a KMS client for Azure manged HSM Key Vault.
func NewHSM(ctx context.Context, vaultName string, store kms.Storage, opts *Opts) (*HSMClient, error) {
func NewHSM(ctx context.Context, vaultName string, store kms.Storage, kekID string, opts *Opts) (*HSMClient, error) {
if opts == nil {
opts = &Opts{}
}
@ -72,6 +73,7 @@ func NewHSM(ctx context.Context, vaultName string, store kms.Storage, opts *Opts
client: client,
credentials: cred,
storage: store,
kekID: kekID,
}, nil
}
@ -114,7 +116,7 @@ func (c *HSMClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
}
// GetDEK loads an encrypted DEK from storage and unwraps it using an HSM-backed key.
func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekSize int) ([]byte, error) {
func (c *HSMClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
encryptedDEK, err := c.storage.Get(ctx, keyID)
if err != nil {
if !errors.Is(err, storage.ErrDEKUnset) {
@ -126,7 +128,7 @@ func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekS
if err != nil {
return nil, fmt.Errorf("key generation: %w", err)
}
if err := c.putDEK(ctx, kekID, keyID, newDEK); err != nil {
if err := c.putDEK(ctx, c.kekID, keyID, newDEK); err != nil {
return nil, fmt.Errorf("creating new DEK: %w", err)
}
@ -137,7 +139,7 @@ func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekS
Algorithm: to.Ptr(azkeys.JSONWebKeyEncryptionAlgorithmA256KW),
Value: encryptedDEK,
}
res, err := c.client.UnwrapKey(ctx, kekID, "", params, &azkeys.UnwrapKeyOptions{})
res, err := c.client.UnwrapKey(ctx, c.kekID, "", params, &azkeys.UnwrapKeyOptions{})
if err != nil {
return nil, fmt.Errorf("unwrapping key: %w", err)
}

View File

@ -165,7 +165,7 @@ func TestHSMGetNewDEK(t *testing.T) {
storage: tc.storage,
}
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
dek, err := client.GetDEK(context.Background(), "volume-01", 32)
if tc.wantErr {
assert.Error(err)
} else {
@ -208,7 +208,7 @@ func TestHSMGetExistingDEK(t *testing.T) {
storage: storage,
}
dek, err := client.GetDEK(context.Background(), "test-key", keyID, len(testKey))
dek, err := client.GetDEK(context.Background(), keyID, len(testKey))
if tc.wantErr {
assert.Error(err)
} else {

View File

@ -20,8 +20,15 @@ type KMS struct {
}
// New creates a new ClusterKMS.
func New(salt []byte) *KMS {
return &KMS{salt: salt}
func New(key []byte, salt []byte) (*KMS, error) {
if len(key) == 0 {
return nil, errors.New("missing master key")
}
if len(salt) == 0 {
return nil, errors.New("missing salt")
}
return &KMS{masterKey: key, salt: salt}, nil
}
// CreateKEK sets the ClusterKMS masterKey.
@ -31,7 +38,7 @@ func (c *KMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error {
}
// GetDEK derives a key from the KMS masterKey.
func (c *KMS) GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error) {
func (c *KMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
if len(c.masterKey) == 0 {
return nil, errors.New("master key not set for Constellation KMS")
}

View File

@ -24,19 +24,12 @@ func TestMain(m *testing.M) {
func TestClusterKMS(t *testing.T) {
testVector := testvector.HKDF0xFF
assert := assert.New(t)
kms := New(testVector.Salt)
key, err := kms.GetDEK(context.Background(), "", "key-1", 32)
assert.Error(err)
assert.Nil(key)
err = kms.CreateKEK(context.Background(), "", testVector.Secret)
assert.NoError(err)
assert.Equal(testVector.Secret, kms.masterKey)
require := require.New(t)
kms, err := New(testVector.Secret, testVector.Salt)
require.NoError(err)
keyLower, err := kms.GetDEK(
context.Background(),
"",
strings.ToLower(testVector.InfoPrefix+testVector.Info),
int(testVector.Length),
)
@ -46,12 +39,11 @@ func TestClusterKMS(t *testing.T) {
// output of the KMS should be case sensitive
keyUpper, err := kms.GetDEK(
context.Background(),
"",
strings.ToUpper(testVector.InfoPrefix+testVector.Info),
int(testVector.Length),
)
assert.NoError(err)
assert.NotEqual(key, keyUpper)
assert.NotEqual(keyLower, keyUpper)
}
func TestVectorsHKDF(t *testing.T) {
@ -61,6 +53,7 @@ func TestVectorsHKDF(t *testing.T) {
dekID string
dekSize uint
wantKey []byte
wantErr bool
}{
"rfc Test Case 1": {
kek: testvector.HKDFrfc1.Secret,
@ -82,6 +75,7 @@ func TestVectorsHKDF(t *testing.T) {
dekID: testvector.HKDFrfc3.Info,
dekSize: testvector.HKDFrfc3.Length,
wantKey: testvector.HKDFrfc3.Output,
wantErr: true,
},
"HKDF zero": {
kek: testvector.HKDFZero.Secret,
@ -104,10 +98,15 @@ func TestVectorsHKDF(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
kms := New(tc.salt)
kms, err := New(tc.kek, tc.salt)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
require.NoError(kms.CreateKEK(context.Background(), "", tc.kek))
out, err := kms.GetDEK(context.Background(), "", tc.dekID, int(tc.dekSize))
out, err := kms.GetDEK(context.Background(), tc.dekID, int(tc.dekSize))
require.NoError(err)
assert.Equal(tc.wantKey, out)
})

View File

@ -50,11 +50,12 @@ type KMSClient struct {
waitBackoffLimit int
storage kmsInterface.Storage
protectionLevel kmspb.ProtectionLevel
kekID string
opts []gax.CallOption
}
// New initializes a KMS client for Google Cloud Platform.
func New(ctx context.Context, projectID, locationID, keyRingID string, store kmsInterface.Storage, protectionLvl kmspb.ProtectionLevel, opts ...gax.CallOption) (*KMSClient, error) {
func New(ctx context.Context, projectID, locationID, keyRingID string, store kmsInterface.Storage, protectionLvl kmspb.ProtectionLevel, kekID string, opts ...gax.CallOption) (*KMSClient, error) {
if store == nil {
store = storage.NewMemMapStorage()
}
@ -71,6 +72,7 @@ func New(ctx context.Context, projectID, locationID, keyRingID string, store kms
waitBackoffLimit: 10,
storage: store,
protectionLevel: protectionLvl,
kekID: kekID,
opts: opts,
}
@ -108,7 +110,7 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
}
// GetDEK fetches an encrypted Data Encryption Key from storage and decrypts it using a KEK stored in Google's KMS.
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) {
func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
client, err := c.newClient(ctx)
if err != nil {
return nil, err
@ -126,11 +128,11 @@ func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int
if err != nil {
return nil, fmt.Errorf("key generation: %w", err)
}
return newDEK, c.putDEK(ctx, client, kekID, keyID, newDEK)
return newDEK, c.putDEK(ctx, client, c.kekID, keyID, newDEK)
}
request := &kmspb.DecryptRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.projectID, c.locationID, c.keyRingID, kekID),
Name: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.projectID, c.locationID, c.keyRingID, c.kekID),
Ciphertext: encryptedDEK,
}

View File

@ -321,7 +321,7 @@ func TestGetDEK(t *testing.T) {
storage: tc.storage,
}
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
dek, err := client.GetDEK(context.Background(), "volume-01", 32)
if tc.wantErr {
assert.Error(err)
} else {

View File

@ -17,7 +17,7 @@ type CloudKMS interface {
CreateKEK(ctx context.Context, keyID string, kek []byte) error
// GetDEK returns the DEK for dekID and kekID from the KMS.
// If the DEK does not exist, a new one is created and saved to storage.
GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error)
GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error)
}
// Storage provides an abstract interface for the storage backend used for DEKs.

View File

@ -24,10 +24,10 @@ import (
// Well known endpoints for KMS services.
const (
AWSKMSURI = "kms://aws?keyPolicy=%s"
AzureKMSURI = "kms://azure-kms?name=%s&type=%s"
AzureHSMURI = "kms://azure-hsm?name=%s"
GCPKMSURI = "kms://gcp?project=%s&location=%s&keyRing=%s&protectionLvl=%s"
AWSKMSURI = "kms://aws?keyPolicy=%s&kekID=%s"
AzureKMSURI = "kms://azure-kms?name=%s&type=%s&kekID=%s"
AzureHSMURI = "kms://azure-hsm?name=%s&kekID=%s"
GCPKMSURI = "kms://gcp?project=%s&location=%s&keyRing=%s&protectionLvl=%s&kekID=%s"
ClusterKMSURI = "kms://cluster-kms?key=%s&salt=%s"
AWSS3URI = "storage://aws?bucket=%s"
AzureBlobURI = "storage://azure?container=%s&connectionString=%s"
@ -35,6 +35,21 @@ const (
NoStoreURI = "storage://no-store"
)
// MasterSecret holds the master key and salt for deriving keys.
type MasterSecret struct {
Key []byte `json:"key"`
Salt []byte `json:"salt"`
}
// EncodeToURI returns an URI encoding the master secret.
func (m *MasterSecret) EncodeToURI() string {
return fmt.Sprintf(
ClusterKMSURI,
base64.URLEncoding.EncodeToString(m.Key),
base64.URLEncoding.EncodeToString(m.Salt),
)
}
// KMSInformation about an existing KMS.
type KMSInformation struct {
KMSURI string
@ -104,39 +119,39 @@ func getKMS(ctx context.Context, kmsURI string, store kms.Storage) (kms.CloudKMS
switch uri.Host {
case "aws":
poliyProducer, err := getAWSKMSConfig(uri)
poliyProducer, kekID, err := getAWSKMSConfig(uri)
if err != nil {
return nil, err
}
return aws.New(ctx, poliyProducer, store)
return aws.New(ctx, poliyProducer, store, kekID)
case "azure-kms":
vaultName, vaultType, err := getAzureKMSConfig(uri)
vaultName, vaultType, kekID, err := getAzureKMSConfig(uri)
if err != nil {
return nil, err
}
return azure.New(ctx, vaultName, azure.VaultSuffix(vaultType), store, nil)
return azure.New(ctx, vaultName, azure.VaultSuffix(vaultType), store, kekID, nil)
case "azure-hsm":
vaultName, err := getAzureHSMConfig(uri)
vaultName, kekID, err := getAzureHSMConfig(uri)
if err != nil {
return nil, err
}
return azure.NewHSM(ctx, vaultName, store, nil)
return azure.NewHSM(ctx, vaultName, store, kekID, nil)
case "gcp":
project, location, keyRing, protectionLvl, err := getGCPKMSConfig(uri)
project, location, keyRing, protectionLvl, kekID, err := getGCPKMSConfig(uri)
if err != nil {
return nil, err
}
return gcp.New(ctx, project, location, keyRing, store, kmspb.ProtectionLevel(protectionLvl))
return gcp.New(ctx, project, location, keyRing, store, kmspb.ProtectionLevel(protectionLvl), kekID)
case "cluster-kms":
salt, err := getClusterKMSConfig(uri)
masterSecret, err := getClusterKMSConfig(uri)
if err != nil {
return nil, err
}
return cluster.New(salt), nil
return cluster.New(masterSecret.Key, masterSecret.Salt)
default:
return nil, fmt.Errorf("unknown KMS type: %s", uri.Host)
@ -156,22 +171,56 @@ func getAWSS3Config(uri *url.URL) (string, error) {
return r[0], err
}
func getAWSKMSConfig(uri *url.URL) (*defaultPolicyProducer, error) {
r, err := getConfig(uri.Query(), []string{"keyPolicy"})
func getAWSKMSConfig(uri *url.URL) (*defaultPolicyProducer, string, error) {
r, err := getConfig(uri.Query(), []string{"keyPolicy", "kekID"})
if err != nil {
return nil, err
return nil, "", err
}
return &defaultPolicyProducer{policy: r[0]}, err
if len(r) != 2 {
return nil, "", fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
}
kekID, err := base64.URLEncoding.DecodeString(r[1])
if err != nil {
return nil, "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return &defaultPolicyProducer{policy: r[0]}, string(kekID), err
}
func getAzureKMSConfig(uri *url.URL) (string, string, error) {
r, err := getConfig(uri.Query(), []string{"name", "type"})
return r[0], r[1], err
func getAzureKMSConfig(uri *url.URL) (string, string, string, error) {
r, err := getConfig(uri.Query(), []string{"name", "type", "kekID"})
if err != nil {
return "", "", "", fmt.Errorf("getting config: %w", err)
}
if len(r) != 3 {
return "", "", "", fmt.Errorf("expected 3 KmsURI args, got %d", len(r))
}
kekID, err := base64.URLEncoding.DecodeString(r[2])
if err != nil {
return "", "", "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return r[0], r[1], string(kekID), err
}
func getAzureHSMConfig(uri *url.URL) (string, error) {
r, err := getConfig(uri.Query(), []string{"name"})
return r[0], err
func getAzureHSMConfig(uri *url.URL) (string, string, error) {
r, err := getConfig(uri.Query(), []string{"name", "kekID"})
if err != nil {
return "", "", fmt.Errorf("getting config: %w", err)
}
if len(r) != 2 {
return "", "", fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
}
kekID, err := base64.URLEncoding.DecodeString(r[1])
if err != nil {
return "", "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return r[0], string(kekID), err
}
func getAzureBlobConfig(uri *url.URL) (string, string, error) {
@ -182,16 +231,26 @@ func getAzureBlobConfig(uri *url.URL) (string, string, error) {
return r[0], r[1], nil
}
func getGCPKMSConfig(uri *url.URL) (string, string, string, int32, error) {
r, err := getConfig(uri.Query(), []string{"project", "location", "keyRing", "protectionLvl"})
func getGCPKMSConfig(uri *url.URL) (project string, location string, keyRing string, protectionLvl int32, kekID string, err error) {
r, err := getConfig(uri.Query(), []string{"project", "location", "keyRing", "protectionLvl", "kekID"})
if err != nil {
return "", "", "", 0, err
return "", "", "", 0, "", err
}
protectionLvl, err := strconv.ParseInt(r[3], 10, 32)
if len(r) != 5 {
return "", "", "", 0, "", fmt.Errorf("expected 5 KmsURI args, got %d", len(r))
}
kekIDByte, err := base64.URLEncoding.DecodeString(r[4])
if err != nil {
return "", "", "", 0, err
return "", "", "", 0, "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return r[0], r[1], r[2], int32(protectionLvl), nil
protectionLvl32, err := strconv.ParseInt(r[3], 10, 32)
if err != nil {
return "", "", "", 0, "", err
}
return r[0], r[1], r[2], int32(protectionLvl32), string(kekIDByte), nil
}
func getGCPStorageConfig(uri *url.URL) (string, string, error) {
@ -199,12 +258,26 @@ func getGCPStorageConfig(uri *url.URL) (string, string, error) {
return r[0], r[1], err
}
func getClusterKMSConfig(uri *url.URL) ([]byte, error) {
r, err := getConfig(uri.Query(), []string{"salt"})
func getClusterKMSConfig(uri *url.URL) (MasterSecret, error) {
r, err := getConfig(uri.Query(), []string{"key", "salt"})
if err != nil {
return nil, err
return MasterSecret{}, err
}
return base64.URLEncoding.DecodeString(r[0])
if len(r) != 2 {
return MasterSecret{}, fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
}
key, err := base64.URLEncoding.DecodeString(r[0])
if err != nil {
return MasterSecret{}, fmt.Errorf("parsing key from kmsUri: %w", err)
}
salt, err := base64.URLEncoding.DecodeString(r[1])
if err != nil {
return MasterSecret{}, fmt.Errorf("parsing salt from kmsUri: %w", err)
}
return MasterSecret{Key: key, Salt: salt}, nil
}
// getConfig parses url query values, returning a map of the requested values.

View File

@ -18,6 +18,8 @@ import (
"go.uber.org/goleak"
)
const constellationKekID = "Constellation"
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
@ -80,23 +82,23 @@ func TestGetKMS(t *testing.T) {
wantErr bool
}{
"cluster kms": {
uri: fmt.Sprintf("%s?salt=%s", ClusterKMSURI, base64.URLEncoding.EncodeToString([]byte("salt"))),
uri: fmt.Sprintf(ClusterKMSURI, base64.URLEncoding.EncodeToString([]byte("key")), base64.URLEncoding.EncodeToString([]byte("salt"))),
wantErr: false,
},
"aws kms": {
uri: fmt.Sprintf(AWSKMSURI, ""),
uri: fmt.Sprintf(AWSKMSURI, "", ""),
wantErr: true,
},
"azure kms": {
uri: fmt.Sprintf(AzureKMSURI, "", ""),
uri: fmt.Sprintf(AzureKMSURI, "", "", ""),
wantErr: true,
},
"azure hsm": {
uri: fmt.Sprintf(AzureHSMURI, ""),
uri: fmt.Sprintf(AzureHSMURI, "", ""),
wantErr: true,
},
"gcp kms": {
uri: fmt.Sprintf(GCPKMSURI, "", "", "", ""),
uri: fmt.Sprintf(GCPKMSURI, "", "", "", "", ""),
wantErr: true,
},
"unknown kms": {
@ -135,7 +137,8 @@ func TestSetUpKMS(t *testing.T) {
assert.Error(err)
assert.Nil(kms)
kms, err = KMS(context.Background(), "storage://no-store", "kms://cluster-kms?salt="+base64.URLEncoding.EncodeToString([]byte("salt")))
masterSecret := MasterSecret{Key: []byte("key"), Salt: []byte("salt")}
kms, err = KMS(context.Background(), "storage://no-store", masterSecret.EncodeToURI())
assert.NoError(err)
assert.NotNil(kms)
}
@ -146,13 +149,15 @@ func TestGetAWSKMSConfig(t *testing.T) {
policy := "{keyPolicy: keyPolicy}"
escapedPolicy := url.QueryEscape(policy)
uri, err := url.Parse(fmt.Sprintf(AWSKMSURI, escapedPolicy))
kekID := base64.URLEncoding.EncodeToString([]byte(constellationKekID))
uri, err := url.Parse(fmt.Sprintf(AWSKMSURI, escapedPolicy, kekID))
require.NoError(err)
policyProducer, err := getAWSKMSConfig(uri)
policyProducer, rKekID, err := getAWSKMSConfig(uri)
require.NoError(err)
keyPolicy, err := policyProducer.CreateKeyPolicy("")
require.NoError(err)
assert.Equal(policy, keyPolicy)
assert.Equal(constellationKekID, rKekID)
}
func TestGetAzureBlobConfig(t *testing.T) {
@ -178,18 +183,20 @@ func TestGetGCPKMSConfig(t *testing.T) {
location := "global"
keyRing := "test-ring"
protectionLvl := "2"
uri, err := url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, protectionLvl))
kekID := base64.URLEncoding.EncodeToString([]byte(constellationKekID))
uri, err := url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, protectionLvl, kekID))
require.NoError(err)
rProject, rLocation, rKeyRing, rProtectionLvl, err := getGCPKMSConfig(uri)
rProject, rLocation, rKeyRing, rProtectionLvl, rKekID, err := getGCPKMSConfig(uri)
require.NoError(err)
assert.Equal(project, rProject)
assert.Equal(location, rLocation)
assert.Equal(keyRing, rKeyRing)
assert.Equal(int32(2), rProtectionLvl)
assert.Equal(constellationKekID, rKekID)
uri, err = url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, "invalid"))
uri, err = url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, "invalid", kekID))
require.NoError(err)
_, _, _, _, err = getGCPKMSConfig(uri)
_, _, _, _, _, err = getGCPKMSConfig(uri)
assert.Error(err)
}
@ -202,12 +209,13 @@ func TestGetClusterKMSConfig(t *testing.T) {
0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf,
}
uri, err := url.Parse(ClusterKMSURI + "?salt=" + base64.URLEncoding.EncodeToString(expectedSalt))
masterSecretIn := MasterSecret{Key: []byte("key"), Salt: expectedSalt}
uri, err := url.Parse(masterSecretIn.EncodeToURI())
require.NoError(err)
salt, err := getClusterKMSConfig(uri)
masterSecretOut, err := getClusterKMSConfig(uri)
assert.NoError(err)
assert.Equal(expectedSalt, salt)
assert.Equal(expectedSalt, masterSecretOut.Salt)
}
func TestGetConfig(t *testing.T) {

View File

@ -59,7 +59,7 @@ const (
// JoinImage image of Constellation join service.
JoinImage = "ghcr.io/edgelesssys/constellation/join-service:v2.5.0-pre.0.20230118154955-632090c21b93@sha256:7c53b43f2580ded9f04a9927d4ff585d3edce5d10a1d83006688c818e6395eb1" // renovate:container
// KeyServiceImage image of Constellation KMS server.
KeyServiceImage = "ghcr.io/edgelesssys/constellation/kmsserver:v2.5.0-pre.0.20230112123617-d0e9f427d1ba@sha256:d4319308eb62e2ee079cc86858acdd1faccc404edec7bfabecf35861284a55f3" // renovate:container
KeyServiceImage = "ghcr.io/edgelesssys/constellation/keyservice:v2.5.0-pre.0.20230116125211-d37bd077d8c6@sha256:4c14176f94899054bbf945f6f209521ffcdbcb9042abc5850d778240fe3693a4" // renovate:container
// VerificationImage image of Constellation verification service.
VerificationImage = "ghcr.io/edgelesssys/constellation/verification-service:v2.5.0-pre.0.20230118154955-632090c21b93@sha256:593f735a236f0cb8f4373a7a2dca41be9ab2ba1b784a2ebcf8fb5271705822a3" // renovate:container
// GcpGuestImage image for GCP guest agent.

View File

@ -8,7 +8,6 @@ package main
import (
"context"
"encoding/base64"
"errors"
"flag"
"path/filepath"
@ -53,18 +52,15 @@ func main() {
if len(salt) < crypto.RNGLengthDefault {
log.With(zap.Error(errors.New("invalid salt length"))).Fatalf("Expected salt to be %d bytes, but got %d", crypto.RNGLengthDefault, len(salt))
}
keyURI := setup.ClusterKMSURI + "?salt=" + base64.URLEncoding.EncodeToString(salt)
masterSecret := setup.MasterSecret{Key: masterKey, Salt: salt}
// set up Key Management Service
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
conKMS, err := setup.KMS(ctx, setup.NoStoreURI, keyURI)
conKMS, err := setup.KMS(ctx, setup.NoStoreURI, masterSecret.EncodeToURI())
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to setup KMS")
}
if err := conKMS.CreateKEK(ctx, "Constellation", masterKey); err != nil {
log.With(zap.Error(err)).Fatalf("Failed to create KMS KEK from MasterKey")
}
if err := server.New(log.Named("keyservice"), conKMS).Run(*port); err != nil {
log.With(zap.Error(err)).Fatalf("Failed to run keyservice server")

View File

@ -74,7 +74,7 @@ func (s *Server) GetDataKey(ctx context.Context, in *keyserviceproto.GetDataKeyR
return nil, status.Error(codes.InvalidArgument, "no data key ID specified")
}
key, err := s.conKMS.GetDEK(ctx, "Constellation", crypto.HKDFInfoPrefix+in.DataKeyId, int(in.Length))
key, err := s.conKMS.GetDEK(ctx, crypto.DEKPrefix+in.DataKeyId, int(in.Length))
if err != nil {
log.With(zap.Error(err)).Errorf("Failed to get data key")
return nil, status.Errorf(codes.Internal, "%v", err)

View File

@ -63,7 +63,7 @@ func (c *stubKMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error
return nil
}
func (c *stubKMS) GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error) {
func (c *stubKMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
if c.deriveKeyErr != nil {
return nil, c.deriveKeyErr
}

View File

@ -146,7 +146,7 @@ func TestAwsKms(t *testing.T) {
require.NotEqual(newKEKId1, newKEKId2)
var keyPolicyProducer createKeyPolicyFunc
client, err := awsInterface.New(context.Background(), &keyPolicyProducer, nil)
client, err := awsInterface.New(context.Background(), &keyPolicyProducer, nil, newKEKId1)
require.NoError(err)
privateKEK1 := []byte(strings.Repeat("1234", 8))
@ -166,14 +166,14 @@ func TestAwsKms(t *testing.T) {
assert.NoError(client.CreateKEK(ctx, newKEKId1, privateKEK2))
// make sure that GetDEK is idempotent
volumeKey1, err := client.GetDEK(ctx, newKEKId1, "volume01", kmsconfig.SymmetricKeyLength)
volumeKey1, err := client.GetDEK(ctx, "volume01", kmsconfig.SymmetricKeyLength)
require.NoError(err)
volumeKey1Copy, err := client.GetDEK(ctx, newKEKId1, "volume01", kmsconfig.SymmetricKeyLength)
volumeKey1Copy, err := client.GetDEK(ctx, "volume01", kmsconfig.SymmetricKeyLength)
require.NoError(err)
assert.Equal(volumeKey1, volumeKey1Copy)
// test setting a second DEK
volumeKey2, err := client.GetDEK(ctx, newKEKId1, "volume02", kmsconfig.SymmetricKeyLength)
volumeKey2, err := client.GetDEK(ctx, "volume02", kmsconfig.SymmetricKeyLength)
require.NoError(err)
assert.NotEqual(volumeKey1, volumeKey2)
@ -184,7 +184,7 @@ func TestAwsKms(t *testing.T) {
assert.NoError(client.CreateKEK(ctx, newKEKId2, privateKEK3))
// test setting a DEK with AWS KMS generated KEK
volumeKey3, err := client.GetDEK(ctx, newKEKId2, "volume03", kmsconfig.SymmetricKeyLength)
volumeKey3, err := client.GetDEK(ctx, "volume03", kmsconfig.SymmetricKeyLength)
require.NoError(err)
assert.NotEqual(volumeKey1, volumeKey3)

View File

@ -66,22 +66,22 @@ func TestAzureKeyVault(t *testing.T) {
store := storage.NewMemMapStorage()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
client, err := azure.New(ctx, azVaultName, azure.DefaultCloud, store, nil)
kekName := "test-kek"
client, err := azure.New(ctx, azVaultName, azure.DefaultCloud, store, kekName, nil)
require.NoError(err)
kekName := "test-kek"
dekName := "test-dek"
assert.NoError(client.CreateKEK(ctx, kekName, nil))
res, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
res, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err)
res2, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
res2, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err)
assert.Equal(res, res2)
res3, err := client.GetDEK(ctx, kekName, addSuffix(dekName), config.SymmetricKeyLength)
res3, err := client.GetDEK(ctx, addSuffix(dekName), config.SymmetricKeyLength)
assert.NoError(err)
assert.Len(res3, config.SymmetricKeyLength)
assert.NotEqual(res, res3)
@ -102,10 +102,10 @@ func TestAzureHSM(t *testing.T) {
store := storage.NewMemMapStorage()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
client, err := azure.NewHSM(ctx, azHSMName, store, nil)
kekName := "test-kek"
client, err := azure.NewHSM(ctx, azHSMName, store, kekName, nil)
require.NoError(err)
kekName := "test-kek"
dekName := "test-dek"
importedKek := "test-kek-import"
kekData := []byte{0x52, 0xFD, 0xFC, 0x07, 0x21, 0x82, 0x65, 0x4F, 0x16, 0x3F, 0x5F, 0x0F, 0x9A, 0x62, 0x1D, 0x72, 0x95, 0x66, 0xC7, 0x4D, 0x10, 0x03, 0x7C, 0x4D, 0x7B, 0xBB, 0x04, 0x07, 0xD1, 0xE2, 0xC6, 0x49}
@ -114,15 +114,15 @@ func TestAzureHSM(t *testing.T) {
assert.NoError(client.CreateKEK(ctx, kekName, nil))
res, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
res, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
require.NoError(err)
assert.NotNil(res)
res2, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
res2, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
require.NoError(err)
assert.Equal(res, res2)
res3, err := client.GetDEK(ctx, kekName, addSuffix(dekName), config.SymmetricKeyLength)
res3, err := client.GetDEK(ctx, addSuffix(dekName), config.SymmetricKeyLength)
require.NoError(err)
assert.Len(res3, config.SymmetricKeyLength)
assert.NotEqual(res, res3)

View File

@ -42,20 +42,20 @@ func TestCreateGcpKEK(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE)
kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE, kekName)
require.NoError(err)
// Key name is random, but there is a chance we try to create a key that already exists, in that case the test fails
assert.NoError(kmsClient.CreateKEK(ctx, kekName, nil))
res, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
res, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err)
res2, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
res2, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err)
assert.Equal(res, res2)
res3, err := kmsClient.GetDEK(ctx, kekName, addSuffix(dekName), config.SymmetricKeyLength)
res3, err := kmsClient.GetDEK(ctx, addSuffix(dekName), config.SymmetricKeyLength)
assert.NoError(err)
assert.Len(res3, config.SymmetricKeyLength)
assert.NotEqual(res, res3)
@ -76,15 +76,15 @@ func TestImportGcpKEK(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE)
kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE, kekName)
require.NoError(err)
assert.NoError(kmsClient.CreateKEK(ctx, kekName, kekData))
res, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
res, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err)
res2, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
res2, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
assert.NoError(err)
assert.Equal(res, res2)
}