Refactor init/recovery to use kms URI

So far the masterSecret was sent to the initial bootstrapper
on init/recovery. With this commit this information is encoded
in the kmsURI that is sent during init.
For recover, the communication with the recoveryserver is
changed. Before a streaming gRPC call was used to
exchanges UUID for measurementSecret and state disk key.
Now a standard gRPC is made that includes the same kmsURI &
storageURI that are sent during init.
This commit is contained in:
Otto Bittner 2023-01-16 11:19:03 +01:00
parent 0e71322e2e
commit 9a1f52e94e
35 changed files with 466 additions and 623 deletions

View file

@ -30,7 +30,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
keyservice "github.com/edgelesssys/constellation/v2/internal/kms/setup"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/edgelesssys/constellation/v2/internal/versions"
@ -146,8 +146,8 @@ func (i *initCmd) initialize(cmd *cobra.Command, newDialer func(validator *cloud
req := &initproto.InitRequest{
MasterSecret: masterSecret.Key,
Salt: masterSecret.Salt,
KmsUri: keyservice.ClusterKMSURI,
StorageUri: keyservice.NoStoreURI,
KmsUri: masterSecret.EncodeToURI(),
StorageUri: kmssetup.NoStoreURI,
KeyEncryptionKeyId: "",
UseExistingKek: false,
CloudServiceAccountUri: serviceAccURI,
@ -296,26 +296,20 @@ type initFlags struct {
conformance bool
}
// masterSecret holds the master key and salt for deriving keys.
type masterSecret struct {
Key []byte `json:"key"`
Salt []byte `json:"salt"`
}
// readOrGenerateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler file.Handler, filename string) (masterSecret, error) {
func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler file.Handler, filename string) (kmssetup.MasterSecret, error) {
if filename != "" {
i.log.Debugf("Reading master secret from file %q", filename)
var secret masterSecret
var secret kmssetup.MasterSecret
if err := fileHandler.ReadJSON(filename, &secret); err != nil {
return masterSecret{}, err
return kmssetup.MasterSecret{}, err
}
if len(secret.Key) < crypto.MasterSecretLengthMin {
return masterSecret{}, fmt.Errorf("provided master secret is smaller than the required minimum of %d Bytes", crypto.MasterSecretLengthMin)
return kmssetup.MasterSecret{}, fmt.Errorf("provided master secret is smaller than the required minimum of %d Bytes", crypto.MasterSecretLengthMin)
}
if len(secret.Salt) < crypto.RNGLengthDefault {
return masterSecret{}, fmt.Errorf("provided salt is smaller than the required minimum of %d Bytes", crypto.RNGLengthDefault)
return kmssetup.MasterSecret{}, fmt.Errorf("provided salt is smaller than the required minimum of %d Bytes", crypto.RNGLengthDefault)
}
return secret, nil
}
@ -324,19 +318,19 @@ func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler fi
i.log.Debugf("Generating new master secret")
key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault)
if err != nil {
return masterSecret{}, err
return kmssetup.MasterSecret{}, err
}
salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return masterSecret{}, err
return kmssetup.MasterSecret{}, err
}
secret := masterSecret{
secret := kmssetup.MasterSecret{
Key: key,
Salt: salt,
}
i.log.Debugf("Generated master secret key and salt values")
if err := fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
return masterSecret{}, err
return kmssetup.MasterSecret{}, err
}
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to ./%s\n", constants.MasterSecretFilename)
return secret, nil

View file

@ -18,6 +18,8 @@ import (
"testing"
"time"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
@ -183,7 +185,7 @@ func TestInitialize(t *testing.T) {
require.NoError(err)
// assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID")))
assert.Contains(out.String(), hex.EncodeToString([]byte("clusterID")))
var secret masterSecret
var secret kmssetup.MasterSecret
assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret))
assert.NotEmpty(secret.Key)
assert.NotEmpty(secret.Salt)
@ -251,7 +253,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
createFileFunc: func(handler file.Handler) error {
return handler.WriteJSON(
"someSecret",
masterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")},
kmssetup.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")},
file.OptNone,
)
},
@ -282,7 +284,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
createFileFunc: func(handler file.Handler) error {
return handler.WriteJSON(
"shortSecret",
masterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("short")},
kmssetup.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("short")},
file.OptNone,
)
},
@ -294,7 +296,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
createFileFunc: func(handler file.Handler) error {
return handler.WriteJSON(
"shortSecret",
masterSecret{Key: []byte("short"), Salt: []byte("constellation-32Byte-length-salt")},
kmssetup.MasterSecret{Key: []byte("short"), Salt: []byte("constellation-32Byte-length-salt")},
file.OptNone,
)
},
@ -340,7 +342,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
tc.filename = strings.Trim(filename[1], "\n")
}
var masterSecret masterSecret
var masterSecret kmssetup.MasterSecret
require.NoError(fileHandler.ReadJSON(tc.filename, &masterSecret))
assert.Equal(masterSecret.Key, secret.Key)
assert.Equal(masterSecret.Salt, secret.Salt)

View file

@ -8,7 +8,6 @@ package cmd
import (
"context"
"errors"
"fmt"
"io"
"net"
@ -17,7 +16,6 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
@ -25,6 +23,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/retry"
"github.com/spf13/afero"
"github.com/spf13/cobra"
@ -73,7 +72,7 @@ func (r *recoverCmd) recover(
}
r.log.Debugf("Using flags: %+v", flags)
var masterSecret masterSecret
var masterSecret kmssetup.MasterSecret
r.log.Debugf("Loading master secret file from %s", flags.secretPath)
if err := fileHandler.ReadJSON(flags.secretPath, &masterSecret); err != nil {
return err
@ -97,12 +96,7 @@ func (r *recoverCmd) recover(
r.log.Debugf("Created a new validator")
doer.setDialer(newDialer(validator), flags.endpoint)
r.log.Debugf("Set dialer for endpoint %s", flags.endpoint)
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt)
r.log.Debugf("Derived measurementSecret")
if err != nil {
return err
}
doer.setSecrets(getStateDiskKeyFunc(masterSecret.Key, masterSecret.Salt), measurementSecret)
doer.setURIs(masterSecret.EncodeToURI(), kmssetup.NoStoreURI)
r.log.Debugf("Set secrets")
if err := r.recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil {
if grpcRetry.ServiceIsUnavailable(err) {
@ -157,15 +151,15 @@ func (r *recoverCmd) recoverCall(ctx context.Context, out io.Writer, interval ti
type recoverDoerInterface interface {
Do(ctx context.Context) error
setDialer(dialer grpcDialer, endpoint string)
setSecrets(getDiskKey func(uuid string) ([]byte, error), measurementSecret []byte)
setURIs(kmsURI, storageURI string)
}
type recoverDoer struct {
dialer grpcDialer
endpoint string
measurementSecret []byte
getDiskKey func(uuid string) (key []byte, err error)
log debugLog
dialer grpcDialer
endpoint string
kmsURI string // encodes masterSecret
storageURI string
log debugLog
}
// Do performs the recover streaming rpc.
@ -177,53 +171,19 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
d.log.Debugf("Dialed recovery server")
defer conn.Close()
// set up streaming client
protoClient := recoverproto.NewAPIClient(conn)
d.log.Debugf("Created protoClient")
recoverclient, err := protoClient.Recover(ctx)
d.log.Debugf("Created recoverclient")
req := &recoverproto.RecoverMessage{
KmsUri: d.kmsURI,
StorageUri: d.storageURI,
}
_, err = protoClient.Recover(ctx, req)
if err != nil {
return fmt.Errorf("creating client: %w", err)
return fmt.Errorf("calling recover: %w", err)
}
defer func() {
_ = recoverclient.CloseSend()
}()
// send measurement secret as first message
if err := recoverclient.Send(&recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_MeasurementSecret{
MeasurementSecret: d.measurementSecret,
},
}); err != nil {
return fmt.Errorf("sending measurement secret: %w", err)
}
d.log.Debugf("Sent measurement secret")
// receive disk uuid
res, err := recoverclient.Recv()
if err != nil {
return fmt.Errorf("receiving disk uuid: %w", err)
}
d.log.Debugf("Received disk uuid")
stateDiskKey, err := d.getDiskKey(res.DiskUuid)
if err != nil {
return fmt.Errorf("getting state disk key: %w", err)
}
d.log.Debugf("Got state disk key")
// send disk key
if err := recoverclient.Send(&recoverproto.RecoverMessage{
Request: &recoverproto.RecoverMessage_StateDiskKey{
StateDiskKey: stateDiskKey,
},
}); err != nil {
return fmt.Errorf("sending state disk key: %w", err)
}
d.log.Debugf("Sent state disk key")
if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
return fmt.Errorf("receiving confirmation: %w", err)
}
d.log.Debugf("Received confirmation")
return nil
}
@ -233,9 +193,9 @@ func (d *recoverDoer) setDialer(dialer grpcDialer, endpoint string) {
d.endpoint = endpoint
}
func (d *recoverDoer) setSecrets(getDiskKey func(string) ([]byte, error), measurementSecret []byte) {
d.getDiskKey = getDiskKey
d.measurementSecret = measurementSecret
func (d *recoverDoer) setURIs(kmsURI, storageURI string) {
d.kmsURI = kmsURI
d.storageURI = storageURI
}
type recoverFlags struct {
@ -281,6 +241,6 @@ func (r *recoverCmd) parseRecoverFlags(cmd *cobra.Command, fileHandler file.Hand
func getStateDiskKeyFunc(masterKey, salt []byte) func(uuid string) ([]byte, error) {
return func(uuid string) ([]byte, error) {
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.StateDiskKeyLength)
return crypto.DeriveKey(masterKey, salt, []byte(crypto.DEKPrefix+uuid), crypto.StateDiskKeyLength)
}
}

View file

@ -26,6 +26,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
@ -156,7 +157,7 @@ func TestRecover(t *testing.T) {
require.NoError(fileHandler.WriteJSON(
"constellation-mastersecret.json",
masterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
kmssetup.MasterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
file.OptNone,
))
@ -244,101 +245,16 @@ func TestParseRecoverFlags(t *testing.T) {
}
func TestDoRecovery(t *testing.T) {
someErr := errors.New("error")
testCases := map[string]struct {
recoveryServer *stubRecoveryServer
wantErr bool
}{
"success": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return stream.Send(&recoverproto.RecoverResponse{
DiskUuid: "00000000-0000-0000-0000-000000000000",
})
},
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
}},
},
recoveryServer: &stubRecoveryServer{},
},
"error on first recv": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{
{
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
},
},
},
wantErr: true,
},
"error on send": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{
{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
},
},
},
wantErr: true,
},
"error on second recv": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{
{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return stream.Send(&recoverproto.RecoverResponse{
DiskUuid: "00000000-0000-0000-0000-000000000000",
})
},
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
},
},
},
wantErr: true,
},
"final message is an error": {
recoveryServer: &stubRecoveryServer{
actions: [][]func(stream recoverproto.API_RecoverServer) error{{
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return stream.Send(&recoverproto.RecoverResponse{
DiskUuid: "00000000-0000-0000-0000-000000000000",
})
},
func(stream recoverproto.API_RecoverServer) error {
_, err := stream.Recv()
return err
},
func(stream recoverproto.API_RecoverServer) error {
return someErr
},
}},
},
wantErr: true,
"server responds with error": {
recoveryServer: &stubRecoveryServer{recoverError: errors.New("someErr")},
wantErr: true,
},
}
@ -357,13 +273,9 @@ func TestDoRecovery(t *testing.T) {
r := &recoverCmd{log: logger.NewTest(t)}
recoverDoer := &recoverDoer{
dialer: dialer.New(nil, nil, netDialer),
endpoint: addr,
measurementSecret: []byte("measurement-secret"),
getDiskKey: func(string) ([]byte, error) {
return []byte("disk-key"), nil
},
log: r.log,
dialer: dialer.New(nil, nil, netDialer),
endpoint: addr,
log: r.log,
}
err := recoverDoer.Do(context.Background())
@ -402,23 +314,15 @@ func TestDeriveStateDiskKey(t *testing.T) {
}
type stubRecoveryServer struct {
actions [][]func(recoverproto.API_RecoverServer) error
calls int
recoverError error
recoverproto.UnimplementedAPIServer
}
func (s *stubRecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
if s.calls >= len(s.actions) {
return status.Error(codes.Unavailable, "server is unavailable")
func (s *stubRecoveryServer) Recover(context.Context, *recoverproto.RecoverMessage) (*recoverproto.RecoverResponse, error) {
if s.recoverError != nil {
return nil, s.recoverError
}
s.calls++
for _, action := range s.actions[s.calls-1] {
if err := action(stream); err != nil {
return err
}
}
return nil
return &recoverproto.RecoverResponse{}, nil
}
type stubDoer struct {
@ -437,4 +341,4 @@ func (d *stubDoer) Do(context.Context) error {
func (d *stubDoer) setDialer(grpcDialer, string) {}
func (d *stubDoer) setSecrets(func(string) ([]byte, error), []byte) {}
func (d *stubDoer) setURIs(kmsURI, storageURI string) {}