mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-09-22 14:04:53 -04:00
Refactor init/recovery to use kms URI
So far the masterSecret was sent to the initial bootstrapper on init/recovery. With this commit this information is encoded in the kmsURI that is sent during init. For recover, the communication with the recoveryserver is changed. Before a streaming gRPC call was used to exchanges UUID for measurementSecret and state disk key. Now a standard gRPC is made that includes the same kmsURI & storageURI that are sent during init.
This commit is contained in:
parent
0e71322e2e
commit
9a1f52e94e
35 changed files with 466 additions and 623 deletions
|
@ -30,7 +30,7 @@ import (
|
|||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||||
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
|
||||
keyservice "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
"github.com/edgelesssys/constellation/v2/internal/license"
|
||||
"github.com/edgelesssys/constellation/v2/internal/retry"
|
||||
"github.com/edgelesssys/constellation/v2/internal/versions"
|
||||
|
@ -146,8 +146,8 @@ func (i *initCmd) initialize(cmd *cobra.Command, newDialer func(validator *cloud
|
|||
req := &initproto.InitRequest{
|
||||
MasterSecret: masterSecret.Key,
|
||||
Salt: masterSecret.Salt,
|
||||
KmsUri: keyservice.ClusterKMSURI,
|
||||
StorageUri: keyservice.NoStoreURI,
|
||||
KmsUri: masterSecret.EncodeToURI(),
|
||||
StorageUri: kmssetup.NoStoreURI,
|
||||
KeyEncryptionKeyId: "",
|
||||
UseExistingKek: false,
|
||||
CloudServiceAccountUri: serviceAccURI,
|
||||
|
@ -296,26 +296,20 @@ type initFlags struct {
|
|||
conformance bool
|
||||
}
|
||||
|
||||
// masterSecret holds the master key and salt for deriving keys.
|
||||
type masterSecret struct {
|
||||
Key []byte `json:"key"`
|
||||
Salt []byte `json:"salt"`
|
||||
}
|
||||
|
||||
// readOrGenerateMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
|
||||
func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler file.Handler, filename string) (masterSecret, error) {
|
||||
func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler file.Handler, filename string) (kmssetup.MasterSecret, error) {
|
||||
if filename != "" {
|
||||
i.log.Debugf("Reading master secret from file %q", filename)
|
||||
var secret masterSecret
|
||||
var secret kmssetup.MasterSecret
|
||||
if err := fileHandler.ReadJSON(filename, &secret); err != nil {
|
||||
return masterSecret{}, err
|
||||
return kmssetup.MasterSecret{}, err
|
||||
}
|
||||
|
||||
if len(secret.Key) < crypto.MasterSecretLengthMin {
|
||||
return masterSecret{}, fmt.Errorf("provided master secret is smaller than the required minimum of %d Bytes", crypto.MasterSecretLengthMin)
|
||||
return kmssetup.MasterSecret{}, fmt.Errorf("provided master secret is smaller than the required minimum of %d Bytes", crypto.MasterSecretLengthMin)
|
||||
}
|
||||
if len(secret.Salt) < crypto.RNGLengthDefault {
|
||||
return masterSecret{}, fmt.Errorf("provided salt is smaller than the required minimum of %d Bytes", crypto.RNGLengthDefault)
|
||||
return kmssetup.MasterSecret{}, fmt.Errorf("provided salt is smaller than the required minimum of %d Bytes", crypto.RNGLengthDefault)
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
|
@ -324,19 +318,19 @@ func (i *initCmd) readOrGenerateMasterSecret(outWriter io.Writer, fileHandler fi
|
|||
i.log.Debugf("Generating new master secret")
|
||||
key, err := crypto.GenerateRandomBytes(crypto.MasterSecretLengthDefault)
|
||||
if err != nil {
|
||||
return masterSecret{}, err
|
||||
return kmssetup.MasterSecret{}, err
|
||||
}
|
||||
salt, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
|
||||
if err != nil {
|
||||
return masterSecret{}, err
|
||||
return kmssetup.MasterSecret{}, err
|
||||
}
|
||||
secret := masterSecret{
|
||||
secret := kmssetup.MasterSecret{
|
||||
Key: key,
|
||||
Salt: salt,
|
||||
}
|
||||
i.log.Debugf("Generated master secret key and salt values")
|
||||
if err := fileHandler.WriteJSON(constants.MasterSecretFilename, secret, file.OptNone); err != nil {
|
||||
return masterSecret{}, err
|
||||
return kmssetup.MasterSecret{}, err
|
||||
}
|
||||
fmt.Fprintf(outWriter, "Your Constellation master secret was successfully written to ./%s\n", constants.MasterSecretFilename)
|
||||
return secret, nil
|
||||
|
|
|
@ -18,6 +18,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
|
||||
|
@ -183,7 +185,7 @@ func TestInitialize(t *testing.T) {
|
|||
require.NoError(err)
|
||||
// assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID")))
|
||||
assert.Contains(out.String(), hex.EncodeToString([]byte("clusterID")))
|
||||
var secret masterSecret
|
||||
var secret kmssetup.MasterSecret
|
||||
assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret))
|
||||
assert.NotEmpty(secret.Key)
|
||||
assert.NotEmpty(secret.Salt)
|
||||
|
@ -251,7 +253,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
|
|||
createFileFunc: func(handler file.Handler) error {
|
||||
return handler.WriteJSON(
|
||||
"someSecret",
|
||||
masterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")},
|
||||
kmssetup.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")},
|
||||
file.OptNone,
|
||||
)
|
||||
},
|
||||
|
@ -282,7 +284,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
|
|||
createFileFunc: func(handler file.Handler) error {
|
||||
return handler.WriteJSON(
|
||||
"shortSecret",
|
||||
masterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("short")},
|
||||
kmssetup.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("short")},
|
||||
file.OptNone,
|
||||
)
|
||||
},
|
||||
|
@ -294,7 +296,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
|
|||
createFileFunc: func(handler file.Handler) error {
|
||||
return handler.WriteJSON(
|
||||
"shortSecret",
|
||||
masterSecret{Key: []byte("short"), Salt: []byte("constellation-32Byte-length-salt")},
|
||||
kmssetup.MasterSecret{Key: []byte("short"), Salt: []byte("constellation-32Byte-length-salt")},
|
||||
file.OptNone,
|
||||
)
|
||||
},
|
||||
|
@ -340,7 +342,7 @@ func TestReadOrGenerateMasterSecret(t *testing.T) {
|
|||
tc.filename = strings.Trim(filename[1], "\n")
|
||||
}
|
||||
|
||||
var masterSecret masterSecret
|
||||
var masterSecret kmssetup.MasterSecret
|
||||
require.NoError(fileHandler.ReadJSON(tc.filename, &masterSecret))
|
||||
assert.Equal(masterSecret.Key, secret.Key)
|
||||
assert.Equal(masterSecret.Salt, secret.Salt)
|
||||
|
|
|
@ -8,7 +8,6 @@ package cmd
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -17,7 +16,6 @@ import (
|
|||
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
|
@ -25,6 +23,7 @@ import (
|
|||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||||
grpcRetry "github.com/edgelesssys/constellation/v2/internal/grpc/retry"
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
"github.com/edgelesssys/constellation/v2/internal/retry"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -73,7 +72,7 @@ func (r *recoverCmd) recover(
|
|||
}
|
||||
r.log.Debugf("Using flags: %+v", flags)
|
||||
|
||||
var masterSecret masterSecret
|
||||
var masterSecret kmssetup.MasterSecret
|
||||
r.log.Debugf("Loading master secret file from %s", flags.secretPath)
|
||||
if err := fileHandler.ReadJSON(flags.secretPath, &masterSecret); err != nil {
|
||||
return err
|
||||
|
@ -97,12 +96,7 @@ func (r *recoverCmd) recover(
|
|||
r.log.Debugf("Created a new validator")
|
||||
doer.setDialer(newDialer(validator), flags.endpoint)
|
||||
r.log.Debugf("Set dialer for endpoint %s", flags.endpoint)
|
||||
measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret.Key, masterSecret.Salt)
|
||||
r.log.Debugf("Derived measurementSecret")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
doer.setSecrets(getStateDiskKeyFunc(masterSecret.Key, masterSecret.Salt), measurementSecret)
|
||||
doer.setURIs(masterSecret.EncodeToURI(), kmssetup.NoStoreURI)
|
||||
r.log.Debugf("Set secrets")
|
||||
if err := r.recoverCall(cmd.Context(), cmd.OutOrStdout(), interval, doer); err != nil {
|
||||
if grpcRetry.ServiceIsUnavailable(err) {
|
||||
|
@ -157,15 +151,15 @@ func (r *recoverCmd) recoverCall(ctx context.Context, out io.Writer, interval ti
|
|||
type recoverDoerInterface interface {
|
||||
Do(ctx context.Context) error
|
||||
setDialer(dialer grpcDialer, endpoint string)
|
||||
setSecrets(getDiskKey func(uuid string) ([]byte, error), measurementSecret []byte)
|
||||
setURIs(kmsURI, storageURI string)
|
||||
}
|
||||
|
||||
type recoverDoer struct {
|
||||
dialer grpcDialer
|
||||
endpoint string
|
||||
measurementSecret []byte
|
||||
getDiskKey func(uuid string) (key []byte, err error)
|
||||
log debugLog
|
||||
dialer grpcDialer
|
||||
endpoint string
|
||||
kmsURI string // encodes masterSecret
|
||||
storageURI string
|
||||
log debugLog
|
||||
}
|
||||
|
||||
// Do performs the recover streaming rpc.
|
||||
|
@ -177,53 +171,19 @@ func (d *recoverDoer) Do(ctx context.Context) (retErr error) {
|
|||
d.log.Debugf("Dialed recovery server")
|
||||
defer conn.Close()
|
||||
|
||||
// set up streaming client
|
||||
protoClient := recoverproto.NewAPIClient(conn)
|
||||
d.log.Debugf("Created protoClient")
|
||||
recoverclient, err := protoClient.Recover(ctx)
|
||||
d.log.Debugf("Created recoverclient")
|
||||
|
||||
req := &recoverproto.RecoverMessage{
|
||||
KmsUri: d.kmsURI,
|
||||
StorageUri: d.storageURI,
|
||||
}
|
||||
|
||||
_, err = protoClient.Recover(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating client: %w", err)
|
||||
return fmt.Errorf("calling recover: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = recoverclient.CloseSend()
|
||||
}()
|
||||
|
||||
// send measurement secret as first message
|
||||
if err := recoverclient.Send(&recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: d.measurementSecret,
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("sending measurement secret: %w", err)
|
||||
}
|
||||
d.log.Debugf("Sent measurement secret")
|
||||
|
||||
// receive disk uuid
|
||||
res, err := recoverclient.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receiving disk uuid: %w", err)
|
||||
}
|
||||
d.log.Debugf("Received disk uuid")
|
||||
stateDiskKey, err := d.getDiskKey(res.DiskUuid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting state disk key: %w", err)
|
||||
}
|
||||
d.log.Debugf("Got state disk key")
|
||||
|
||||
// send disk key
|
||||
if err := recoverclient.Send(&recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: stateDiskKey,
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("sending state disk key: %w", err)
|
||||
}
|
||||
d.log.Debugf("Sent state disk key")
|
||||
|
||||
if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
|
||||
return fmt.Errorf("receiving confirmation: %w", err)
|
||||
}
|
||||
d.log.Debugf("Received confirmation")
|
||||
return nil
|
||||
}
|
||||
|
@ -233,9 +193,9 @@ func (d *recoverDoer) setDialer(dialer grpcDialer, endpoint string) {
|
|||
d.endpoint = endpoint
|
||||
}
|
||||
|
||||
func (d *recoverDoer) setSecrets(getDiskKey func(string) ([]byte, error), measurementSecret []byte) {
|
||||
d.getDiskKey = getDiskKey
|
||||
d.measurementSecret = measurementSecret
|
||||
func (d *recoverDoer) setURIs(kmsURI, storageURI string) {
|
||||
d.kmsURI = kmsURI
|
||||
d.storageURI = storageURI
|
||||
}
|
||||
|
||||
type recoverFlags struct {
|
||||
|
@ -281,6 +241,6 @@ func (r *recoverCmd) parseRecoverFlags(cmd *cobra.Command, fileHandler file.Hand
|
|||
|
||||
func getStateDiskKeyFunc(masterKey, salt []byte) func(uuid string) ([]byte, error) {
|
||||
return func(uuid string) ([]byte, error) {
|
||||
return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+uuid), crypto.StateDiskKeyLength)
|
||||
return crypto.DeriveKey(masterKey, salt, []byte(crypto.DEKPrefix+uuid), crypto.StateDiskKeyLength)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
|
||||
kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -156,7 +157,7 @@ func TestRecover(t *testing.T) {
|
|||
|
||||
require.NoError(fileHandler.WriteJSON(
|
||||
"constellation-mastersecret.json",
|
||||
masterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
|
||||
kmssetup.MasterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
|
||||
file.OptNone,
|
||||
))
|
||||
|
||||
|
@ -244,101 +245,16 @@ func TestParseRecoverFlags(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDoRecovery(t *testing.T) {
|
||||
someErr := errors.New("error")
|
||||
testCases := map[string]struct {
|
||||
recoveryServer *stubRecoveryServer
|
||||
wantErr bool
|
||||
}{
|
||||
"success": {
|
||||
recoveryServer: &stubRecoveryServer{
|
||||
actions: [][]func(stream recoverproto.API_RecoverServer) error{{
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return stream.Send(&recoverproto.RecoverResponse{
|
||||
DiskUuid: "00000000-0000-0000-0000-000000000000",
|
||||
})
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
}},
|
||||
},
|
||||
recoveryServer: &stubRecoveryServer{},
|
||||
},
|
||||
"error on first recv": {
|
||||
recoveryServer: &stubRecoveryServer{
|
||||
actions: [][]func(stream recoverproto.API_RecoverServer) error{
|
||||
{
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return someErr
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"error on send": {
|
||||
recoveryServer: &stubRecoveryServer{
|
||||
actions: [][]func(stream recoverproto.API_RecoverServer) error{
|
||||
{
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return someErr
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"error on second recv": {
|
||||
recoveryServer: &stubRecoveryServer{
|
||||
actions: [][]func(stream recoverproto.API_RecoverServer) error{
|
||||
{
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return stream.Send(&recoverproto.RecoverResponse{
|
||||
DiskUuid: "00000000-0000-0000-0000-000000000000",
|
||||
})
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return someErr
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"final message is an error": {
|
||||
recoveryServer: &stubRecoveryServer{
|
||||
actions: [][]func(stream recoverproto.API_RecoverServer) error{{
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return stream.Send(&recoverproto.RecoverResponse{
|
||||
DiskUuid: "00000000-0000-0000-0000-000000000000",
|
||||
})
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
_, err := stream.Recv()
|
||||
return err
|
||||
},
|
||||
func(stream recoverproto.API_RecoverServer) error {
|
||||
return someErr
|
||||
},
|
||||
}},
|
||||
},
|
||||
wantErr: true,
|
||||
"server responds with error": {
|
||||
recoveryServer: &stubRecoveryServer{recoverError: errors.New("someErr")},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -357,13 +273,9 @@ func TestDoRecovery(t *testing.T) {
|
|||
|
||||
r := &recoverCmd{log: logger.NewTest(t)}
|
||||
recoverDoer := &recoverDoer{
|
||||
dialer: dialer.New(nil, nil, netDialer),
|
||||
endpoint: addr,
|
||||
measurementSecret: []byte("measurement-secret"),
|
||||
getDiskKey: func(string) ([]byte, error) {
|
||||
return []byte("disk-key"), nil
|
||||
},
|
||||
log: r.log,
|
||||
dialer: dialer.New(nil, nil, netDialer),
|
||||
endpoint: addr,
|
||||
log: r.log,
|
||||
}
|
||||
|
||||
err := recoverDoer.Do(context.Background())
|
||||
|
@ -402,23 +314,15 @@ func TestDeriveStateDiskKey(t *testing.T) {
|
|||
}
|
||||
|
||||
type stubRecoveryServer struct {
|
||||
actions [][]func(recoverproto.API_RecoverServer) error
|
||||
calls int
|
||||
recoverError error
|
||||
recoverproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
func (s *stubRecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
|
||||
if s.calls >= len(s.actions) {
|
||||
return status.Error(codes.Unavailable, "server is unavailable")
|
||||
func (s *stubRecoveryServer) Recover(context.Context, *recoverproto.RecoverMessage) (*recoverproto.RecoverResponse, error) {
|
||||
if s.recoverError != nil {
|
||||
return nil, s.recoverError
|
||||
}
|
||||
s.calls++
|
||||
|
||||
for _, action := range s.actions[s.calls-1] {
|
||||
if err := action(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return &recoverproto.RecoverResponse{}, nil
|
||||
}
|
||||
|
||||
type stubDoer struct {
|
||||
|
@ -437,4 +341,4 @@ func (d *stubDoer) Do(context.Context) error {
|
|||
|
||||
func (d *stubDoer) setDialer(grpcDialer, string) {}
|
||||
|
||||
func (d *stubDoer) setSecrets(func(string) ([]byte, error), []byte) {}
|
||||
func (d *stubDoer) setURIs(kmsURI, storageURI string) {}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue