AB#1903 Push keys to restarting nodes on trigger RPC

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-04-11 15:28:41 +02:00 committed by Daniel Weiße
parent 152e3985f7
commit 37aff14cab
9 changed files with 193 additions and 24 deletions

View File

@ -54,7 +54,7 @@ func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, o
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core} vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
zapLoggerPubapi := zapLoggerCore.Named("pubapi") zapLoggerPubapi := zapLoggerCore.Named("pubapi")
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr) papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
zapLoggergRPC := zapLoggerPubapi.Named("gRPC") zapLoggergRPC := zapLoggerPubapi.Named("gRPC")

View File

@ -211,7 +211,7 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdial
} }
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer} vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer}
papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr) papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr, nil)
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}) tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
require.NoError(err) require.NoError(err)

View File

@ -145,7 +145,7 @@ func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) {
} }
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer} vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer}
papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr) papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr, nil)
return core, papi, nil return core, papi, nil
} }

View File

@ -7,10 +7,12 @@ import (
"net" "net"
"time" "time"
"github.com/edgelesssys/constellation/coordinator/config"
"github.com/edgelesssys/constellation/coordinator/peer" "github.com/edgelesssys/constellation/coordinator/peer"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -185,17 +187,31 @@ func (a *API) ActivateAdditionalNodes(in *pubproto.ActivateAdditionalNodesReques
// RequestStateDiskKey triggers the Coordinator to return a key derived from the Constellation's master secret to the caller. // RequestStateDiskKey triggers the Coordinator to return a key derived from the Constellation's master secret to the caller.
func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestStateDiskKeyRequest) (*pubproto.RequestStateDiskKeyResponse, error) { func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestStateDiskKeyRequest) (*pubproto.RequestStateDiskKeyResponse, error) {
// TODO: Add Coordinator call to restarting node and deliver the key if err := a.core.RequireState(state.ActivatingNodes); err != nil {
/* return nil, status.Errorf(codes.FailedPrecondition, "%v", err)
if err := a.core.RequireState(state.IsNode, state.ActivatingNodes); err != nil { }
return nil, err key, err := a.core.GetDataKey(ctx, in.DiskUuid, config.RNGLengthDefault)
} if err != nil {
_, err := a.core.GetDataKey(ctx, in.DiskUuid, 32) return nil, status.Errorf(codes.Internal, "unable to load key: %v", err)
if err != nil { }
return nil, status.Errorf(codes.Internal, "")
} peer, err := a.peerFromContext(ctx)
*/ if err != nil {
return &pubproto.RequestStateDiskKeyResponse{}, errors.New("unimplemented") return nil, status.Errorf(codes.Internal, "%v", err)
}
conn, err := a.dial(ctx, peer)
if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
defer conn.Close()
client := keyproto.NewAPIClient(conn)
if _, err := client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{StateDiskKey: key}); err != nil {
return nil, status.Errorf(codes.Internal, "pushing key to peer %q: %v", peer, err)
}
return &pubproto.RequestStateDiskKeyResponse{}, nil
} }
func (a *API) activateNodes(logToCLI logFunc, nodePublicIPs []string) error { func (a *API) activateNodes(logToCLI logFunc, nodePublicIPs []string) error {

View File

@ -2,6 +2,7 @@ package pubapi
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"io" "io"
"net" "net"
@ -9,6 +10,8 @@ import (
"testing" "testing"
"github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/coordinator/kms" "github.com/edgelesssys/constellation/coordinator/kms"
"github.com/edgelesssys/constellation/coordinator/oid" "github.com/edgelesssys/constellation/coordinator/oid"
"github.com/edgelesssys/constellation/coordinator/peer" "github.com/edgelesssys/constellation/coordinator/peer"
@ -16,11 +19,13 @@ import (
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest" "go.uber.org/zap/zaptest"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
grpcpeer "google.golang.org/grpc/peer"
) )
func TestActivateAsCoordinator(t *testing.T) { func TestActivateAsCoordinator(t *testing.T) {
@ -126,7 +131,7 @@ func TestActivateAsCoordinator(t *testing.T) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr) api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
// spawn nodes // spawn nodes
@ -257,7 +262,7 @@ func TestActivateAdditionalNodes(t *testing.T) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr) api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
// spawn nodes // spawn nodes
var nodePublicIPs []string var nodePublicIPs []string
@ -306,7 +311,7 @@ func TestAssemblePeerStruct(t *testing.T) {
vpnPubKey := []byte{2, 3, 4} vpnPubKey := []byte{2, 3, 4}
core := &fakeCore{vpnPubKey: vpnPubKey} core := &fakeCore{vpnPubKey: vpnPubKey}
api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr) api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
vpnIP, err := core.GetVPNIP() vpnIP, err := core.GetVPNIP()
@ -433,3 +438,107 @@ func (s *stubActivateAdditionalNodesServer) Send(req *pubproto.ActivateAdditiona
s.sent = append(s.sent, req) s.sent = append(s.sent, req)
return nil return nil
} }
func TestRequestStateDiskKey(t *testing.T) {
defaultKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
someErr := errors.New("error")
testCases := map[string]struct {
state state.State
dataKey []byte
getDataKeyErr error
pushKeyErr error
errExpected bool
}{
"success": {
state: state.ActivatingNodes,
dataKey: defaultKey,
},
"Coordinator in wrong state": {
state: state.IsNode,
dataKey: defaultKey,
errExpected: true,
},
"GetDataKey fails": {
state: state.ActivatingNodes,
dataKey: defaultKey,
getDataKeyErr: someErr,
errExpected: true,
},
"key pushing fails": {
state: state.ActivatingNodes,
dataKey: defaultKey,
pushKeyErr: someErr,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
issuer := core.NewMockIssuer()
stateDiskServer := &stubStateDiskServer{pushKeyErr: tc.pushKeyErr}
// we can not use a bufconn here, since we rely on grpcpeer.FromContext() to connect to the caller
listener, err := net.Listen("tcp", ":")
require.NoError(err)
defer listener.Close()
tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer)
require.NoError(err)
s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
keyproto.RegisterAPIServer(s, stateDiskServer)
defer s.GracefulStop()
go s.Serve(listener)
ctx := grpcpeer.NewContext(context.Background(), &grpcpeer.Peer{Addr: listener.Addr()})
getPeerFromContext := func(ctx context.Context) (string, error) {
peer, ok := grpcpeer.FromContext(ctx)
if !ok {
return "", errors.New("unable to get peer from context")
}
return peer.Addr.String(), nil
}
core := &fakeCore{
state: tc.state,
dataKey: tc.dataKey,
getDataKeyErr: tc.getDataKeyErr,
}
api := New(zaptest.NewLogger(t), core, &net.Dialer{}, nil, dummyValidator{}, nil, getPeerFromContext)
_, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{})
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.dataKey, stateDiskServer.receivedRequest.StateDiskKey)
}
})
}
}
type dummyValidator struct {
oid.Dummy
}
func (d dummyValidator) Validate(attdoc []byte, nonce []byte) ([]byte, error) {
var attestation vtpm.AttestationDocument
if err := json.Unmarshal(attdoc, &attestation); err != nil {
return nil, err
}
return attestation.UserData, nil
}
type stubStateDiskServer struct {
receivedRequest *keyproto.PushStateDiskKeyRequest
pushKeyErr error
keyproto.UnimplementedAPIServer
}
func (s *stubStateDiskServer) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
s.receivedRequest = in
return &keyproto.PushStateDiskKeyResponse{}, s.pushKeyErr
}

View File

@ -66,7 +66,7 @@ func TestActivateAsCoordinators(t *testing.T) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr) api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
// spawn coordinator // spawn coordinator
@ -133,7 +133,7 @@ func TestTriggerCoordinatorUpdate(t *testing.T) {
} }
dialer := testdialer.NewBufconnDialer() dialer := testdialer.NewBufconnDialer()
api := New(logger, core, dialer, nil, nil, nil) api := New(logger, core, dialer, nil, nil, nil, nil)
_, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{}) _, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{})
if tc.expectErr { if tc.expectErr {
@ -202,7 +202,7 @@ func TestActivateAdditionalCoordinators(t *testing.T) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr) api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
// spawn coordinator // spawn coordinator

View File

@ -129,7 +129,7 @@ func TestActivateAsNode(t *testing.T) {
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr} cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr}
dialer := testdialer.NewBufconnDialer() dialer := testdialer.NewBufconnDialer()
api := New(logger, cor, dialer, nil, nil, nil) api := New(logger, cor, dialer, nil, nil, nil, nil)
defer api.Close() defer api.Close()
vserver := grpc.NewServer() vserver := grpc.NewServer()
@ -217,7 +217,7 @@ func TestTriggerNodeUpdate(t *testing.T) {
core := &fakeCore{state: tc.state} core := &fakeCore{state: tc.state}
dialer := testdialer.NewBufconnDialer() dialer := testdialer.NewBufconnDialer()
api := New(logger, core, dialer, nil, nil, nil) api := New(logger, core, dialer, nil, nil, nil, nil)
vserver := grpc.NewServer() vserver := grpc.NewServer()
vapi := &stubVPNAPI{ vapi := &stubVPNAPI{
@ -292,7 +292,7 @@ func TestJoinCluster(t *testing.T) {
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr} core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
dialer := testdialer.NewBufconnDialer() dialer := testdialer.NewBufconnDialer()
api := New(logger, core, dialer, nil, nil, nil) api := New(logger, core, dialer, nil, nil, nil, nil)
vserver := grpc.NewServer() vserver := grpc.NewServer()
vapi := &stubVPNAPI{ vapi := &stubVPNAPI{

View File

@ -3,16 +3,19 @@ package pubapi
import ( import (
"context" "context"
"errors"
"net" "net"
"sync" "sync"
"time" "time"
"github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/state/setup"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/peer"
) )
const ( const (
@ -34,11 +37,12 @@ type API struct {
stopUpdate chan struct{} stopUpdate chan struct{}
wgClose sync.WaitGroup wgClose sync.WaitGroup
resourceVersion int resourceVersion int
peerFromContext PeerFromContextFunc
pubproto.UnimplementedAPIServer pubproto.UnimplementedAPIServer
} }
// New creates a new API. // New creates a new API.
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc) *API { func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API {
return &API{ return &API{
logger: logger, logger: logger,
core: core, core: core,
@ -47,6 +51,7 @@ func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer
validator: validator, validator: validator,
getPublicIPAddr: getPublicIPAddr, getPublicIPAddr: getPublicIPAddr,
stopUpdate: make(chan struct{}, 1), stopUpdate: make(chan struct{}, 1),
peerFromContext: peerFromContext,
} }
} }
@ -112,3 +117,21 @@ type VPNAPIServer interface {
} }
type GetIPAddrFunc func() (string, error) type GetIPAddrFunc func() (string, error)
// PeerFromContextFunc returns a peer endpoint (IP:port) from a given context.
type PeerFromContextFunc func(context.Context) (string, error)
// GetRecoveryPeerFromContext returns the context's IP joined with the Coordinator's default port.
func GetRecoveryPeerFromContext(ctx context.Context) (string, error) {
peer, ok := peer.FromContext(ctx)
if !ok {
return "", errors.New("unable to get peer from context")
}
peerIP, _, err := net.SplitHostPort(peer.Addr.String())
if err != nil {
return "", err
}
return net.JoinHostPort(peerIP, setup.RecoveryPort), nil
}

View File

@ -1,9 +1,13 @@
package pubapi package pubapi
import ( import (
"context"
"net"
"testing" "testing"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak" "go.uber.org/goleak"
grpcpeer "google.golang.org/grpc/peer"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -14,3 +18,20 @@ func TestMain(m *testing.M) {
goleak.IgnoreTopFunction("k8s.io/klog/v2.(*loggingT).flushDaemon"), goleak.IgnoreTopFunction("k8s.io/klog/v2.(*loggingT).flushDaemon"),
) )
} }
func TestGetRecoveryPeerFromContext(t *testing.T) {
assert := assert.New(t)
testIP := "192.0.2.1"
testPort := 1234
expectedPeer := net.JoinHostPort(testIP, "9000")
addr := &net.TCPAddr{IP: net.ParseIP(testIP), Port: testPort}
ctx := grpcpeer.NewContext(context.Background(), &grpcpeer.Peer{Addr: addr})
peer, err := GetRecoveryPeerFromContext(ctx)
assert.NoError(err)
assert.Equal(expectedPeer, peer)
_, err = GetRecoveryPeerFromContext(context.Background())
assert.Error(err)
}