mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-11 23:49:30 -05:00
AB#1903 Push keys to restarting nodes on trigger RPC
Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
parent
152e3985f7
commit
37aff14cab
@ -54,7 +54,7 @@ func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, o
|
||||
|
||||
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
|
||||
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
|
||||
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr)
|
||||
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
|
||||
|
||||
zapLoggergRPC := zapLoggerPubapi.Named("gRPC")
|
||||
|
||||
|
@ -211,7 +211,7 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdial
|
||||
}
|
||||
|
||||
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer}
|
||||
papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr)
|
||||
papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr, nil)
|
||||
|
||||
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
||||
require.NoError(err)
|
||||
|
@ -145,7 +145,7 @@ func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) {
|
||||
}
|
||||
|
||||
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer}
|
||||
papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr)
|
||||
papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr, nil)
|
||||
|
||||
return core, papi, nil
|
||||
}
|
||||
|
@ -7,10 +7,12 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/config"
|
||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/role"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
@ -185,17 +187,31 @@ func (a *API) ActivateAdditionalNodes(in *pubproto.ActivateAdditionalNodesReques
|
||||
|
||||
// RequestStateDiskKey triggers the Coordinator to return a key derived from the Constellation's master secret to the caller.
|
||||
func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestStateDiskKeyRequest) (*pubproto.RequestStateDiskKeyResponse, error) {
|
||||
// TODO: Add Coordinator call to restarting node and deliver the key
|
||||
/*
|
||||
if err := a.core.RequireState(state.IsNode, state.ActivatingNodes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err := a.core.GetDataKey(ctx, in.DiskUuid, 32)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "")
|
||||
}
|
||||
*/
|
||||
return &pubproto.RequestStateDiskKeyResponse{}, errors.New("unimplemented")
|
||||
if err := a.core.RequireState(state.ActivatingNodes); err != nil {
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "%v", err)
|
||||
}
|
||||
key, err := a.core.GetDataKey(ctx, in.DiskUuid, config.RNGLengthDefault)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "unable to load key: %v", err)
|
||||
}
|
||||
|
||||
peer, err := a.peerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||
}
|
||||
|
||||
conn, err := a.dial(ctx, peer)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := keyproto.NewAPIClient(conn)
|
||||
if _, err := client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{StateDiskKey: key}); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "pushing key to peer %q: %v", peer, err)
|
||||
}
|
||||
|
||||
return &pubproto.RequestStateDiskKeyResponse{}, nil
|
||||
}
|
||||
|
||||
func (a *API) activateNodes(logToCLI logFunc, nodePublicIPs []string) error {
|
||||
|
@ -2,6 +2,7 @@ package pubapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
@ -9,6 +10,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/coordinator/core"
|
||||
"github.com/edgelesssys/constellation/coordinator/kms"
|
||||
"github.com/edgelesssys/constellation/coordinator/oid"
|
||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||
@ -16,11 +19,13 @@ import (
|
||||
"github.com/edgelesssys/constellation/coordinator/role"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/zap/zaptest"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
grpcpeer "google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
func TestActivateAsCoordinator(t *testing.T) {
|
||||
@ -126,7 +131,7 @@ func TestActivateAsCoordinator(t *testing.T) {
|
||||
return "192.0.2.1", nil
|
||||
}
|
||||
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr)
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
|
||||
// spawn nodes
|
||||
@ -257,7 +262,7 @@ func TestActivateAdditionalNodes(t *testing.T) {
|
||||
return "192.0.2.1", nil
|
||||
}
|
||||
|
||||
api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr)
|
||||
api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
// spawn nodes
|
||||
var nodePublicIPs []string
|
||||
@ -306,7 +311,7 @@ func TestAssemblePeerStruct(t *testing.T) {
|
||||
|
||||
vpnPubKey := []byte{2, 3, 4}
|
||||
core := &fakeCore{vpnPubKey: vpnPubKey}
|
||||
api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr)
|
||||
api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
|
||||
vpnIP, err := core.GetVPNIP()
|
||||
@ -433,3 +438,107 @@ func (s *stubActivateAdditionalNodesServer) Send(req *pubproto.ActivateAdditiona
|
||||
s.sent = append(s.sent, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRequestStateDiskKey(t *testing.T) {
|
||||
defaultKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
|
||||
someErr := errors.New("error")
|
||||
testCases := map[string]struct {
|
||||
state state.State
|
||||
dataKey []byte
|
||||
getDataKeyErr error
|
||||
pushKeyErr error
|
||||
errExpected bool
|
||||
}{
|
||||
"success": {
|
||||
state: state.ActivatingNodes,
|
||||
dataKey: defaultKey,
|
||||
},
|
||||
"Coordinator in wrong state": {
|
||||
state: state.IsNode,
|
||||
dataKey: defaultKey,
|
||||
errExpected: true,
|
||||
},
|
||||
"GetDataKey fails": {
|
||||
state: state.ActivatingNodes,
|
||||
dataKey: defaultKey,
|
||||
getDataKeyErr: someErr,
|
||||
errExpected: true,
|
||||
},
|
||||
"key pushing fails": {
|
||||
state: state.ActivatingNodes,
|
||||
dataKey: defaultKey,
|
||||
pushKeyErr: someErr,
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
issuer := core.NewMockIssuer()
|
||||
|
||||
stateDiskServer := &stubStateDiskServer{pushKeyErr: tc.pushKeyErr}
|
||||
|
||||
// we can not use a bufconn here, since we rely on grpcpeer.FromContext() to connect to the caller
|
||||
listener, err := net.Listen("tcp", ":")
|
||||
require.NoError(err)
|
||||
defer listener.Close()
|
||||
|
||||
tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer)
|
||||
require.NoError(err)
|
||||
s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
keyproto.RegisterAPIServer(s, stateDiskServer)
|
||||
defer s.GracefulStop()
|
||||
go s.Serve(listener)
|
||||
|
||||
ctx := grpcpeer.NewContext(context.Background(), &grpcpeer.Peer{Addr: listener.Addr()})
|
||||
getPeerFromContext := func(ctx context.Context) (string, error) {
|
||||
peer, ok := grpcpeer.FromContext(ctx)
|
||||
if !ok {
|
||||
return "", errors.New("unable to get peer from context")
|
||||
}
|
||||
return peer.Addr.String(), nil
|
||||
}
|
||||
|
||||
core := &fakeCore{
|
||||
state: tc.state,
|
||||
dataKey: tc.dataKey,
|
||||
getDataKeyErr: tc.getDataKeyErr,
|
||||
}
|
||||
api := New(zaptest.NewLogger(t), core, &net.Dialer{}, nil, dummyValidator{}, nil, getPeerFromContext)
|
||||
|
||||
_, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{})
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.dataKey, stateDiskServer.receivedRequest.StateDiskKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type dummyValidator struct {
|
||||
oid.Dummy
|
||||
}
|
||||
|
||||
func (d dummyValidator) Validate(attdoc []byte, nonce []byte) ([]byte, error) {
|
||||
var attestation vtpm.AttestationDocument
|
||||
if err := json.Unmarshal(attdoc, &attestation); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return attestation.UserData, nil
|
||||
}
|
||||
|
||||
type stubStateDiskServer struct {
|
||||
receivedRequest *keyproto.PushStateDiskKeyRequest
|
||||
pushKeyErr error
|
||||
keyproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
func (s *stubStateDiskServer) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) {
|
||||
s.receivedRequest = in
|
||||
return &keyproto.PushStateDiskKeyResponse{}, s.pushKeyErr
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func TestActivateAsCoordinators(t *testing.T) {
|
||||
return "192.0.2.1", nil
|
||||
}
|
||||
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr)
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
|
||||
// spawn coordinator
|
||||
@ -133,7 +133,7 @@ func TestTriggerCoordinatorUpdate(t *testing.T) {
|
||||
}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
|
||||
api := New(logger, core, dialer, nil, nil, nil)
|
||||
api := New(logger, core, dialer, nil, nil, nil, nil)
|
||||
|
||||
_, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{})
|
||||
if tc.expectErr {
|
||||
@ -202,7 +202,7 @@ func TestActivateAdditionalCoordinators(t *testing.T) {
|
||||
return "192.0.2.1", nil
|
||||
}
|
||||
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr)
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
|
||||
// spawn coordinator
|
||||
|
@ -129,7 +129,7 @@ func TestActivateAsNode(t *testing.T) {
|
||||
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
|
||||
api := New(logger, cor, dialer, nil, nil, nil)
|
||||
api := New(logger, cor, dialer, nil, nil, nil, nil)
|
||||
defer api.Close()
|
||||
|
||||
vserver := grpc.NewServer()
|
||||
@ -217,7 +217,7 @@ func TestTriggerNodeUpdate(t *testing.T) {
|
||||
core := &fakeCore{state: tc.state}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
|
||||
api := New(logger, core, dialer, nil, nil, nil)
|
||||
api := New(logger, core, dialer, nil, nil, nil, nil)
|
||||
|
||||
vserver := grpc.NewServer()
|
||||
vapi := &stubVPNAPI{
|
||||
@ -292,7 +292,7 @@ func TestJoinCluster(t *testing.T) {
|
||||
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
|
||||
api := New(logger, core, dialer, nil, nil, nil)
|
||||
api := New(logger, core, dialer, nil, nil, nil, nil)
|
||||
|
||||
vserver := grpc.NewServer()
|
||||
vapi := &stubVPNAPI{
|
||||
|
@ -3,16 +3,19 @@ package pubapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/state/setup"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -34,11 +37,12 @@ type API struct {
|
||||
stopUpdate chan struct{}
|
||||
wgClose sync.WaitGroup
|
||||
resourceVersion int
|
||||
peerFromContext PeerFromContextFunc
|
||||
pubproto.UnimplementedAPIServer
|
||||
}
|
||||
|
||||
// New creates a new API.
|
||||
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc) *API {
|
||||
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
core: core,
|
||||
@ -47,6 +51,7 @@ func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer
|
||||
validator: validator,
|
||||
getPublicIPAddr: getPublicIPAddr,
|
||||
stopUpdate: make(chan struct{}, 1),
|
||||
peerFromContext: peerFromContext,
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,3 +117,21 @@ type VPNAPIServer interface {
|
||||
}
|
||||
|
||||
type GetIPAddrFunc func() (string, error)
|
||||
|
||||
// PeerFromContextFunc returns a peer endpoint (IP:port) from a given context.
|
||||
type PeerFromContextFunc func(context.Context) (string, error)
|
||||
|
||||
// GetRecoveryPeerFromContext returns the context's IP joined with the Coordinator's default port.
|
||||
func GetRecoveryPeerFromContext(ctx context.Context) (string, error) {
|
||||
peer, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return "", errors.New("unable to get peer from context")
|
||||
}
|
||||
|
||||
peerIP, _, err := net.SplitHostPort(peer.Addr.String())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return net.JoinHostPort(peerIP, setup.RecoveryPort), nil
|
||||
}
|
||||
|
@ -1,9 +1,13 @@
|
||||
package pubapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/goleak"
|
||||
grpcpeer "google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@ -14,3 +18,20 @@ func TestMain(m *testing.M) {
|
||||
goleak.IgnoreTopFunction("k8s.io/klog/v2.(*loggingT).flushDaemon"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestGetRecoveryPeerFromContext(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
testIP := "192.0.2.1"
|
||||
testPort := 1234
|
||||
expectedPeer := net.JoinHostPort(testIP, "9000")
|
||||
|
||||
addr := &net.TCPAddr{IP: net.ParseIP(testIP), Port: testPort}
|
||||
ctx := grpcpeer.NewContext(context.Background(), &grpcpeer.Peer{Addr: addr})
|
||||
|
||||
peer, err := GetRecoveryPeerFromContext(ctx)
|
||||
assert.NoError(err)
|
||||
assert.Equal(expectedPeer, peer)
|
||||
|
||||
_, err = GetRecoveryPeerFromContext(context.Background())
|
||||
assert.Error(err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user