mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-05-02 06:16:08 -04:00
Refactor init/recovery to use kms URI
So far the masterSecret was sent to the initial bootstrapper on init/recovery. With this commit this information is encoded in the kmsURI that is sent during init. For recover, the communication with the recoveryserver is changed. Before a streaming gRPC call was used to exchanges UUID for measurementSecret and state disk key. Now a standard gRPC is made that includes the same kmsURI & storageURI that are sent during init.
This commit is contained in:
parent
0e71322e2e
commit
9a1f52e94e
35 changed files with 466 additions and 623 deletions
|
@ -13,8 +13,10 @@ import (
|
|||
|
||||
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/atls"
|
||||
"github.com/edgelesssys/constellation/v2/internal/crypto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/grpclog"
|
||||
"github.com/edgelesssys/constellation/v2/internal/kms/kms"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
|
@ -22,6 +24,8 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type kmsFactory func(ctx context.Context, storageURI string, kmsURI string) (kms.CloudKMS, error)
|
||||
|
||||
// RecoveryServer is a gRPC server that can be used by an admin to recover a restarting node.
|
||||
type RecoveryServer struct {
|
||||
mux sync.Mutex
|
||||
|
@ -30,6 +34,7 @@ type RecoveryServer struct {
|
|||
stateDiskKey []byte
|
||||
measurementSecret []byte
|
||||
grpcServer server
|
||||
factory kmsFactory
|
||||
|
||||
log *logger.Logger
|
||||
|
||||
|
@ -37,9 +42,10 @@ type RecoveryServer struct {
|
|||
}
|
||||
|
||||
// New returns a new RecoveryServer.
|
||||
func New(issuer atls.Issuer, log *logger.Logger) *RecoveryServer {
|
||||
func New(issuer atls.Issuer, factory kmsFactory, log *logger.Logger) *RecoveryServer {
|
||||
server := &RecoveryServer{
|
||||
log: log,
|
||||
log: log,
|
||||
factory: factory,
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
|
@ -87,47 +93,32 @@ func (s *RecoveryServer) Serve(ctx context.Context, listener net.Listener, diskU
|
|||
}
|
||||
|
||||
// Recover is a bidirectional streaming RPC that is used to send recovery keys to a restarting node.
|
||||
func (s *RecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
|
||||
func (s *RecoveryServer) Recover(ctx context.Context, req *recoverproto.RecoverMessage) (*recoverproto.RecoverResponse, error) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
log := s.log.With(zap.String("peer", grpclog.PeerAddrFromContext(stream.Context())))
|
||||
log := s.log.With(zap.String("peer", grpclog.PeerAddrFromContext(ctx)))
|
||||
|
||||
log.Infof("Received recover call")
|
||||
|
||||
msg, err := stream.Recv()
|
||||
cloudKms, err := s.factory(ctx, req.StorageUri, req.KmsUri)
|
||||
if err != nil {
|
||||
return status.Error(codes.Internal, "failed to receive message")
|
||||
return nil, status.Errorf(codes.Internal, "creating kms client: %s", err)
|
||||
}
|
||||
|
||||
measurementSecret, ok := msg.GetRequest().(*recoverproto.RecoverMessage_MeasurementSecret)
|
||||
if !ok {
|
||||
log.Errorf("Received invalid first message: not a measurement secret")
|
||||
return status.Error(codes.InvalidArgument, "first message is not a measurement secret")
|
||||
}
|
||||
|
||||
if err := stream.Send(&recoverproto.RecoverResponse{DiskUuid: s.diskUUID}); err != nil {
|
||||
log.With(zap.Error(err)).Errorf("Failed to send disk UUID")
|
||||
return status.Error(codes.Internal, "failed to send response")
|
||||
}
|
||||
|
||||
msg, err = stream.Recv()
|
||||
measurementSecret, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+crypto.MeasurementSecretKeyID, crypto.DerivedKeyLengthDefault)
|
||||
if err != nil {
|
||||
log.With(zap.Error(err)).Errorf("Failed to receive disk key")
|
||||
return status.Error(codes.Internal, "failed to receive message")
|
||||
return nil, status.Errorf(codes.Internal, "requesting measurementSecret: %s", err)
|
||||
}
|
||||
|
||||
stateDiskKey, ok := msg.GetRequest().(*recoverproto.RecoverMessage_StateDiskKey)
|
||||
if !ok {
|
||||
log.Errorf("Received invalid second message: not a state disk key")
|
||||
return status.Error(codes.InvalidArgument, "second message is not a state disk key")
|
||||
stateDiskKey, err := cloudKms.GetDEK(ctx, crypto.DEKPrefix+s.diskUUID, crypto.StateDiskKeyLength)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "requesting stateDiskKey: %s", err)
|
||||
}
|
||||
|
||||
s.stateDiskKey = stateDiskKey.StateDiskKey
|
||||
s.measurementSecret = measurementSecret.MeasurementSecret
|
||||
s.stateDiskKey = stateDiskKey
|
||||
s.measurementSecret = measurementSecret
|
||||
log.Infof("Received state disk key and measurement secret, shutting down server")
|
||||
|
||||
go s.grpcServer.GracefulStop()
|
||||
return nil
|
||||
return &recoverproto.RecoverResponse{}, nil
|
||||
}
|
||||
|
||||
// StubServer implements the RecoveryServer interface but does not actually start a server.
|
||||
|
|
|
@ -8,7 +8,7 @@ package recoveryserver
|
|||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/edgelesssys/constellation/v2/internal/atls"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
|
||||
"github.com/edgelesssys/constellation/v2/internal/kms/kms"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/edgelesssys/constellation/v2/internal/oid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -25,14 +26,17 @@ import (
|
|||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
goleak.VerifyTestMain(m,
|
||||
// https://github.com/census-instrumentation/opencensus-go/issues/1262
|
||||
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestServe(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
log := logger.NewTest(t)
|
||||
uuid := "uuid"
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), log)
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), newStubKMS(nil, nil), log)
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
listener := dialer.GetListener("192.0.2.1:1234")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -49,7 +53,7 @@ func TestServe(t *testing.T) {
|
|||
cancel()
|
||||
wg.Wait()
|
||||
|
||||
server = New(atls.NewFakeIssuer(oid.Dummy{}), log)
|
||||
server = New(atls.NewFakeIssuer(oid.Dummy{}), newStubKMS(nil, nil), log)
|
||||
dialer = testdialer.NewBufconnDialer()
|
||||
listener = dialer.GetListener("192.0.2.1:1234")
|
||||
|
||||
|
@ -71,59 +75,26 @@ func TestServe(t *testing.T) {
|
|||
|
||||
func TestRecover(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
initialMsg message
|
||||
keyMsg message
|
||||
kmsURI string
|
||||
storageURI string
|
||||
factory kmsFactory
|
||||
wantErr bool
|
||||
}{
|
||||
"success": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
},
|
||||
// base64 encoded: key=masterkey&salt=somesalt
|
||||
kmsURI: "kms://cluster-kms?key=bWFzdGVya2V5&salt=c29tZXNhbHQ=",
|
||||
storageURI: "storage://no-store",
|
||||
factory: newStubKMS(nil, nil),
|
||||
},
|
||||
"first message is not a measurement secret": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_StateDiskKey{
|
||||
StateDiskKey: []byte("diskKey"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"kms init fails": {
|
||||
factory: newStubKMS(errors.New("setup failed"), nil),
|
||||
wantErr: true,
|
||||
},
|
||||
"second message is not a state disk key": {
|
||||
initialMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
},
|
||||
keyMsg: message{
|
||||
recoverMsg: &recoverproto.RecoverMessage{
|
||||
Request: &recoverproto.RecoverMessage_MeasurementSecret{
|
||||
MeasurementSecret: []byte("measurementSecret"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"GetDEK fails": {
|
||||
kmsURI: "kms://cluster-kms?key=bWFzdGVya2V5&salt=c29tZXNhbHQ=",
|
||||
storageURI: "storage://no-store",
|
||||
factory: newStubKMS(nil, errors.New("GetDEK failed")),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -134,7 +105,7 @@ func TestRecover(t *testing.T) {
|
|||
|
||||
ctx := context.Background()
|
||||
serverUUID := "uuid"
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), logger.NewTest(t))
|
||||
server := New(atls.NewFakeIssuer(oid.Dummy{}), tc.factory, logger.NewTest(t))
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
listener := netDialer.GetListener("192.0.2.1:1234")
|
||||
|
||||
|
@ -154,41 +125,46 @@ func TestRecover(t *testing.T) {
|
|||
conn, err := dialer.New(nil, nil, netDialer).Dial(ctx, "192.0.2.1:1234")
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
client, err := recoverproto.NewAPIClient(conn).Recover(ctx)
|
||||
require.NoError(err)
|
||||
|
||||
// Send initial message
|
||||
err = client.Send(tc.initialMsg.recoverMsg)
|
||||
require.NoError(err)
|
||||
req := recoverproto.RecoverMessage{
|
||||
KmsUri: tc.kmsURI,
|
||||
StorageUri: tc.storageURI,
|
||||
}
|
||||
_, err = recoverproto.NewAPIClient(conn).Recover(ctx, &req)
|
||||
|
||||
// Receive uuid
|
||||
uuid, err := client.Recv()
|
||||
if tc.initialMsg.wantErr {
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.Equal(serverUUID, uuid.DiskUuid)
|
||||
|
||||
// Send key message
|
||||
err = client.Send(tc.keyMsg.recoverMsg)
|
||||
require.NoError(err)
|
||||
|
||||
_, err = client.Recv()
|
||||
if tc.keyMsg.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.ErrorIs(io.EOF, err)
|
||||
|
||||
wg.Wait()
|
||||
assert.NoError(serveErr)
|
||||
assert.Equal(tc.initialMsg.recoverMsg.GetMeasurementSecret(), measurementSecret)
|
||||
assert.Equal(tc.keyMsg.recoverMsg.GetStateDiskKey(), diskKey)
|
||||
require.NoError(serveErr)
|
||||
assert.NoError(err)
|
||||
assert.NotNil(measurementSecret)
|
||||
assert.NotNil(diskKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type message struct {
|
||||
recoverMsg *recoverproto.RecoverMessage
|
||||
wantErr bool
|
||||
func newStubKMS(setupErr, getDEKErr error) kmsFactory {
|
||||
return func(ctx context.Context, storageURI string, kmsURI string) (kms.CloudKMS, error) {
|
||||
if setupErr != nil {
|
||||
return nil, setupErr
|
||||
}
|
||||
return &stubKMS{getDEKErr: getDEKErr}, nil
|
||||
}
|
||||
}
|
||||
|
||||
type stubKMS struct {
|
||||
getDEKErr error
|
||||
}
|
||||
|
||||
func (s *stubKMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubKMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
|
||||
if s.getDEKErr != nil {
|
||||
return nil, s.getDEKErr
|
||||
}
|
||||
return []byte("someDEK"), nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue