mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
Refactor init/recovery to use kms URI
So far the masterSecret was sent to the initial bootstrapper on init/recovery. With this commit this information is encoded in the kmsURI that is sent during init. For recover, the communication with the recoveryserver is changed. Before a streaming gRPC call was used to exchanges UUID for measurementSecret and state disk key. Now a standard gRPC is made that includes the same kmsURI & storageURI that are sent during init.
This commit is contained in:
parent
0e71322e2e
commit
9a1f52e94e
@ -23,7 +23,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
|
||||
"github.com/edgelesssys/constellation/v2/internal/kms/kms"
|
||||
kmsSetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/edgelesssys/constellation/v2/internal/nodestate"
|
||||
"github.com/edgelesssys/constellation/v2/internal/role"
|
||||
@ -110,8 +110,13 @@ func (s *Server) Init(ctx context.Context, req *initproto.InitRequest) (*initpro
|
||||
return nil, status.Errorf(codes.Internal, "invalid init secret %s", err)
|
||||
}
|
||||
|
||||
cloudKms, err := kmssetup.KMS(ctx, req.StorageUri, req.KmsUri)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating kms client: %w", err)
|
||||
}
|
||||
|
||||
// generate values for cluster attestation
|
||||
measurementSalt, clusterID, err := deriveMeasurementValues(req.MasterSecret, req.Salt)
|
||||
measurementSalt, clusterID, err := deriveMeasurementValues(ctx, cloudKms)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "deriving measurement values: %s", err)
|
||||
}
|
||||
@ -130,7 +135,7 @@ func (s *Server) Init(ctx context.Context, req *initproto.InitRequest) (*initpro
|
||||
return nil, status.Error(codes.FailedPrecondition, "node is already being activated")
|
||||
}
|
||||
|
||||
if err := s.setupDisk(req.MasterSecret, req.Salt); err != nil {
|
||||
if err := s.setupDisk(ctx, cloudKms); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "setting up disk: %s", err)
|
||||
}
|
||||
|
||||
@ -177,7 +182,7 @@ func (s *Server) Stop() {
|
||||
s.log.Infof("Stopped")
|
||||
}
|
||||
|
||||
func (s *Server) setupDisk(masterSecret, salt []byte) error {
|
||||
func (s *Server) setupDisk(ctx context.Context, cloudKms kms.CloudKMS) error {
|
||||
if err := s.disk.Open(); err != nil {
|
||||
return fmt.Errorf("opening encrypted disk: %w", err)
|
||||
}
|
||||
@ -189,7 +194,7 @@ func (s *Server) setupDisk(masterSecret, salt []byte) error {
|
||||
}
|
||||
uuid = strings.ToLower(uuid)
|
||||
|
||||
diskKey, err := crypto.DeriveKey(masterSecret, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.DerivedKeyLengthDefault)
|
||||
diskKey, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+uuid, crypto.StateDiskKeyLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -197,12 +202,12 @@ func (s *Server) setupDisk(masterSecret, salt []byte) error {
|
||||
return s.disk.UpdatePassphrase(string(diskKey))
|
||||
}
|
||||
|
||||
func deriveMeasurementValues(masterSecret, hkdfSalt []byte) (salt, clusterID []byte, err error) {
|
||||
func deriveMeasurementValues(ctx context.Context, cloudKms kms.CloudKMS) (salt, clusterID []byte, err error) {
|
||||
salt, err = crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
secret, err := attestation.DeriveMeasurementSecret(masterSecret, hkdfSalt)
|
||||
secret, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+crypto.MeasurementSecretKeyID, crypto.DerivedKeyLengthDefault)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/internal/atls"
|
||||
"github.com/edgelesssys/constellation/v2/internal/crypto/testvector"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/edgelesssys/constellation/v2/internal/oid"
|
||||
"github.com/edgelesssys/constellation/v2/internal/versions/components"
|
||||
@ -30,7 +31,10 @@ import (
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
goleak.VerifyTestMain(m,
|
||||
// https://github.com/census-instrumentation/opencensus-go/issues/1262
|
||||
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
@ -86,6 +90,8 @@ func TestInit(t *testing.T) {
|
||||
initSecretHash, err := bcrypt.GenerateFromPassword(initSecret, bcrypt.DefaultCost)
|
||||
require.NoError(t, err)
|
||||
|
||||
masterSecret := kmssetup.MasterSecret{Key: []byte("secret"), Salt: []byte("salt")}
|
||||
|
||||
testCases := map[string]struct {
|
||||
nodeLock *fakeLock
|
||||
initializer ClusterInitializer
|
||||
@ -102,14 +108,14 @@ func TestInit(t *testing.T) {
|
||||
disk: &stubDisk{},
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
initSecretHash: initSecretHash,
|
||||
req: &initproto.InitRequest{InitSecret: initSecret},
|
||||
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
|
||||
},
|
||||
"node locked": {
|
||||
nodeLock: lockedLock,
|
||||
initializer: &stubClusterInitializer{},
|
||||
disk: &stubDisk{},
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
req: &initproto.InitRequest{InitSecret: initSecret},
|
||||
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
|
||||
initSecretHash: initSecretHash,
|
||||
wantErr: true,
|
||||
wantShutdown: true,
|
||||
@ -119,7 +125,7 @@ func TestInit(t *testing.T) {
|
||||
initializer: &stubClusterInitializer{},
|
||||
disk: &stubDisk{openErr: someErr},
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
req: &initproto.InitRequest{InitSecret: initSecret},
|
||||
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
|
||||
initSecretHash: initSecretHash,
|
||||
wantErr: true,
|
||||
},
|
||||
@ -128,7 +134,7 @@ func TestInit(t *testing.T) {
|
||||
initializer: &stubClusterInitializer{},
|
||||
disk: &stubDisk{uuidErr: someErr},
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
req: &initproto.InitRequest{InitSecret: initSecret},
|
||||
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
|
||||
initSecretHash: initSecretHash,
|
||||
wantErr: true,
|
||||
},
|
||||
@ -137,7 +143,7 @@ func TestInit(t *testing.T) {
|
||||
initializer: &stubClusterInitializer{},
|
||||
disk: &stubDisk{updatePassphraseErr: someErr},
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
req: &initproto.InitRequest{InitSecret: initSecret},
|
||||
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
|
||||
initSecretHash: initSecretHash,
|
||||
wantErr: true,
|
||||
},
|
||||
@ -146,7 +152,7 @@ func TestInit(t *testing.T) {
|
||||
initializer: &stubClusterInitializer{},
|
||||
disk: &stubDisk{},
|
||||
fileHandler: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())),
|
||||
req: &initproto.InitRequest{InitSecret: initSecret},
|
||||
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
|
||||
initSecretHash: initSecretHash,
|
||||
wantErr: true,
|
||||
},
|
||||
@ -155,7 +161,7 @@ func TestInit(t *testing.T) {
|
||||
initializer: &stubClusterInitializer{initClusterErr: someErr},
|
||||
disk: &stubDisk{},
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
req: &initproto.InitRequest{InitSecret: initSecret},
|
||||
req: &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: kmssetup.NoStoreURI},
|
||||
initSecretHash: initSecretHash,
|
||||
wantErr: true,
|
||||
},
|
||||
@ -211,28 +217,29 @@ func TestInit(t *testing.T) {
|
||||
|
||||
func TestSetupDisk(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
uuid string
|
||||
masterSecret []byte
|
||||
salt []byte
|
||||
wantKey []byte
|
||||
uuid string
|
||||
masterKey []byte
|
||||
salt []byte
|
||||
wantKey []byte
|
||||
}{
|
||||
"lower case uuid": {
|
||||
uuid: strings.ToLower(testvector.HKDF0xFF.Info),
|
||||
masterSecret: testvector.HKDF0xFF.Secret,
|
||||
salt: testvector.HKDF0xFF.Salt,
|
||||
wantKey: testvector.HKDF0xFF.Output,
|
||||
uuid: strings.ToLower(testvector.HKDF0xFF.Info),
|
||||
masterKey: testvector.HKDF0xFF.Secret,
|
||||
salt: testvector.HKDF0xFF.Salt,
|
||||
wantKey: testvector.HKDF0xFF.Output,
|
||||
},
|
||||
"upper case uuid": {
|
||||
uuid: strings.ToUpper(testvector.HKDF0xFF.Info),
|
||||
masterSecret: testvector.HKDF0xFF.Secret,
|
||||
salt: testvector.HKDF0xFF.Salt,
|
||||
wantKey: testvector.HKDF0xFF.Output,
|
||||
uuid: strings.ToUpper(testvector.HKDF0xFF.Info),
|
||||
masterKey: testvector.HKDF0xFF.Secret,
|
||||
salt: testvector.HKDF0xFF.Salt,
|
||||
wantKey: testvector.HKDF0xFF.Output,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
disk := &fakeDisk{
|
||||
uuid: tc.uuid,
|
||||
@ -242,7 +249,11 @@ func TestSetupDisk(t *testing.T) {
|
||||
disk: disk,
|
||||
}
|
||||
|
||||
assert.NoError(server.setupDisk(tc.masterSecret, tc.salt))
|
||||
masterSecret := kmssetup.MasterSecret{Key: tc.masterKey, Salt: tc.salt}
|
||||
|
||||
cloudKms, err := kmssetup.KMS(context.Background(), kmssetup.NoStoreURI, masterSecret.EncodeToURI())
|
||||
require.NoError(err)
|
||||
assert.NoError(server.setupDisk(context.Background(), cloudKms))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||||
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
|
||||
keyservice "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
"github.com/edgelesssys/constellation/v2/internal/license"
|
||||
"github.com/edgelesssys/constellation/v2/internal/retry"
|
||||
"github.com/edgelesssys/constellation/v2/internal/versions"
|
||||
@ -146,8 +146,8 @@ func (i *initCmd) initialize(cmd *cobra.Command, newDialer func(validator *cloud
|
||||
req := &initproto.InitRequest{
|
||||
MasterSecret: masterSecret.Key,
|
||||
Salt: masterSecret.Salt,
|
||||
KmsUri: keyservice.ClusterKMSURI,
|
||||
StorageUri: keyservice.NoStoreURI,
|
||||
KmsUri: masterSecret.EncodeToURI(),
|
||||
StorageUri: kmssetup.NoStoreURI,
|
||||
KeyEncryptionKeyId: "",
|
||||
UseExistingKek: false,
|
||||
CloudServiceAccountUri: serviceAccURI,
|
||||
@ -296,26 +296,20 @@ type initFlags struct {
|
||||
conformance bool
|
||||
}
|
||||
|
||||
// masterSecret holds the master key and salt for deriving keys.
|
||||
type masterSecret struct {
|
||||
Key []byte `json:"key"`
|
||||
Salt []byte `json:"salt"`
|
||||
}
|
||||
|
||||
// readOrGenerateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
|
||||
func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler file.Handler, filename string) (masterSecret, error) {
|
||||
func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler file.Handler, filename string) (kmssetup.MasterSecret, error) {
|
||||
if filename != "" {
|
||||
i.log.Debugf("Reading master secret from file %q", filename)
|
||||
var secret masterSecret
|
||||
var secret kmssetup.MasterSecret
|
||||
if err := fileHandler.ReadJSON(filename, &secret); err != nil {
|
||||
return masterSecret{}, err
|
||||
return kmssetup.MasterSecret{}, err
|
||||
}
|
||||
|
||||
if len(secret.Key) < crypto.MasterSecretLengthMin {
|
||||
return masterSecret{}, fmt.Errorf("provided master secret is smaller than the required minimum of %d Bytes", crypto.MasterSecretLengthMin)
|
||||
return kmssetup.MasterSecret{}, fmt.Errorf("provided master secret is smaller than the required minimum of %d Bytes", crypto.MasterSecretLengthMin)
|
||||
}
|
||||
if len(secret.Salt) < crypto.RNGLengthDefault {
|
||||
return masterSecret{}, fmt.Errorf("provided salt is smaller than the required minimum of %d Bytes", crypto.RNGLengthDefault)
|
||||
return kmssetup.MasterSecret{}, fmt.Errorf("provided salt is smaller than the required minimum of %d Bytes", crypto.RNGLengthDefault)
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
@ -324,19 +318,19 @@ func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler fi
|
||||
i.log.Debugf("Generating new master secret")
|
||||
key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault)
|
||||
if err != nil {
|
||||
return masterSecret{}, err
|
||||
return kmssetup.MasterSecret{}, err
|
||||
}
|
||||
salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
|
||||
if err != nil {
|
||||
return masterSecret{}, err
|
||||
return kmssetup.MasterSecret{}, err
|
||||
}
|
||||
secret := masterSecret{
|
||||
secret := kmssetup.MasterSecret{
|
||||
Key: key,
|
||||
Salt: salt,
|
||||
}
|
||||
i.log.Debugf("Generated master secret key and salt values")
|
||||
if err := fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
|
||||
return masterSecret{}, err
|
||||
return kmssetup.MasterSecret{}, err
|
||||
}
|
||||
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to ./%s\n", constants.MasterSecretFilename)
|
||||
return secret, nil
|
||||
|
@ -18,6 +18,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
|
||||
@ -183,7 +185,7 @@ func TestInitialize(t *testing.T) {
|
||||
require.NoError(err)
|
||||
// assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID")))
|
||||
assert.Contains(out.String(), hex.EncodeToString([]byte("clusterID")))
|
||||
var secret masterSecret
|
||||
var secret kmssetup.MasterSecret
|
||||
assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret))
|
||||
assert.NotEmpty(secret.Key)
|
||||
assert.NotEmpty(secret.Salt)
|
||||
@ -251,7 +253,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
|
||||
createFileFunc: func(handler file.Handler) error {
|
||||
return handler.WriteJSON(
|
||||
"someSecret",
|
||||
masterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")},
|
||||
kmssetup.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")},
|
||||
file.OptNone,
|
||||
)
|
||||
},
|
||||
@ -282,7 +284,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
|
||||
createFileFunc: func(handler file.Handler) error {
|
||||
return handler.WriteJSON(
|
||||
"shortSecret",
|
||||
masterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("short")},
|
||||
kmssetup.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("short")},
|
||||
file.OptNone,
|
||||
)
|
||||
},
|
||||
@ -294,7 +296,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
|
||||
createFileFunc: func(handler file.Handler) error {
|
||||
return handler.WriteJSON(
|
||||
"shortSecret",
|
||||
masterSecret{Key: []byte("short"), Salt: []byte("constellation-32Byte-length-salt")},
|
||||
kmssetup.MasterSecret{Key: []byte("short"), Salt: []byte("constellation-32Byte-length-salt")},
|
||||
file.OptNone,
|
||||
)
|
||||
},
|
||||
@ -340,7 +342,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
|
||||
tc.filename = strings.Trim(filename[1], "\n")
|
||||
}
|
||||
|
||||
var masterSecret masterSecret
|
||||
var masterSecret kmssetup.MasterSecret
|
||||
require.NoError(fileHandler.ReadJSON(tc.filename, &masterSecret))
|
||||
assert.Equal(masterSecret.Key, secret.Key)
|
||||
assert.Equal(masterSecret.Salt, secret.Salt)
|
||||
|
@ -8,7 +8,6 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -17,7 +16,6 @@ import (
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
@ -25,6 +23,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||||
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
"github.com/edgelesssys/constellation/v2/internal/retry"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
@ -73,7 +72,7 @@ func (r *recoverCmd) recover(
|
||||
}
|
||||
r.log.Debugf("Using flags: %+v", flags)
|
||||
|
||||
var masterSecret masterSecret
|
||||
var masterSecret kmssetup.MasterSecret
|
||||
r.log.Debugf("Loading master secret file from %s", flags.secretPath)
|
||||
if err := fileHandler.ReadJSON(flags.secretPath, &masterSecret); err != nil {
|
||||
return err
|
||||
@ -97,12 +96,7 @@ func (r *recoverCmd) recover(
|
||||
r.log.Debugf("Created a new validator")
|
||||
doer.setDialer(newDialer(validator), flags.endpoint)
|
||||
r.log.Debugf("Set dialer for endpoint %s", flags.endpoint)
|
||||
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt)
|
||||
r.log.Debugf("Derived measurementSecret")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
doer.setSecrets(getStateDiskKeyFunc(masterSecret.Key, masterSecret.Salt), measurementSecret)
|
||||
doer.setURIs(masterSecret.EncodeToURI(), kmssetup.NoStoreURI)
|
||||
r.log.Debugf("Set secrets")
|
||||
if err := r.recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil {
|
||||
if grpcRetry.ServiceIsUnavailable(err) {
|
||||
@ -157,15 +151,15 @@ func (r *recoverCmd) recoverCall(ctx context.Context, out io.Writer, interval ti
|
||||
type recoverDoerInterface interface {
|
||||
Do(ctx context.Context) error
|
||||
setDialer(dialer grpcDialer, endpoint string)
|
||||
setSecrets(getDiskKey func(uuid string) ([]byte, error), measurementSecret []byte)
|
||||
setURIs(kmsURI, storageURI string)
|
||||
}
|
||||
|
||||
type recoverDoer struct {
|
||||
dialer grpcDialer
|
||||
endpoint string
|
||||
measurementSecret []byte
|
||||
getDiskKey func(uuid string) (key []byte, err error)
|
||||
log debugLog
|
||||
dialer grpcDialer
|
||||
endpoint string
|
||||
kmsURI string // encodes masterSecret
|
||||
storageURI string
|
||||
log debugLog
|
||||
}
|
||||
|
||||
// Do performs the recover streaming rpc.
|
||||
@ -177,53 +171,19 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
|
||||
d.log.Debugf("Dialed recovery server")
|
||||
defer conn.Close()
|
||||
|
||||
// set up streaming client
|
||||
protoClient := recoverproto.NewAPIClient(conn)
|
||||
d.log.Debugf("Created protoClient")
|
||||
recoverclient, err := protoClient.Recover(ctx)
|
||||
d.log.Debugf("Created recoverclient")
|
||||
|
||||
req := &recoverproto.RecoverMessage{
|
||||
KmsUri: d.kmsURI,
|
||||
StorageUri: d.storageURI,
|
||||
}
|
||||
|
||||
_, err = protoClient.Recover(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
return fmt.Errorf("calling recover: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = recoverclient.CloseSend()
|
||||
}()
|
||||
|
||||
// send measurement secret as first message
|
||||
if err := recoverclient.Send(&recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: d.measurementSecret,
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("sending measurement secret: %w", err)
|
||||
}
|
||||
d.log.Debugf("Sent measurement secret")
|
||||
|
||||
// receive disk uuid
|
||||
res, err := recoverclient.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receiving disk uuid: %w", err)
|
||||
}
|
||||
d.log.Debugf("Received disk uuid")
|
||||
stateDiskKey, err := d.getDiskKey(res.DiskUuid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting state disk key: %w", err)
|
||||
}
|
||||
d.log.Debugf("Got state disk key")
|
||||
|
||||
// send disk key
|
||||
if err := recoverclient.Send(&recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: stateDiskKey,
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("sending state disk key: %w", err)
|
||||
}
|
||||
d.log.Debugf("Sent state disk key")
|
||||
|
||||
if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
|
||||
return fmt.Errorf("receiving confirmation: %w", err)
|
||||
}
|
||||
d.log.Debugf("Received confirmation")
|
||||
return nil
|
||||
}
|
||||
@ -233,9 +193,9 @@ func (d *recoverDoer) setDialer(dialer grpcDialer, endpoint string) {
|
||||
d.endpoint = endpoint
|
||||
}
|
||||
|
||||
func (d *recoverDoer) setSecrets(getDiskKey func(string) ([]byte, error), measurementSecret []byte) {
|
||||
d.getDiskKey = getDiskKey
|
||||
d.measurementSecret = measurementSecret
|
||||
func (d *recoverDoer) setURIs(kmsURI, storageURI string) {
|
||||
d.kmsURI = kmsURI
|
||||
d.storageURI = storageURI
|
||||
}
|
||||
|
||||
type recoverFlags struct {
|
||||
@ -281,6 +241,6 @@ func (r *recoverCmd) parseRecoverFlags(cmd *cobra.Command, fileHandler file.Hand
|
||||
|
||||
func getStateDiskKeyFunc(masterKey, salt []byte) func(uuid string) ([]byte, error) {
|
||||
return func(uuid string) ([]byte, error) {
|
||||
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.StateDiskKeyLength)
|
||||
return crypto.DeriveKey(masterKey, salt, []byte(crypto.DEKPrefix+uuid), crypto.StateDiskKeyLength)
|
||||
}
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -156,7 +157,7 @@ func TestRecover(t *testing.T) {
|
||||
|
||||
require.NoError(fileHandler.WriteJSON(
|
||||
"constellation-mastersecret.json",
|
||||
masterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
|
||||
kmssetup.MasterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
|
||||
file.OptNone,
|
||||
))
|
||||
|
||||
@ -244,101 +245,16 @@ func TestParseRecoverFlags(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDoRecovery(t *testing.T) {
|
||||
someErr := errors.New("error")
|
||||
testCases := map[string]struct {
|
||||
recoveryServer *stubRecoveryServer
|
||||
wantErr bool
|
||||
}{
|
||||
"success": {
|
||||
recoveryServer: &stubRecoveryServer{
|
||||
actions: [][]func(stream recoverproto.API_RecoverServer) error{{
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return stream.Send(&recoverproto.RecoverResponse{
|
||||
DiskUuid: "00000000-0000-0000-0000-000000000000",
|
||||
})
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
}},
|
||||
},
|
||||
recoveryServer: &stubRecoveryServer{},
|
||||
},
|
||||
"error on first recv": {
|
||||
recoveryServer: &stubRecoveryServer{
|
||||
actions: [][]func(stream recoverproto.API_RecoverServer) error{
|
||||
{
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return someErr
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"error on send": {
|
||||
recoveryServer: &stubRecoveryServer{
|
||||
actions: [][]func(stream recoverproto.API_RecoverServer) error{
|
||||
{
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return someErr
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"error on second recv": {
|
||||
recoveryServer: &stubRecoveryServer{
|
||||
actions: [][]func(stream recoverproto.API_RecoverServer) error{
|
||||
{
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return stream.Send(&recoverproto.RecoverResponse{
|
||||
DiskUuid: "00000000-0000-0000-0000-000000000000",
|
||||
})
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return someErr
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"final message is an error": {
|
||||
recoveryServer: &stubRecoveryServer{
|
||||
actions: [][]func(stream recoverproto.API_RecoverServer) error{{
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return stream.Send(&recoverproto.RecoverResponse{
|
||||
DiskUuid: "00000000-0000-0000-0000-000000000000",
|
||||
})
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return someErr
|
||||
},
|
||||
}},
|
||||
},
|
||||
wantErr: true,
|
||||
"server responds with error": {
|
||||
recoveryServer: &stubRecoveryServer{recoverError: errors.New("someErr")},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
@ -357,13 +273,9 @@ func TestDoRecovery(t *testing.T) {
|
||||
|
||||
r := &recoverCmd{log: logger.NewTest(t)}
|
||||
recoverDoer := &recoverDoer{
|
||||
dialer: dialer.New(nil, nil, netDialer),
|
||||
endpoint: addr,
|
||||
measurementSecret: []byte("measurement-secret"),
|
||||
getDiskKey: func(string) ([]byte, error) {
|
||||
return []byte("disk-key"), nil
|
||||
},
|
||||
log: r.log,
|
||||
dialer: dialer.New(nil, nil, netDialer),
|
||||
endpoint: addr,
|
||||
log: r.log,
|
||||
}
|
||||
|
||||
err := recoverDoer.Do(context.Background())
|
||||
@ -402,23 +314,15 @@ func TestDeriveStateDiskKey(t *testing.T) {
|
||||
}
|
||||
|
||||
type stubRecoveryServer struct {
|
||||
actions [][]func(recoverproto.API_RecoverServer) error
|
||||
calls int
|
||||
recoverError error
|
||||
recoverproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
func (s *stubRecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
|
||||
if s.calls >= len(s.actions) {
|
||||
return status.Error(codes.Unavailable, "server is unavailable")
|
||||
func (s *stubRecoveryServer) Recover(context.Context, *recoverproto.RecoverMessage) (*recoverproto.RecoverResponse, error) {
|
||||
if s.recoverError != nil {
|
||||
return nil, s.recoverError
|
||||
}
|
||||
s.calls++
|
||||
|
||||
for _, action := range s.actions[s.calls-1] {
|
||||
if err := action(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return &recoverproto.RecoverResponse{}, nil
|
||||
}
|
||||
|
||||
type stubDoer struct {
|
||||
@ -437,4 +341,4 @@ func (d *stubDoer) Do(context.Context) error {
|
||||
|
||||
func (d *stubDoer) setDialer(grpcDialer, string) {}
|
||||
|
||||
func (d *stubDoer) setSecrets(func(string) ([]byte, error), []byte) {}
|
||||
func (d *stubDoer) setURIs(kmsURI, storageURI string) {}
|
||||
|
@ -34,6 +34,7 @@ import (
|
||||
qemucloud "github.com/edgelesssys/constellation/v2/internal/cloud/qemu"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/edgelesssys/constellation/v2/internal/role"
|
||||
tpmClient "github.com/google/go-tpm-tools/client"
|
||||
@ -151,7 +152,7 @@ func main() {
|
||||
// set up recovery server if control-plane node
|
||||
var recoveryServer setup.RecoveryServer
|
||||
if self.Role == role.ControlPlane {
|
||||
recoveryServer = recoveryserver.New(issuer, log.Named("recoveryServer"))
|
||||
recoveryServer = recoveryserver.New(issuer, kmssetup.KMS, log.Named("recoveryServer"))
|
||||
} else {
|
||||
recoveryServer = recoveryserver.NewStub(log.Named("recoveryServer"))
|
||||
}
|
||||
|
@ -13,8 +13,10 @@ import (
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/atls"
|
||||
"github.com/edgelesssys/constellation/v2/internal/crypto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
|
||||
"github.com/edgelesssys/constellation/v2/internal/kms/kms"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
@ -22,6 +24,8 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type kmsFactory func(ctx context.Context, storageURI string, kmsURI string) (kms.CloudKMS, error)
|
||||
|
||||
// RecoveryServer is a gRPC server that can be used by an admin to recover a restarting node.
|
||||
type RecoveryServer struct {
|
||||
mux sync.Mutex
|
||||
@ -30,6 +34,7 @@ type RecoveryServer struct {
|
||||
stateDiskKey []byte
|
||||
measurementSecret []byte
|
||||
grpcServer server
|
||||
factory kmsFactory
|
||||
|
||||
log *logger.Logger
|
||||
|
||||
@ -37,9 +42,10 @@ type RecoveryServer struct {
|
||||
}
|
||||
|
||||
// New returns a new RecoveryServer.
|
||||
func New(issuer atls.Issuer, log *logger.Logger) *RecoveryServer {
|
||||
func New(issuer atls.Issuer, factory kmsFactory, log *logger.Logger) *RecoveryServer {
|
||||
server := &RecoveryServer{
|
||||
log: log,
|
||||
log: log,
|
||||
factory: factory,
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
@ -87,47 +93,32 @@ func (s *RecoveryServer) Serve(ctx context.Context, listener net.Listener, diskU
|
||||
}
|
||||
|
||||
// Recover is a bidirectional streaming RPC that is used to send recovery keys to a restarting node.
|
||||
func (s *RecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
|
||||
func (s *RecoveryServer) Recover(ctx context.Context, req *recoverproto.RecoverMessage) (*recoverproto.RecoverResponse, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log := s.log.With(zap.String("peer", grpclog.PeerAddrFromContext(stream.Context())))
|
||||
log := s.log.With(zap.String("peer", grpclog.PeerAddrFromContext(ctx)))
|
||||
|
||||
log.Infof("Received recover call")
|
||||
|
||||
msg, err := stream.Recv()
|
||||
cloudKms, err := s.factory(ctx, req.StorageUri, req.KmsUri)
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, "failed to receive message")
|
||||
return nil, status.Errorf(codes.Internal, "creating kms client: %s", err)
|
||||
}
|
||||
|
||||
measurementSecret, ok := msg.GetRequest().(*recoverproto.RecoverMessage_MeasurementSecret)
|
||||
if !ok {
|
||||
log.Errorf("Received invalid first message: not a measurement secret")
|
||||
return status.Error(codes.InvalidArgument, "first message is not a measurement secret")
|
||||
}
|
||||
|
||||
if err := stream.Send(&recoverproto.RecoverResponse{DiskUuid: s.diskUUID}); err != nil {
|
||||
log.With(zap.Error(err)).Errorf("Failed to send disk UUID")
|
||||
return status.Error(codes.Internal, "failed to send response")
|
||||
}
|
||||
|
||||
msg, err = stream.Recv()
|
||||
measurementSecret, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+crypto.MeasurementSecretKeyID, crypto.DerivedKeyLengthDefault)
|
||||
if err != nil {
|
||||
log.With(zap.Error(err)).Errorf("Failed to receive disk key")
|
||||
return status.Error(codes.Internal, "failed to receive message")
|
||||
return nil, status.Errorf(codes.Internal, "requesting measurementSecret: %s", err)
|
||||
}
|
||||
|
||||
stateDiskKey, ok := msg.GetRequest().(*recoverproto.RecoverMessage_StateDiskKey)
|
||||
if !ok {
|
||||
log.Errorf("Received invalid second message: not a state disk key")
|
||||
return status.Error(codes.InvalidArgument, "second message is not a state disk key")
|
||||
stateDiskKey, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+s.diskUUID, crypto.StateDiskKeyLength)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "requesting stateDiskKey: %s", err)
|
||||
}
|
||||
|
||||
s.stateDiskKey = stateDiskKey.StateDiskKey
|
||||
s.measurementSecret = measurementSecret.MeasurementSecret
|
||||
s.stateDiskKey = stateDiskKey
|
||||
s.measurementSecret = measurementSecret
|
||||
log.Infof("Received state disk key and measurement secret, shutting down server")
|
||||
|
||||
go s.grpcServer.GracefulStop()
|
||||
return nil
|
||||
return &recoverproto.RecoverResponse{}, nil
|
||||
}
|
||||
|
||||
// StubServer implements the RecoveryServer interface but does not actually start a server.
|
||||
|
@ -8,7 +8,7 @@ package recoveryserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@ -17,6 +17,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/internal/atls"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
|
||||
"github.com/edgelesssys/constellation/v2/internal/kms/kms"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/edgelesssys/constellation/v2/internal/oid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -25,14 +26,17 @@ import (
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
goleak.VerifyTestMain(m,
|
||||
// https://github.com/census-instrumentation/opencensus-go/issues/1262
|
||||
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestServe(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
log := logger.NewTest(t)
|
||||
uuid := "uuid"
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), log)
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), newStubKMS(nil, nil), log)
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
listener := dialer.GetListener("192.0.2.1:1234")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@ -49,7 +53,7 @@ func TestServe(t *testing.T) {
|
||||
cancel()
|
||||
wg.Wait()
|
||||
|
||||
server = New(atls.NewFakeIssuer(oid.Dummy{}), log)
|
||||
server = New(atls.NewFakeIssuer(oid.Dummy{}), newStubKMS(nil, nil), log)
|
||||
dialer = testdialer.NewBufconnDialer()
|
||||
listener = dialer.GetListener("192.0.2.1:1234")
|
||||
|
||||
@ -71,59 +75,26 @@ func TestServe(t *testing.T) {
|
||||
|
||||
func TestRecover(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
initialMsg message
|
||||
keyMsg message
|
||||
kmsURI string
|
||||
storageURI string
|
||||
factory kmsFactory
|
||||
wantErr bool
|
||||
}{
|
||||
"success": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
},
|
||||
// base64 encoded: key=masterkey&salt=somesalt
|
||||
kmsURI: "kms://cluster-kms?key=bWFzdGVya2V5&salt=c29tZXNhbHQ=",
|
||||
storageURI: "storage://no-store",
|
||||
factory: newStubKMS(nil, nil),
|
||||
},
|
||||
"first message is not a measurement secret": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"kms init fails": {
|
||||
factory: newStubKMS(errors.New("setup failed"), nil),
|
||||
wantErr: true,
|
||||
},
|
||||
"second message is not a state disk key": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"GetDEK fails": {
|
||||
kmsURI: "kms://cluster-kms?key=bWFzdGVya2V5&salt=c29tZXNhbHQ=",
|
||||
storageURI: "storage://no-store",
|
||||
factory: newStubKMS(nil, errors.New("GetDEK failed")),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
@ -134,7 +105,7 @@ func TestRecover(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
serverUUID := "uuid"
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), logger.NewTest(t))
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), tc.factory, logger.NewTest(t))
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
listener := netDialer.GetListener("192.0.2.1:1234")
|
||||
|
||||
@ -154,41 +125,46 @@ func TestRecover(t *testing.T) {
|
||||
conn, err := dialer.New(nil, nil, netDialer).Dial(ctx, "192.0.2.1:1234")
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
client, err := recoverproto.NewAPIClient(conn).Recover(ctx)
|
||||
require.NoError(err)
|
||||
|
||||
// Send initial message
|
||||
err = client.Send(tc.initialMsg.recoverMsg)
|
||||
require.NoError(err)
|
||||
req := recoverproto.RecoverMessage{
|
||||
KmsUri: tc.kmsURI,
|
||||
StorageUri: tc.storageURI,
|
||||
}
|
||||
_, err = recoverproto.NewAPIClient(conn).Recover(ctx, &req)
|
||||
|
||||
// Receive uuid
|
||||
uuid, err := client.Recv()
|
||||
if tc.initialMsg.wantErr {
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.Equal(serverUUID, uuid.DiskUuid)
|
||||
|
||||
// Send key message
|
||||
err = client.Send(tc.keyMsg.recoverMsg)
|
||||
require.NoError(err)
|
||||
|
||||
_, err = client.Recv()
|
||||
if tc.keyMsg.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.ErrorIs(io.EOF, err)
|
||||
|
||||
wg.Wait()
|
||||
assert.NoError(serveErr)
|
||||
assert.Equal(tc.initialMsg.recoverMsg.GetMeasurementSecret(), measurementSecret)
|
||||
assert.Equal(tc.keyMsg.recoverMsg.GetStateDiskKey(), diskKey)
|
||||
require.NoError(serveErr)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(measurementSecret)
|
||||
assert.NotNil(diskKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type message struct {
|
||||
recoverMsg *recoverproto.RecoverMessage
|
||||
wantErr bool
|
||||
func newStubKMS(setupErr, getDEKErr error) kmsFactory {
|
||||
return func(ctx context.Context, storageURI string, kmsURI string) (kms.CloudKMS, error) {
|
||||
if setupErr != nil {
|
||||
return nil, setupErr
|
||||
}
|
||||
return &stubKMS{getDEKErr: getDEKErr}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type stubKMS struct {
|
||||
getDEKErr error
|
||||
}
|
||||
|
||||
func (s *stubKMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubKMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
|
||||
if s.getDEKErr != nil {
|
||||
return nil, s.getDEKErr
|
||||
}
|
||||
return []byte("someDEK"), nil
|
||||
}
|
||||
|
@ -25,11 +25,10 @@ type RecoverMessage struct {
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Types that are assignable to Request:
|
||||
//
|
||||
// *RecoverMessage_StateDiskKey
|
||||
// *RecoverMessage_MeasurementSecret
|
||||
Request isRecoverMessage_Request `protobuf_oneof:"request"`
|
||||
// bytes state_disk_key = 1; removed
|
||||
// bytes measurement_secret = 2; removed
|
||||
KmsUri string `protobuf:"bytes,3,opt,name=kms_uri,json=kmsUri,proto3" json:"kms_uri,omitempty"`
|
||||
StorageUri string `protobuf:"bytes,4,opt,name=storage_uri,json=storageUri,proto3" json:"storage_uri,omitempty"`
|
||||
}
|
||||
|
||||
func (x *RecoverMessage) Reset() {
|
||||
@ -64,49 +63,24 @@ func (*RecoverMessage) Descriptor() ([]byte, []int) {
|
||||
return file_recover_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (m *RecoverMessage) GetRequest() isRecoverMessage_Request {
|
||||
if m != nil {
|
||||
return m.Request
|
||||
func (x *RecoverMessage) GetKmsUri() string {
|
||||
if x != nil {
|
||||
return x.KmsUri
|
||||
}
|
||||
return nil
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RecoverMessage) GetStateDiskKey() []byte {
|
||||
if x, ok := x.GetRequest().(*RecoverMessage_StateDiskKey); ok {
|
||||
return x.StateDiskKey
|
||||
func (x *RecoverMessage) GetStorageUri() string {
|
||||
if x != nil {
|
||||
return x.StorageUri
|
||||
}
|
||||
return nil
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RecoverMessage) GetMeasurementSecret() []byte {
|
||||
if x, ok := x.GetRequest().(*RecoverMessage_MeasurementSecret); ok {
|
||||
return x.MeasurementSecret
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type isRecoverMessage_Request interface {
|
||||
isRecoverMessage_Request()
|
||||
}
|
||||
|
||||
type RecoverMessage_StateDiskKey struct {
|
||||
StateDiskKey []byte `protobuf:"bytes,1,opt,name=state_disk_key,json=stateDiskKey,proto3,oneof"`
|
||||
}
|
||||
|
||||
type RecoverMessage_MeasurementSecret struct {
|
||||
MeasurementSecret []byte `protobuf:"bytes,2,opt,name=measurement_secret,json=measurementSecret,proto3,oneof"`
|
||||
}
|
||||
|
||||
func (*RecoverMessage_StateDiskKey) isRecoverMessage_Request() {}
|
||||
|
||||
func (*RecoverMessage_MeasurementSecret) isRecoverMessage_Request() {}
|
||||
|
||||
type RecoverResponse struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
DiskUuid string `protobuf:"bytes,1,opt,name=disk_uuid,json=diskUuid,proto3" json:"disk_uuid,omitempty"`
|
||||
}
|
||||
|
||||
func (x *RecoverResponse) Reset() {
|
||||
@ -141,39 +115,27 @@ func (*RecoverResponse) Descriptor() ([]byte, []int) {
|
||||
return file_recover_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *RecoverResponse) GetDiskUuid() string {
|
||||
if x != nil {
|
||||
return x.DiskUuid
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_recover_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_recover_proto_rawDesc = []byte{
|
||||
0x0a, 0x0d, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12,
|
||||
0x0c, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x74, 0x0a,
|
||||
0x0c, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x4a, 0x0a,
|
||||
0x0e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12,
|
||||
0x26, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x6b, 0x65,
|
||||
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x0c, 0x73, 0x74, 0x61, 0x74, 0x65,
|
||||
0x44, 0x69, 0x73, 0x6b, 0x4b, 0x65, 0x79, 0x12, 0x2f, 0x0a, 0x12, 0x6d, 0x65, 0x61, 0x73, 0x75,
|
||||
0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20,
|
||||
0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x11, 0x6d, 0x65, 0x61, 0x73, 0x75, 0x72, 0x65, 0x6d, 0x65,
|
||||
0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75,
|
||||
0x65, 0x73, 0x74, 0x22, 0x2e, 0x0a, 0x0f, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65,
|
||||
0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x69, 0x73, 0x6b, 0x5f, 0x75,
|
||||
0x75, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x69, 0x73, 0x6b, 0x55,
|
||||
0x75, 0x69, 0x64, 0x32, 0x53, 0x0a, 0x03, 0x41, 0x50, 0x49, 0x12, 0x4c, 0x0a, 0x07, 0x52, 0x65,
|
||||
0x63, 0x6f, 0x76, 0x65, 0x72, 0x12, 0x1c, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73,
|
||||
0x61, 0x67, 0x65, 0x1a, 0x1d, 0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f,
|
||||
0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
|
||||
0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x42, 0x5a, 0x40, 0x67, 0x69, 0x74, 0x68,
|
||||
0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65, 0x6c, 0x65, 0x73, 0x73, 0x73,
|
||||
0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c, 0x61, 0x74, 0x69, 0x6f, 0x6e,
|
||||
0x2f, 0x76, 0x32, 0x2f, 0x64, 0x69, 0x73, 0x6b, 0x2d, 0x6d, 0x61, 0x70, 0x70, 0x65, 0x72, 0x2f,
|
||||
0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x33,
|
||||
0x17, 0x0a, 0x07, 0x6b, 0x6d, 0x73, 0x5f, 0x75, 0x72, 0x69, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x06, 0x6b, 0x6d, 0x73, 0x55, 0x72, 0x69, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x74, 0x6f, 0x72,
|
||||
0x61, 0x67, 0x65, 0x5f, 0x75, 0x72, 0x69, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73,
|
||||
0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x55, 0x72, 0x69, 0x22, 0x11, 0x0a, 0x0f, 0x52, 0x65, 0x63,
|
||||
0x6f, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x32, 0x4f, 0x0a, 0x03,
|
||||
0x41, 0x50, 0x49, 0x12, 0x48, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x12, 0x1c,
|
||||
0x2e, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65,
|
||||
0x63, 0x6f, 0x76, 0x65, 0x72, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1d, 0x2e, 0x72,
|
||||
0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x63, 0x6f,
|
||||
0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x42, 0x5a,
|
||||
0x40, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x64, 0x67, 0x65,
|
||||
0x6c, 0x65, 0x73, 0x73, 0x73, 0x79, 0x73, 0x2f, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x65, 0x6c, 0x6c,
|
||||
0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x76, 0x32, 0x2f, 0x64, 0x69, 0x73, 0x6b, 0x2d, 0x6d, 0x61,
|
||||
0x70, 0x70, 0x65, 0x72, 0x2f, 0x72, 0x65, 0x63, 0x6f, 0x76, 0x65, 0x72, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
@ -234,10 +196,6 @@ func file_recover_proto_init() {
|
||||
}
|
||||
}
|
||||
}
|
||||
file_recover_proto_msgTypes[0].OneofWrappers = []interface{}{
|
||||
(*RecoverMessage_StateDiskKey)(nil),
|
||||
(*RecoverMessage_MeasurementSecret)(nil),
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
|
@ -5,16 +5,19 @@ package recoverproto;
|
||||
option go_package = "github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto";
|
||||
|
||||
service API {
|
||||
rpc Recover(stream RecoverMessage) returns (stream RecoverResponse) {}
|
||||
// Recover sends the necessary information to the recoveryserver to start recovering the cluster.
|
||||
rpc Recover(RecoverMessage) returns (RecoverResponse) {}
|
||||
}
|
||||
|
||||
message RecoverMessage {
|
||||
oneof request {
|
||||
bytes state_disk_key = 1;
|
||||
bytes measurement_secret = 2;
|
||||
}
|
||||
// bytes state_disk_key = 1; removed
|
||||
// bytes measurement_secret = 2; removed
|
||||
// kms_uri is the URI of the KMS the recoveryserver should use to decrypt DEKs.
|
||||
string kms_uri = 3;
|
||||
// storage_uri is the URI of the storage location the recoveryserver should use to fetch DEKs.
|
||||
string storage_uri = 4;
|
||||
}
|
||||
|
||||
message RecoverResponse {
|
||||
string disk_uuid = 1;
|
||||
// string disk_uuid = 1; removed
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ const _ = grpc.SupportPackageIsVersion7
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type APIClient interface {
|
||||
Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error)
|
||||
Recover(ctx context.Context, in *RecoverMessage, opts ...grpc.CallOption) (*RecoverResponse, error)
|
||||
}
|
||||
|
||||
type aPIClient struct {
|
||||
@ -33,42 +33,20 @@ func NewAPIClient(cc grpc.ClientConnInterface) APIClient {
|
||||
return &aPIClient{cc}
|
||||
}
|
||||
|
||||
func (c *aPIClient) Recover(ctx context.Context, opts ...grpc.CallOption) (API_RecoverClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &API_ServiceDesc.Streams[0], "/recoverproto.API/Recover", opts...)
|
||||
func (c *aPIClient) Recover(ctx context.Context, in *RecoverMessage, opts ...grpc.CallOption) (*RecoverResponse, error) {
|
||||
out := new(RecoverResponse)
|
||||
err := c.cc.Invoke(ctx, "/recoverproto.API/Recover", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &aPIRecoverClient{stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type API_RecoverClient interface {
|
||||
Send(*RecoverMessage) error
|
||||
Recv() (*RecoverResponse, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type aPIRecoverClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *aPIRecoverClient) Send(m *RecoverMessage) error {
|
||||
return x.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *aPIRecoverClient) Recv() (*RecoverResponse, error) {
|
||||
m := new(RecoverResponse)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// APIServer is the server API for API service.
|
||||
// All implementations must embed UnimplementedAPIServer
|
||||
// for forward compatibility
|
||||
type APIServer interface {
|
||||
Recover(API_RecoverServer) error
|
||||
Recover(context.Context, *RecoverMessage) (*RecoverResponse, error)
|
||||
mustEmbedUnimplementedAPIServer()
|
||||
}
|
||||
|
||||
@ -76,8 +54,8 @@ type APIServer interface {
|
||||
type UnimplementedAPIServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedAPIServer) Recover(API_RecoverServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Recover not implemented")
|
||||
func (UnimplementedAPIServer) Recover(context.Context, *RecoverMessage) (*RecoverResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Recover not implemented")
|
||||
}
|
||||
func (UnimplementedAPIServer) mustEmbedUnimplementedAPIServer() {}
|
||||
|
||||
@ -92,30 +70,22 @@ func RegisterAPIServer(s grpc.ServiceRegistrar, srv APIServer) {
|
||||
s.RegisterService(&API_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _API_Recover_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(APIServer).Recover(&aPIRecoverServer{stream})
|
||||
}
|
||||
|
||||
type API_RecoverServer interface {
|
||||
Send(*RecoverResponse) error
|
||||
Recv() (*RecoverMessage, error)
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type aPIRecoverServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *aPIRecoverServer) Send(m *RecoverResponse) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *aPIRecoverServer) Recv() (*RecoverMessage, error) {
|
||||
m := new(RecoverMessage)
|
||||
if err := x.ServerStream.RecvMsg(m); err != nil {
|
||||
func _API_Recover_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RecoverMessage)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
if interceptor == nil {
|
||||
return srv.(APIServer).Recover(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/recoverproto.API/Recover",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(APIServer).Recover(ctx, req.(*RecoverMessage))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// API_ServiceDesc is the grpc.ServiceDesc for API service.
|
||||
@ -124,14 +94,12 @@ func (x *aPIRecoverServer) Recv() (*RecoverMessage, error) {
|
||||
var API_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "recoverproto.API",
|
||||
HandlerType: (*APIServer)(nil),
|
||||
Methods: []grpc.MethodDesc{},
|
||||
Streams: []grpc.StreamDesc{
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
StreamName: "Recover",
|
||||
Handler: _API_Recover_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
MethodName: "Recover",
|
||||
Handler: _API_Recover_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "recover.proto",
|
||||
}
|
||||
|
@ -20,10 +20,5 @@ const (
|
||||
|
||||
// DeriveClusterID derives the cluster ID from a salt and secret value.
|
||||
func DeriveClusterID(secret, salt []byte) ([]byte, error) {
|
||||
return crypto.DeriveKey(secret, salt, []byte(crypto.HKDFInfoPrefix+clusterIDContext), crypto.DerivedKeyLengthDefault)
|
||||
}
|
||||
|
||||
// DeriveMeasurementSecret derives the secret value needed to derive ClusterID.
|
||||
func DeriveMeasurementSecret(masterSecret, salt []byte) ([]byte, error) {
|
||||
return crypto.DeriveKey(masterSecret, salt, []byte(crypto.HKDFInfoPrefix+MeasurementSecretContext), crypto.DerivedKeyLengthDefault)
|
||||
return crypto.DeriveKey(secret, salt, []byte(crypto.DEKPrefix+clusterIDContext), crypto.DerivedKeyLengthDefault)
|
||||
}
|
||||
|
@ -31,21 +31,3 @@ func TestDeriveClusterID(t *testing.T) {
|
||||
require.NoError(err)
|
||||
assert.NotEqual(clusterID, clusterIDdiff)
|
||||
}
|
||||
|
||||
func TestDeriveMeasurementSecret(t *testing.T) {
|
||||
require := require.New(t)
|
||||
assert := assert.New(t)
|
||||
|
||||
testvector := testvector.HKDFMeasurementSecret
|
||||
measurementSecret, err := DeriveMeasurementSecret(testvector.Secret, testvector.Salt)
|
||||
require.NoError(err)
|
||||
assert.Equal(testvector.Output, measurementSecret)
|
||||
|
||||
measurementSecretdiff, err := DeriveMeasurementSecret(testvector.Secret, []byte("different-salt"))
|
||||
require.NoError(err)
|
||||
assert.NotEqual(measurementSecret, measurementSecretdiff)
|
||||
|
||||
measurementSecretdiff, err = DeriveMeasurementSecret([]byte("different-secret"), testvector.Salt)
|
||||
require.NoError(err)
|
||||
assert.NotEqual(measurementSecret, measurementSecretdiff)
|
||||
}
|
||||
|
@ -30,8 +30,10 @@ const (
|
||||
MasterSecretLengthMin = 16
|
||||
// RNGLengthDefault is the number of bytes used for generating nonces.
|
||||
RNGLengthDefault = 32
|
||||
// HKDFInfoPrefix is the prefix used for the info parameter in HKDF.
|
||||
HKDFInfoPrefix = "key-"
|
||||
// DEKPrefix is the prefix used to prefix DEK IDs. Originally introduced as a requirement for the HKDF info parameter.
|
||||
DEKPrefix = "key-"
|
||||
// MeasurementSecretKeyID is name used for the measurementSecret DEK.
|
||||
MeasurementSecretKeyID = "measurementSecret"
|
||||
)
|
||||
|
||||
// DeriveKey derives a key from a secret.
|
||||
|
@ -56,13 +56,14 @@ type KMSClient struct {
|
||||
awsClient ClientAPI
|
||||
policyProducer KeyPolicyProducer
|
||||
storage kmsInterface.Storage
|
||||
kekID string
|
||||
}
|
||||
|
||||
// New creates and initializes a new KMSClient for AWS.
|
||||
//
|
||||
// The parameter client needs to be initialized with valid AWS credentials (https://aws.github.io/aws-sdk-go-v2/docs/getting-started).
|
||||
// If storage is nil, the default MemMapStorage is used.
|
||||
func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterface.Storage, optFns ...func(*awsconfig.LoadOptions) error) (*KMSClient, error) {
|
||||
func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterface.Storage, kekID string, optFns ...func(*awsconfig.LoadOptions) error) (*KMSClient, error) {
|
||||
if store == nil {
|
||||
store = storage.NewMemMapStorage()
|
||||
}
|
||||
@ -77,6 +78,7 @@ func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterfa
|
||||
awsClient: client,
|
||||
policyProducer: policyProducer,
|
||||
storage: store,
|
||||
kekID: kekID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -206,9 +208,9 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
|
||||
}
|
||||
|
||||
// GetDEK returns the DEK for dekID and kekID from the KMS.
|
||||
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) {
|
||||
func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
|
||||
// The KEK should be identified by its alias. The alias always has the same scheme: 'alias/<kekId>'
|
||||
kekID = "alias/" + kekID
|
||||
kekID := "alias/" + c.kekID
|
||||
|
||||
// If a key for keyID exists in the storage, decrypt the key using the KEK.
|
||||
dek, err := c.decryptDEKFromStorage(ctx, kekID, keyID)
|
||||
|
@ -253,22 +253,23 @@ func TestAWSKMSClient(t *testing.T) {
|
||||
|
||||
awsClient := &fakeAWSClient{kekPool: make(map[string][]byte, 2)}
|
||||
|
||||
client := &KMSClient{
|
||||
awsClient: awsClient,
|
||||
policyProducer: &stubKeyPolicyProducer{},
|
||||
storage: storage.NewMemMapStorage(),
|
||||
}
|
||||
|
||||
awsClient.keyIDCount = -1
|
||||
|
||||
testKEK1ID := "testKEK1"
|
||||
testKEK1 := []byte("test KEK")
|
||||
testKEK2ID := "testKEK2"
|
||||
testKEK2 := []byte("more test KEK")
|
||||
client := &KMSClient{
|
||||
awsClient: awsClient,
|
||||
policyProducer: &stubKeyPolicyProducer{},
|
||||
storage: storage.NewMemMapStorage(),
|
||||
kekID: testKEK1ID,
|
||||
}
|
||||
|
||||
awsClient.keyIDCount = -1
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// try to get a DEK before setting the KEK
|
||||
_, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength)
|
||||
_, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
|
||||
assert.Error(err)
|
||||
assert.ErrorIs(err, kmsInterface.ErrKEKUnknown)
|
||||
|
||||
@ -285,16 +286,16 @@ func TestAWSKMSClient(t *testing.T) {
|
||||
assert.Equal(testKEK2, awsClient.kekPool[strconv.Itoa(awsClient.keyIDCount)])
|
||||
|
||||
// test GetDEK method
|
||||
dek1, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength)
|
||||
dek1, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
dek2, err := client.GetDEK(ctx, testKEK2ID, "volume02", config.SymmetricKeyLength)
|
||||
dek2, err := client.GetDEK(ctx, "volume02", config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
|
||||
// make sure that GetDEK is idempotent
|
||||
dek1Copy, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength)
|
||||
dek1Copy, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
assert.Equal(dek1, dek1Copy)
|
||||
dek2Copy, err := client.GetDEK(ctx, testKEK2ID, "volume02", config.SymmetricKeyLength)
|
||||
dek2Copy, err := client.GetDEK(ctx, "volume02", config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
assert.Equal(dek2, dek2Copy)
|
||||
}
|
||||
|
@ -47,6 +47,7 @@ type kmsClientAPI interface {
|
||||
type KMSClient struct {
|
||||
client kmsClientAPI
|
||||
storage kms.Storage
|
||||
kekID string
|
||||
}
|
||||
|
||||
// Opts are optional settings for AKV clients.
|
||||
@ -57,7 +58,7 @@ type Opts struct {
|
||||
}
|
||||
|
||||
// New initializes a KMS client for Azure Key Vault.
|
||||
func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms.Storage, opts *Opts) (*KMSClient, error) {
|
||||
func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms.Storage, kekID string, opts *Opts) (*KMSClient, error) {
|
||||
if opts == nil {
|
||||
opts = &Opts{}
|
||||
}
|
||||
@ -80,7 +81,7 @@ func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms
|
||||
if store == nil {
|
||||
store = storage.NewMemMapStorage()
|
||||
}
|
||||
return &KMSClient{client: client, storage: store}, nil
|
||||
return &KMSClient{client: client, storage: store, kekID: kekID}, nil
|
||||
}
|
||||
|
||||
// CreateKEK saves a new Key Encryption Key using Azure Key Vault.
|
||||
@ -111,8 +112,8 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
|
||||
}
|
||||
|
||||
// GetDEK decrypts a DEK from storage.
|
||||
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) {
|
||||
kek, err := c.getKEK(ctx, kekID)
|
||||
func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
|
||||
kek, err := c.getKEK(ctx, c.kekID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading KEK from key vault: %w", err)
|
||||
}
|
||||
|
@ -155,7 +155,7 @@ func TestKMSGetDEK(t *testing.T) {
|
||||
storage: tc.storage,
|
||||
}
|
||||
|
||||
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
|
||||
dek, err := client.GetDEK(context.Background(), "volume-01", 32)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
|
@ -38,10 +38,11 @@ type HSMClient struct {
|
||||
client hsmClientAPI
|
||||
storage kms.Storage
|
||||
vaultURL string
|
||||
kekID string
|
||||
}
|
||||
|
||||
// NewHSM initializes a KMS client for Azure manged HSM Key Vault.
|
||||
func NewHSM(ctx context.Context, vaultName string, store kms.Storage, opts *Opts) (*HSMClient, error) {
|
||||
func NewHSM(ctx context.Context, vaultName string, store kms.Storage, kekID string, opts *Opts) (*HSMClient, error) {
|
||||
if opts == nil {
|
||||
opts = &Opts{}
|
||||
}
|
||||
@ -72,6 +73,7 @@ func NewHSM(ctx context.Context, vaultName string, store kms.Storage, opts *Opts
|
||||
client: client,
|
||||
credentials: cred,
|
||||
storage: store,
|
||||
kekID: kekID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -114,7 +116,7 @@ func (c *HSMClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
|
||||
}
|
||||
|
||||
// GetDEK loads an encrypted DEK from storage and unwraps it using an HSM-backed key.
|
||||
func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekSize int) ([]byte, error) {
|
||||
func (c *HSMClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
|
||||
encryptedDEK, err := c.storage.Get(ctx, keyID)
|
||||
if err != nil {
|
||||
if !errors.Is(err, storage.ErrDEKUnset) {
|
||||
@ -126,7 +128,7 @@ func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekS
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("key generation: %w", err)
|
||||
}
|
||||
if err := c.putDEK(ctx, kekID, keyID, newDEK); err != nil {
|
||||
if err := c.putDEK(ctx, c.kekID, keyID, newDEK); err != nil {
|
||||
return nil, fmt.Errorf("creating new DEK: %w", err)
|
||||
}
|
||||
|
||||
@ -137,7 +139,7 @@ func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekS
|
||||
Algorithm: to.Ptr(azkeys.JSONWebKeyEncryptionAlgorithmA256KW),
|
||||
Value: encryptedDEK,
|
||||
}
|
||||
res, err := c.client.UnwrapKey(ctx, kekID, "", params, &azkeys.UnwrapKeyOptions{})
|
||||
res, err := c.client.UnwrapKey(ctx, c.kekID, "", params, &azkeys.UnwrapKeyOptions{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unwrapping key: %w", err)
|
||||
}
|
||||
|
@ -165,7 +165,7 @@ func TestHSMGetNewDEK(t *testing.T) {
|
||||
storage: tc.storage,
|
||||
}
|
||||
|
||||
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
|
||||
dek, err := client.GetDEK(context.Background(), "volume-01", 32)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
@ -208,7 +208,7 @@ func TestHSMGetExistingDEK(t *testing.T) {
|
||||
storage: storage,
|
||||
}
|
||||
|
||||
dek, err := client.GetDEK(context.Background(), "test-key", keyID, len(testKey))
|
||||
dek, err := client.GetDEK(context.Background(), keyID, len(testKey))
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
|
@ -20,8 +20,15 @@ type KMS struct {
|
||||
}
|
||||
|
||||
// New creates a new ClusterKMS.
|
||||
func New(salt []byte) *KMS {
|
||||
return &KMS{salt: salt}
|
||||
func New(key []byte, salt []byte) (*KMS, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, errors.New("missing master key")
|
||||
}
|
||||
if len(salt) == 0 {
|
||||
return nil, errors.New("missing salt")
|
||||
}
|
||||
|
||||
return &KMS{masterKey: key, salt: salt}, nil
|
||||
}
|
||||
|
||||
// CreateKEK sets the ClusterKMS masterKey.
|
||||
@ -31,7 +38,7 @@ func (c *KMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error {
|
||||
}
|
||||
|
||||
// GetDEK derives a key from the KMS masterKey.
|
||||
func (c *KMS) GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error) {
|
||||
func (c *KMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
|
||||
if len(c.masterKey) == 0 {
|
||||
return nil, errors.New("master key not set for Constellation KMS")
|
||||
}
|
||||
|
@ -24,19 +24,12 @@ func TestMain(m *testing.M) {
|
||||
func TestClusterKMS(t *testing.T) {
|
||||
testVector := testvector.HKDF0xFF
|
||||
assert := assert.New(t)
|
||||
kms := New(testVector.Salt)
|
||||
|
||||
key, err := kms.GetDEK(context.Background(), "", "key-1", 32)
|
||||
assert.Error(err)
|
||||
assert.Nil(key)
|
||||
|
||||
err = kms.CreateKEK(context.Background(), "", testVector.Secret)
|
||||
assert.NoError(err)
|
||||
assert.Equal(testVector.Secret, kms.masterKey)
|
||||
require := require.New(t)
|
||||
kms, err := New(testVector.Secret, testVector.Salt)
|
||||
require.NoError(err)
|
||||
|
||||
keyLower, err := kms.GetDEK(
|
||||
context.Background(),
|
||||
"",
|
||||
strings.ToLower(testVector.InfoPrefix+testVector.Info),
|
||||
int(testVector.Length),
|
||||
)
|
||||
@ -46,12 +39,11 @@ func TestClusterKMS(t *testing.T) {
|
||||
// output of the KMS should be case sensitive
|
||||
keyUpper, err := kms.GetDEK(
|
||||
context.Background(),
|
||||
"",
|
||||
strings.ToUpper(testVector.InfoPrefix+testVector.Info),
|
||||
int(testVector.Length),
|
||||
)
|
||||
assert.NoError(err)
|
||||
assert.NotEqual(key, keyUpper)
|
||||
assert.NotEqual(keyLower, keyUpper)
|
||||
}
|
||||
|
||||
func TestVectorsHKDF(t *testing.T) {
|
||||
@ -61,6 +53,7 @@ func TestVectorsHKDF(t *testing.T) {
|
||||
dekID string
|
||||
dekSize uint
|
||||
wantKey []byte
|
||||
wantErr bool
|
||||
}{
|
||||
"rfc Test Case 1": {
|
||||
kek: testvector.HKDFrfc1.Secret,
|
||||
@ -82,6 +75,7 @@ func TestVectorsHKDF(t *testing.T) {
|
||||
dekID: testvector.HKDFrfc3.Info,
|
||||
dekSize: testvector.HKDFrfc3.Length,
|
||||
wantKey: testvector.HKDFrfc3.Output,
|
||||
wantErr: true,
|
||||
},
|
||||
"HKDF zero": {
|
||||
kek: testvector.HKDFZero.Secret,
|
||||
@ -104,10 +98,15 @@ func TestVectorsHKDF(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
kms := New(tc.salt)
|
||||
kms, err := New(tc.kek, tc.salt)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
require.NoError(kms.CreateKEK(context.Background(), "", tc.kek))
|
||||
|
||||
out, err := kms.GetDEK(context.Background(), "", tc.dekID, int(tc.dekSize))
|
||||
out, err := kms.GetDEK(context.Background(), tc.dekID, int(tc.dekSize))
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantKey, out)
|
||||
})
|
||||
|
@ -50,11 +50,12 @@ type KMSClient struct {
|
||||
waitBackoffLimit int
|
||||
storage kmsInterface.Storage
|
||||
protectionLevel kmspb.ProtectionLevel
|
||||
kekID string
|
||||
opts []gax.CallOption
|
||||
}
|
||||
|
||||
// New initializes a KMS client for Google Cloud Platform.
|
||||
func New(ctx context.Context, projectID, locationID, keyRingID string, store kmsInterface.Storage, protectionLvl kmspb.ProtectionLevel, opts ...gax.CallOption) (*KMSClient, error) {
|
||||
func New(ctx context.Context, projectID, locationID, keyRingID string, store kmsInterface.Storage, protectionLvl kmspb.ProtectionLevel, kekID string, opts ...gax.CallOption) (*KMSClient, error) {
|
||||
if store == nil {
|
||||
store = storage.NewMemMapStorage()
|
||||
}
|
||||
@ -71,6 +72,7 @@ func New(ctx context.Context, projectID, locationID, keyRingID string, store kms
|
||||
waitBackoffLimit: 10,
|
||||
storage: store,
|
||||
protectionLevel: protectionLvl,
|
||||
kekID: kekID,
|
||||
opts: opts,
|
||||
}
|
||||
|
||||
@ -108,7 +110,7 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
|
||||
}
|
||||
|
||||
// GetDEK fetches an encrypted Data Encryption Key from storage and decrypts it using a KEK stored in Google's KMS.
|
||||
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) {
|
||||
func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
|
||||
client, err := c.newClient(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -126,11 +128,11 @@ func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("key generation: %w", err)
|
||||
}
|
||||
return newDEK, c.putDEK(ctx, client, kekID, keyID, newDEK)
|
||||
return newDEK, c.putDEK(ctx, client, c.kekID, keyID, newDEK)
|
||||
}
|
||||
|
||||
request := &kmspb.DecryptRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.projectID, c.locationID, c.keyRingID, kekID),
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.projectID, c.locationID, c.keyRingID, c.kekID),
|
||||
Ciphertext: encryptedDEK,
|
||||
}
|
||||
|
||||
|
@ -321,7 +321,7 @@ func TestGetDEK(t *testing.T) {
|
||||
storage: tc.storage,
|
||||
}
|
||||
|
||||
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
|
||||
dek, err := client.GetDEK(context.Background(), "volume-01", 32)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
|
@ -17,7 +17,7 @@ type CloudKMS interface {
|
||||
CreateKEK(ctx context.Context, keyID string, kek []byte) error
|
||||
// GetDEK returns the DEK for dekID and kekID from the KMS.
|
||||
// If the DEK does not exist, a new one is created and saved to storage.
|
||||
GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error)
|
||||
GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error)
|
||||
}
|
||||
|
||||
// Storage provides an abstract interface for the storage backend used for DEKs.
|
||||
|
@ -24,10 +24,10 @@ import (
|
||||
|
||||
// Well known endpoints for KMS services.
|
||||
const (
|
||||
AWSKMSURI = "kms://aws?keyPolicy=%s"
|
||||
AzureKMSURI = "kms://azure-kms?name=%s&type=%s"
|
||||
AzureHSMURI = "kms://azure-hsm?name=%s"
|
||||
GCPKMSURI = "kms://gcp?project=%s&location=%s&keyRing=%s&protectionLvl=%s"
|
||||
AWSKMSURI = "kms://aws?keyPolicy=%s&kekID=%s"
|
||||
AzureKMSURI = "kms://azure-kms?name=%s&type=%s&kekID=%s"
|
||||
AzureHSMURI = "kms://azure-hsm?name=%s&kekID=%s"
|
||||
GCPKMSURI = "kms://gcp?project=%s&location=%s&keyRing=%s&protectionLvl=%s&kekID=%s"
|
||||
ClusterKMSURI = "kms://cluster-kms?key=%s&salt=%s"
|
||||
AWSS3URI = "storage://aws?bucket=%s"
|
||||
AzureBlobURI = "storage://azure?container=%s&connectionString=%s"
|
||||
@ -35,6 +35,21 @@ const (
|
||||
NoStoreURI = "storage://no-store"
|
||||
)
|
||||
|
||||
// MasterSecret holds the master key and salt for deriving keys.
|
||||
type MasterSecret struct {
|
||||
Key []byte `json:"key"`
|
||||
Salt []byte `json:"salt"`
|
||||
}
|
||||
|
||||
// EncodeToURI returns an URI encoding the master secret.
|
||||
func (m *MasterSecret) EncodeToURI() string {
|
||||
return fmt.Sprintf(
|
||||
ClusterKMSURI,
|
||||
base64.URLEncoding.EncodeToString(m.Key),
|
||||
base64.URLEncoding.EncodeToString(m.Salt),
|
||||
)
|
||||
}
|
||||
|
||||
// KMSInformation about an existing KMS.
|
||||
type KMSInformation struct {
|
||||
KMSURI string
|
||||
@ -104,39 +119,39 @@ func getKMS(ctx context.Context, kmsURI string, store kms.Storage) (kms.CloudKMS
|
||||
|
||||
switch uri.Host {
|
||||
case "aws":
|
||||
poliyProducer, err := getAWSKMSConfig(uri)
|
||||
poliyProducer, kekID, err := getAWSKMSConfig(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return aws.New(ctx, poliyProducer, store)
|
||||
return aws.New(ctx, poliyProducer, store, kekID)
|
||||
|
||||
case "azure-kms":
|
||||
vaultName, vaultType, err := getAzureKMSConfig(uri)
|
||||
vaultName, vaultType, kekID, err := getAzureKMSConfig(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return azure.New(ctx, vaultName, azure.VaultSuffix(vaultType), store, nil)
|
||||
return azure.New(ctx, vaultName, azure.VaultSuffix(vaultType), store, kekID, nil)
|
||||
|
||||
case "azure-hsm":
|
||||
vaultName, err := getAzureHSMConfig(uri)
|
||||
vaultName, kekID, err := getAzureHSMConfig(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return azure.NewHSM(ctx, vaultName, store, nil)
|
||||
return azure.NewHSM(ctx, vaultName, store, kekID, nil)
|
||||
|
||||
case "gcp":
|
||||
project, location, keyRing, protectionLvl, err := getGCPKMSConfig(uri)
|
||||
project, location, keyRing, protectionLvl, kekID, err := getGCPKMSConfig(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return gcp.New(ctx, project, location, keyRing, store, kmspb.ProtectionLevel(protectionLvl))
|
||||
return gcp.New(ctx, project, location, keyRing, store, kmspb.ProtectionLevel(protectionLvl), kekID)
|
||||
|
||||
case "cluster-kms":
|
||||
salt, err := getClusterKMSConfig(uri)
|
||||
masterSecret, err := getClusterKMSConfig(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cluster.New(salt), nil
|
||||
return cluster.New(masterSecret.Key, masterSecret.Salt)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown KMS type: %s", uri.Host)
|
||||
@ -156,22 +171,56 @@ func getAWSS3Config(uri *url.URL) (string, error) {
|
||||
return r[0], err
|
||||
}
|
||||
|
||||
func getAWSKMSConfig(uri *url.URL) (*defaultPolicyProducer, error) {
|
||||
r, err := getConfig(uri.Query(), []string{"keyPolicy"})
|
||||
func getAWSKMSConfig(uri *url.URL) (*defaultPolicyProducer, string, error) {
|
||||
r, err := getConfig(uri.Query(), []string{"keyPolicy", "kekID"})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
return &defaultPolicyProducer{policy: r[0]}, err
|
||||
|
||||
if len(r) != 2 {
|
||||
return nil, "", fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
|
||||
}
|
||||
|
||||
kekID, err := base64.URLEncoding.DecodeString(r[1])
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
|
||||
}
|
||||
|
||||
return &defaultPolicyProducer{policy: r[0]}, string(kekID), err
|
||||
}
|
||||
|
||||
func getAzureKMSConfig(uri *url.URL) (string, string, error) {
|
||||
r, err := getConfig(uri.Query(), []string{"name", "type"})
|
||||
return r[0], r[1], err
|
||||
func getAzureKMSConfig(uri *url.URL) (string, string, string, error) {
|
||||
r, err := getConfig(uri.Query(), []string{"name", "type", "kekID"})
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("getting config: %w", err)
|
||||
}
|
||||
if len(r) != 3 {
|
||||
return "", "", "", fmt.Errorf("expected 3 KmsURI args, got %d", len(r))
|
||||
}
|
||||
|
||||
kekID, err := base64.URLEncoding.DecodeString(r[2])
|
||||
if err != nil {
|
||||
return "", "", "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
|
||||
}
|
||||
|
||||
return r[0], r[1], string(kekID), err
|
||||
}
|
||||
|
||||
func getAzureHSMConfig(uri *url.URL) (string, error) {
|
||||
r, err := getConfig(uri.Query(), []string{"name"})
|
||||
return r[0], err
|
||||
func getAzureHSMConfig(uri *url.URL) (string, string, error) {
|
||||
r, err := getConfig(uri.Query(), []string{"name", "kekID"})
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("getting config: %w", err)
|
||||
}
|
||||
if len(r) != 2 {
|
||||
return "", "", fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
|
||||
}
|
||||
|
||||
kekID, err := base64.URLEncoding.DecodeString(r[1])
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
|
||||
}
|
||||
|
||||
return r[0], string(kekID), err
|
||||
}
|
||||
|
||||
func getAzureBlobConfig(uri *url.URL) (string, string, error) {
|
||||
@ -182,16 +231,26 @@ func getAzureBlobConfig(uri *url.URL) (string, string, error) {
|
||||
return r[0], r[1], nil
|
||||
}
|
||||
|
||||
func getGCPKMSConfig(uri *url.URL) (string, string, string, int32, error) {
|
||||
r, err := getConfig(uri.Query(), []string{"project", "location", "keyRing", "protectionLvl"})
|
||||
func getGCPKMSConfig(uri *url.URL) (project string, location string, keyRing string, protectionLvl int32, kekID string, err error) {
|
||||
r, err := getConfig(uri.Query(), []string{"project", "location", "keyRing", "protectionLvl", "kekID"})
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return "", "", "", 0, "", err
|
||||
}
|
||||
protectionLvl, err := strconv.ParseInt(r[3], 10, 32)
|
||||
|
||||
if len(r) != 5 {
|
||||
return "", "", "", 0, "", fmt.Errorf("expected 5 KmsURI args, got %d", len(r))
|
||||
}
|
||||
|
||||
kekIDByte, err := base64.URLEncoding.DecodeString(r[4])
|
||||
if err != nil {
|
||||
return "", "", "", 0, err
|
||||
return "", "", "", 0, "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
|
||||
}
|
||||
return r[0], r[1], r[2], int32(protectionLvl), nil
|
||||
|
||||
protectionLvl32, err := strconv.ParseInt(r[3], 10, 32)
|
||||
if err != nil {
|
||||
return "", "", "", 0, "", err
|
||||
}
|
||||
return r[0], r[1], r[2], int32(protectionLvl32), string(kekIDByte), nil
|
||||
}
|
||||
|
||||
func getGCPStorageConfig(uri *url.URL) (string, string, error) {
|
||||
@ -199,12 +258,26 @@ func getGCPStorageConfig(uri *url.URL) (string, string, error) {
|
||||
return r[0], r[1], err
|
||||
}
|
||||
|
||||
func getClusterKMSConfig(uri *url.URL) ([]byte, error) {
|
||||
r, err := getConfig(uri.Query(), []string{"salt"})
|
||||
func getClusterKMSConfig(uri *url.URL) (MasterSecret, error) {
|
||||
r, err := getConfig(uri.Query(), []string{"key", "salt"})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return MasterSecret{}, err
|
||||
}
|
||||
return base64.URLEncoding.DecodeString(r[0])
|
||||
|
||||
if len(r) != 2 {
|
||||
return MasterSecret{}, fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
|
||||
}
|
||||
|
||||
key, err := base64.URLEncoding.DecodeString(r[0])
|
||||
if err != nil {
|
||||
return MasterSecret{}, fmt.Errorf("parsing key from kmsUri: %w", err)
|
||||
}
|
||||
salt, err := base64.URLEncoding.DecodeString(r[1])
|
||||
if err != nil {
|
||||
return MasterSecret{}, fmt.Errorf("parsing salt from kmsUri: %w", err)
|
||||
}
|
||||
|
||||
return MasterSecret{Key: key, Salt: salt}, nil
|
||||
}
|
||||
|
||||
// getConfig parses url query values, returning a map of the requested values.
|
||||
|
@ -18,6 +18,8 @@ import (
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
const constellationKekID = "Constellation"
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m,
|
||||
// https://github.com/census-instrumentation/opencensus-go/issues/1262
|
||||
@ -80,23 +82,23 @@ func TestGetKMS(t *testing.T) {
|
||||
wantErr bool
|
||||
}{
|
||||
"cluster kms": {
|
||||
uri: fmt.Sprintf("%s?salt=%s", ClusterKMSURI, base64.URLEncoding.EncodeToString([]byte("salt"))),
|
||||
uri: fmt.Sprintf(ClusterKMSURI, base64.URLEncoding.EncodeToString([]byte("key")), base64.URLEncoding.EncodeToString([]byte("salt"))),
|
||||
wantErr: false,
|
||||
},
|
||||
"aws kms": {
|
||||
uri: fmt.Sprintf(AWSKMSURI, ""),
|
||||
uri: fmt.Sprintf(AWSKMSURI, "", ""),
|
||||
wantErr: true,
|
||||
},
|
||||
"azure kms": {
|
||||
uri: fmt.Sprintf(AzureKMSURI, "", ""),
|
||||
uri: fmt.Sprintf(AzureKMSURI, "", "", ""),
|
||||
wantErr: true,
|
||||
},
|
||||
"azure hsm": {
|
||||
uri: fmt.Sprintf(AzureHSMURI, ""),
|
||||
uri: fmt.Sprintf(AzureHSMURI, "", ""),
|
||||
wantErr: true,
|
||||
},
|
||||
"gcp kms": {
|
||||
uri: fmt.Sprintf(GCPKMSURI, "", "", "", ""),
|
||||
uri: fmt.Sprintf(GCPKMSURI, "", "", "", "", ""),
|
||||
wantErr: true,
|
||||
},
|
||||
"unknown kms": {
|
||||
@ -135,7 +137,8 @@ func TestSetUpKMS(t *testing.T) {
|
||||
assert.Error(err)
|
||||
assert.Nil(kms)
|
||||
|
||||
kms, err = KMS(context.Background(), "storage://no-store", "kms://cluster-kms?salt="+base64.URLEncoding.EncodeToString([]byte("salt")))
|
||||
masterSecret := MasterSecret{Key: []byte("key"), Salt: []byte("salt")}
|
||||
kms, err = KMS(context.Background(), "storage://no-store", masterSecret.EncodeToURI())
|
||||
assert.NoError(err)
|
||||
assert.NotNil(kms)
|
||||
}
|
||||
@ -146,13 +149,15 @@ func TestGetAWSKMSConfig(t *testing.T) {
|
||||
|
||||
policy := "{keyPolicy: keyPolicy}"
|
||||
escapedPolicy := url.QueryEscape(policy)
|
||||
uri, err := url.Parse(fmt.Sprintf(AWSKMSURI, escapedPolicy))
|
||||
kekID := base64.URLEncoding.EncodeToString([]byte(constellationKekID))
|
||||
uri, err := url.Parse(fmt.Sprintf(AWSKMSURI, escapedPolicy, kekID))
|
||||
require.NoError(err)
|
||||
policyProducer, err := getAWSKMSConfig(uri)
|
||||
policyProducer, rKekID, err := getAWSKMSConfig(uri)
|
||||
require.NoError(err)
|
||||
keyPolicy, err := policyProducer.CreateKeyPolicy("")
|
||||
require.NoError(err)
|
||||
assert.Equal(policy, keyPolicy)
|
||||
assert.Equal(constellationKekID, rKekID)
|
||||
}
|
||||
|
||||
func TestGetAzureBlobConfig(t *testing.T) {
|
||||
@ -178,18 +183,20 @@ func TestGetGCPKMSConfig(t *testing.T) {
|
||||
location := "global"
|
||||
keyRing := "test-ring"
|
||||
protectionLvl := "2"
|
||||
uri, err := url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, protectionLvl))
|
||||
kekID := base64.URLEncoding.EncodeToString([]byte(constellationKekID))
|
||||
uri, err := url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, protectionLvl, kekID))
|
||||
require.NoError(err)
|
||||
rProject, rLocation, rKeyRing, rProtectionLvl, err := getGCPKMSConfig(uri)
|
||||
rProject, rLocation, rKeyRing, rProtectionLvl, rKekID, err := getGCPKMSConfig(uri)
|
||||
require.NoError(err)
|
||||
assert.Equal(project, rProject)
|
||||
assert.Equal(location, rLocation)
|
||||
assert.Equal(keyRing, rKeyRing)
|
||||
assert.Equal(int32(2), rProtectionLvl)
|
||||
assert.Equal(constellationKekID, rKekID)
|
||||
|
||||
uri, err = url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, "invalid"))
|
||||
uri, err = url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, "invalid", kekID))
|
||||
require.NoError(err)
|
||||
_, _, _, _, err = getGCPKMSConfig(uri)
|
||||
_, _, _, _, _, err = getGCPKMSConfig(uri)
|
||||
assert.Error(err)
|
||||
}
|
||||
|
||||
@ -202,12 +209,13 @@ func TestGetClusterKMSConfig(t *testing.T) {
|
||||
0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf,
|
||||
}
|
||||
|
||||
uri, err := url.Parse(ClusterKMSURI + "?salt=" + base64.URLEncoding.EncodeToString(expectedSalt))
|
||||
masterSecretIn := MasterSecret{Key: []byte("key"), Salt: expectedSalt}
|
||||
uri, err := url.Parse(masterSecretIn.EncodeToURI())
|
||||
require.NoError(err)
|
||||
|
||||
salt, err := getClusterKMSConfig(uri)
|
||||
masterSecretOut, err := getClusterKMSConfig(uri)
|
||||
assert.NoError(err)
|
||||
assert.Equal(expectedSalt, salt)
|
||||
assert.Equal(expectedSalt, masterSecretOut.Salt)
|
||||
}
|
||||
|
||||
func TestGetConfig(t *testing.T) {
|
||||
|
@ -59,7 +59,7 @@ const (
|
||||
// JoinImage image of Constellation join service.
|
||||
JoinImage = "ghcr.io/edgelesssys/constellation/join-service:v2.5.0-pre.0.20230118154955-632090c21b93@sha256:7c53b43f2580ded9f04a9927d4ff585d3edce5d10a1d83006688c818e6395eb1" // renovate:container
|
||||
// KeyServiceImage image of Constellation KMS server.
|
||||
KeyServiceImage = "ghcr.io/edgelesssys/constellation/kmsserver:v2.5.0-pre.0.20230112123617-d0e9f427d1ba@sha256:d4319308eb62e2ee079cc86858acdd1faccc404edec7bfabecf35861284a55f3" // renovate:container
|
||||
KeyServiceImage = "ghcr.io/edgelesssys/constellation/keyservice:v2.5.0-pre.0.20230116125211-d37bd077d8c6@sha256:4c14176f94899054bbf945f6f209521ffcdbcb9042abc5850d778240fe3693a4" // renovate:container
|
||||
// VerificationImage image of Constellation verification service.
|
||||
VerificationImage = "ghcr.io/edgelesssys/constellation/verification-service:v2.5.0-pre.0.20230118154955-632090c21b93@sha256:593f735a236f0cb8f4373a7a2dca41be9ab2ba1b784a2ebcf8fb5271705822a3" // renovate:container
|
||||
// GcpGuestImage image for GCP guest agent.
|
||||
|
@ -8,7 +8,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"flag"
|
||||
"path/filepath"
|
||||
@ -53,18 +52,15 @@ func main() {
|
||||
if len(salt) < crypto.RNGLengthDefault {
|
||||
log.With(zap.Error(errors.New("invalid salt length"))).Fatalf("Expected salt to be %d bytes, but got %d", crypto.RNGLengthDefault, len(salt))
|
||||
}
|
||||
keyURI := setup.ClusterKMSURI + "?salt=" + base64.URLEncoding.EncodeToString(salt)
|
||||
masterSecret := setup.MasterSecret{Key: masterKey, Salt: salt}
|
||||
|
||||
// set up Key Management Service
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
conKMS, err := setup.KMS(ctx, setup.NoStoreURI, keyURI)
|
||||
conKMS, err := setup.KMS(ctx, setup.NoStoreURI, masterSecret.EncodeToURI())
|
||||
if err != nil {
|
||||
log.With(zap.Error(err)).Fatalf("Failed to setup KMS")
|
||||
}
|
||||
if err := conKMS.CreateKEK(ctx, "Constellation", masterKey); err != nil {
|
||||
log.With(zap.Error(err)).Fatalf("Failed to create KMS KEK from MasterKey")
|
||||
}
|
||||
|
||||
if err := server.New(log.Named("keyservice"), conKMS).Run(*port); err != nil {
|
||||
log.With(zap.Error(err)).Fatalf("Failed to run keyservice server")
|
||||
|
@ -74,7 +74,7 @@ func (s *Server) GetDataKey(ctx context.Context, in *keyserviceproto.GetDataKeyR
|
||||
return nil, status.Error(codes.InvalidArgument, "no data key ID specified")
|
||||
}
|
||||
|
||||
key, err := s.conKMS.GetDEK(ctx, "Constellation", crypto.HKDFInfoPrefix+in.DataKeyId, int(in.Length))
|
||||
key, err := s.conKMS.GetDEK(ctx, crypto.DEKPrefix+in.DataKeyId, int(in.Length))
|
||||
if err != nil {
|
||||
log.With(zap.Error(err)).Errorf("Failed to get data key")
|
||||
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||
|
@ -63,7 +63,7 @@ func (c *stubKMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stubKMS) GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error) {
|
||||
func (c *stubKMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
|
||||
if c.deriveKeyErr != nil {
|
||||
return nil, c.deriveKeyErr
|
||||
}
|
||||
|
@ -146,7 +146,7 @@ func TestAwsKms(t *testing.T) {
|
||||
require.NotEqual(newKEKId1, newKEKId2)
|
||||
var keyPolicyProducer createKeyPolicyFunc
|
||||
|
||||
client, err := awsInterface.New(context.Background(), &keyPolicyProducer, nil)
|
||||
client, err := awsInterface.New(context.Background(), &keyPolicyProducer, nil, newKEKId1)
|
||||
require.NoError(err)
|
||||
|
||||
privateKEK1 := []byte(strings.Repeat("1234", 8))
|
||||
@ -166,14 +166,14 @@ func TestAwsKms(t *testing.T) {
|
||||
assert.NoError(client.CreateKEK(ctx, newKEKId1, privateKEK2))
|
||||
|
||||
// make sure that GetDEK is idempotent
|
||||
volumeKey1, err := client.GetDEK(ctx, newKEKId1, "volume01", kmsconfig.SymmetricKeyLength)
|
||||
volumeKey1, err := client.GetDEK(ctx, "volume01", kmsconfig.SymmetricKeyLength)
|
||||
require.NoError(err)
|
||||
volumeKey1Copy, err := client.GetDEK(ctx, newKEKId1, "volume01", kmsconfig.SymmetricKeyLength)
|
||||
volumeKey1Copy, err := client.GetDEK(ctx, "volume01", kmsconfig.SymmetricKeyLength)
|
||||
require.NoError(err)
|
||||
assert.Equal(volumeKey1, volumeKey1Copy)
|
||||
|
||||
// test setting a second DEK
|
||||
volumeKey2, err := client.GetDEK(ctx, newKEKId1, "volume02", kmsconfig.SymmetricKeyLength)
|
||||
volumeKey2, err := client.GetDEK(ctx, "volume02", kmsconfig.SymmetricKeyLength)
|
||||
require.NoError(err)
|
||||
assert.NotEqual(volumeKey1, volumeKey2)
|
||||
|
||||
@ -184,7 +184,7 @@ func TestAwsKms(t *testing.T) {
|
||||
assert.NoError(client.CreateKEK(ctx, newKEKId2, privateKEK3))
|
||||
|
||||
// test setting a DEK with AWS KMS generated KEK
|
||||
volumeKey3, err := client.GetDEK(ctx, newKEKId2, "volume03", kmsconfig.SymmetricKeyLength)
|
||||
volumeKey3, err := client.GetDEK(ctx, "volume03", kmsconfig.SymmetricKeyLength)
|
||||
require.NoError(err)
|
||||
assert.NotEqual(volumeKey1, volumeKey3)
|
||||
|
||||
|
@ -66,22 +66,22 @@ func TestAzureKeyVault(t *testing.T) {
|
||||
store := storage.NewMemMapStorage()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer cancel()
|
||||
client, err := azure.New(ctx, azVaultName, azure.DefaultCloud, store, nil)
|
||||
kekName := "test-kek"
|
||||
client, err := azure.New(ctx, azVaultName, azure.DefaultCloud, store, kekName, nil)
|
||||
require.NoError(err)
|
||||
|
||||
kekName := "test-kek"
|
||||
dekName := "test-dek"
|
||||
|
||||
assert.NoError(client.CreateKEK(ctx, kekName, nil))
|
||||
|
||||
res, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
|
||||
res, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
|
||||
res2, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
|
||||
res2, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
assert.Equal(res, res2)
|
||||
|
||||
res3, err := client.GetDEK(ctx, kekName, addSuffix(dekName), config.SymmetricKeyLength)
|
||||
res3, err := client.GetDEK(ctx, addSuffix(dekName), config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
assert.Len(res3, config.SymmetricKeyLength)
|
||||
assert.NotEqual(res, res3)
|
||||
@ -102,10 +102,10 @@ func TestAzureHSM(t *testing.T) {
|
||||
store := storage.NewMemMapStorage()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer cancel()
|
||||
client, err := azure.NewHSM(ctx, azHSMName, store, nil)
|
||||
kekName := "test-kek"
|
||||
client, err := azure.NewHSM(ctx, azHSMName, store, kekName, nil)
|
||||
require.NoError(err)
|
||||
|
||||
kekName := "test-kek"
|
||||
dekName := "test-dek"
|
||||
importedKek := "test-kek-import"
|
||||
kekData := []byte{0x52, 0xFD, 0xFC, 0x07, 0x21, 0x82, 0x65, 0x4F, 0x16, 0x3F, 0x5F, 0x0F, 0x9A, 0x62, 0x1D, 0x72, 0x95, 0x66, 0xC7, 0x4D, 0x10, 0x03, 0x7C, 0x4D, 0x7B, 0xBB, 0x04, 0x07, 0xD1, 0xE2, 0xC6, 0x49}
|
||||
@ -114,15 +114,15 @@ func TestAzureHSM(t *testing.T) {
|
||||
|
||||
assert.NoError(client.CreateKEK(ctx, kekName, nil))
|
||||
|
||||
res, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
|
||||
res, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
|
||||
require.NoError(err)
|
||||
assert.NotNil(res)
|
||||
|
||||
res2, err := client.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
|
||||
res2, err := client.GetDEK(ctx, dekName, config.SymmetricKeyLength)
|
||||
require.NoError(err)
|
||||
assert.Equal(res, res2)
|
||||
|
||||
res3, err := client.GetDEK(ctx, kekName, addSuffix(dekName), config.SymmetricKeyLength)
|
||||
res3, err := client.GetDEK(ctx, addSuffix(dekName), config.SymmetricKeyLength)
|
||||
require.NoError(err)
|
||||
assert.Len(res3, config.SymmetricKeyLength)
|
||||
assert.NotEqual(res, res3)
|
||||
|
@ -42,20 +42,20 @@ func TestCreateGcpKEK(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer cancel()
|
||||
|
||||
kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE)
|
||||
kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE, kekName)
|
||||
require.NoError(err)
|
||||
|
||||
// Key name is random, but there is a chance we try to create a key that already exists, in that case the test fails
|
||||
assert.NoError(kmsClient.CreateKEK(ctx, kekName, nil))
|
||||
|
||||
res, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
|
||||
res, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
|
||||
res2, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
|
||||
res2, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
assert.Equal(res, res2)
|
||||
|
||||
res3, err := kmsClient.GetDEK(ctx, kekName, addSuffix(dekName), config.SymmetricKeyLength)
|
||||
res3, err := kmsClient.GetDEK(ctx, addSuffix(dekName), config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
assert.Len(res3, config.SymmetricKeyLength)
|
||||
assert.NotEqual(res, res3)
|
||||
@ -76,15 +76,15 @@ func TestImportGcpKEK(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer cancel()
|
||||
|
||||
kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE)
|
||||
kmsClient, err := gcp.New(ctx, gcpProjectID, gcpLocation, gcpKeyRing, store, kmspb.ProtectionLevel_SOFTWARE, kekName)
|
||||
require.NoError(err)
|
||||
|
||||
assert.NoError(kmsClient.CreateKEK(ctx, kekName, kekData))
|
||||
|
||||
res, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
|
||||
res, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
|
||||
res2, err := kmsClient.GetDEK(ctx, kekName, dekName, config.SymmetricKeyLength)
|
||||
res2, err := kmsClient.GetDEK(ctx, dekName, config.SymmetricKeyLength)
|
||||
assert.NoError(err)
|
||||
assert.Equal(res, res2)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user