diff --git a/coordinator/cmd/coordinator/run.go b/coordinator/cmd/coordinator/run.go index 882f5057a..f4c084397 100644 --- a/coordinator/cmd/coordinator/run.go +++ b/coordinator/cmd/coordinator/run.go @@ -54,7 +54,7 @@ func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, o vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core} zapLoggerPubapi := zapLoggerCore.Named("pubapi") - papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr) + papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext) zapLoggergRPC := zapLoggerPubapi.Named("gRPC") diff --git a/coordinator/coordinator_test.go b/coordinator/coordinator_test.go index 9c934c93b..d5338e06c 100644 --- a/coordinator/coordinator_test.go +++ b/coordinator/coordinator_test.go @@ -211,7 +211,7 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdial } vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer} - papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr) + papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr, nil) tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}) require.NoError(err) diff --git a/coordinator/core/legacy_test.go b/coordinator/core/legacy_test.go index 0d5789fc2..11b1ec16a 100644 --- a/coordinator/core/legacy_test.go +++ b/coordinator/core/legacy_test.go @@ -145,7 +145,7 @@ func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) { } vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer} - papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr) + papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr, nil) return core, papi, nil } diff --git a/coordinator/pubapi/coord.go b/coordinator/pubapi/coord.go index 78e9dd1e1..66502ca86 100644 --- a/coordinator/pubapi/coord.go +++ b/coordinator/pubapi/coord.go @@ -7,10 +7,12 @@ import ( "net" "time" + "github.com/edgelesssys/constellation/coordinator/config" "github.com/edgelesssys/constellation/coordinator/peer" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/state" + "github.com/edgelesssys/constellation/state/keyservice/keyproto" "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -185,17 +187,31 @@ func (a *API) ActivateAdditionalNodes(in *pubproto.ActivateAdditionalNodesReques // RequestStateDiskKey triggers the Coordinator to return a key derived from the Constellation's master secret to the caller. func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestStateDiskKeyRequest) (*pubproto.RequestStateDiskKeyResponse, error) { - // TODO: Add Coordinator call to restarting node and deliver the key - /* - if err := a.core.RequireState(state.IsNode, state.ActivatingNodes); err != nil { - return nil, err - } - _, err := a.core.GetDataKey(ctx, in.DiskUuid, 32) - if err != nil { - return nil, status.Errorf(codes.Internal, "") - } - */ - return &pubproto.RequestStateDiskKeyResponse{}, errors.New("unimplemented") + if err := a.core.RequireState(state.ActivatingNodes); err != nil { + return nil, status.Errorf(codes.FailedPrecondition, "%v", err) + } + key, err := a.core.GetDataKey(ctx, in.DiskUuid, config.RNGLengthDefault) + if err != nil { + return nil, status.Errorf(codes.Internal, "unable to load key: %v", err) + } + + peer, err := a.peerFromContext(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "%v", err) + } + + conn, err := a.dial(ctx, peer) + if err != nil { + return nil, status.Errorf(codes.Internal, "%v", err) + } + defer conn.Close() + + client := keyproto.NewAPIClient(conn) + if _, err := client.PushStateDiskKey(ctx, &keyproto.PushStateDiskKeyRequest{StateDiskKey: key}); err != nil { + return nil, status.Errorf(codes.Internal, "pushing key to peer %q: %v", peer, err) + } + + return &pubproto.RequestStateDiskKeyResponse{}, nil } func (a *API) activateNodes(logToCLI logFunc, nodePublicIPs []string) error { diff --git a/coordinator/pubapi/coord_test.go b/coordinator/pubapi/coord_test.go index ab67b710b..a99098ef4 100644 --- a/coordinator/pubapi/coord_test.go +++ b/coordinator/pubapi/coord_test.go @@ -2,6 +2,7 @@ package pubapi import ( "context" + "encoding/json" "errors" "io" "net" @@ -9,6 +10,8 @@ import ( "testing" "github.com/edgelesssys/constellation/coordinator/atls" + "github.com/edgelesssys/constellation/coordinator/attestation/vtpm" + "github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/kms" "github.com/edgelesssys/constellation/coordinator/oid" "github.com/edgelesssys/constellation/coordinator/peer" @@ -16,11 +19,13 @@ import ( "github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/util/testdialer" + "github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap/zaptest" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + grpcpeer "google.golang.org/grpc/peer" ) func TestActivateAsCoordinator(t *testing.T) { @@ -126,7 +131,7 @@ func TestActivateAsCoordinator(t *testing.T) { return "192.0.2.1", nil } - api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr) + api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) defer api.Close() // spawn nodes @@ -257,7 +262,7 @@ func TestActivateAdditionalNodes(t *testing.T) { return "192.0.2.1", nil } - api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr) + api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr, nil) defer api.Close() // spawn nodes var nodePublicIPs []string @@ -306,7 +311,7 @@ func TestAssemblePeerStruct(t *testing.T) { vpnPubKey := []byte{2, 3, 4} core := &fakeCore{vpnPubKey: vpnPubKey} - api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr) + api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr, nil) defer api.Close() vpnIP, err := core.GetVPNIP() @@ -433,3 +438,107 @@ func (s *stubActivateAdditionalNodesServer) Send(req *pubproto.ActivateAdditiona s.sent = append(s.sent, req) return nil } + +func TestRequestStateDiskKey(t *testing.T) { + defaultKey := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + someErr := errors.New("error") + testCases := map[string]struct { + state state.State + dataKey []byte + getDataKeyErr error + pushKeyErr error + errExpected bool + }{ + "success": { + state: state.ActivatingNodes, + dataKey: defaultKey, + }, + "Coordinator in wrong state": { + state: state.IsNode, + dataKey: defaultKey, + errExpected: true, + }, + "GetDataKey fails": { + state: state.ActivatingNodes, + dataKey: defaultKey, + getDataKeyErr: someErr, + errExpected: true, + }, + "key pushing fails": { + state: state.ActivatingNodes, + dataKey: defaultKey, + pushKeyErr: someErr, + errExpected: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + issuer := core.NewMockIssuer() + + stateDiskServer := &stubStateDiskServer{pushKeyErr: tc.pushKeyErr} + + // we can not use a bufconn here, since we rely on grpcpeer.FromContext() to connect to the caller + listener, err := net.Listen("tcp", ":") + require.NoError(err) + defer listener.Close() + + tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer) + require.NoError(err) + s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + keyproto.RegisterAPIServer(s, stateDiskServer) + defer s.GracefulStop() + go s.Serve(listener) + + ctx := grpcpeer.NewContext(context.Background(), &grpcpeer.Peer{Addr: listener.Addr()}) + getPeerFromContext := func(ctx context.Context) (string, error) { + peer, ok := grpcpeer.FromContext(ctx) + if !ok { + return "", errors.New("unable to get peer from context") + } + return peer.Addr.String(), nil + } + + core := &fakeCore{ + state: tc.state, + dataKey: tc.dataKey, + getDataKeyErr: tc.getDataKeyErr, + } + api := New(zaptest.NewLogger(t), core, &net.Dialer{}, nil, dummyValidator{}, nil, getPeerFromContext) + + _, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{}) + if tc.errExpected { + assert.Error(err) + } else { + assert.NoError(err) + assert.Equal(tc.dataKey, stateDiskServer.receivedRequest.StateDiskKey) + } + }) + } +} + +type dummyValidator struct { + oid.Dummy +} + +func (d dummyValidator) Validate(attdoc []byte, nonce []byte) ([]byte, error) { + var attestation vtpm.AttestationDocument + if err := json.Unmarshal(attdoc, &attestation); err != nil { + return nil, err + } + return attestation.UserData, nil +} + +type stubStateDiskServer struct { + receivedRequest *keyproto.PushStateDiskKeyRequest + pushKeyErr error + keyproto.UnimplementedAPIServer +} + +func (s *stubStateDiskServer) PushStateDiskKey(ctx context.Context, in *keyproto.PushStateDiskKeyRequest) (*keyproto.PushStateDiskKeyResponse, error) { + s.receivedRequest = in + return &keyproto.PushStateDiskKeyResponse{}, s.pushKeyErr +} diff --git a/coordinator/pubapi/multicoord_test.go b/coordinator/pubapi/multicoord_test.go index f8d832228..d0fd1f3c0 100644 --- a/coordinator/pubapi/multicoord_test.go +++ b/coordinator/pubapi/multicoord_test.go @@ -66,7 +66,7 @@ func TestActivateAsCoordinators(t *testing.T) { return "192.0.2.1", nil } - api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr) + api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) defer api.Close() // spawn coordinator @@ -133,7 +133,7 @@ func TestTriggerCoordinatorUpdate(t *testing.T) { } dialer := testdialer.NewBufconnDialer() - api := New(logger, core, dialer, nil, nil, nil) + api := New(logger, core, dialer, nil, nil, nil, nil) _, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{}) if tc.expectErr { @@ -202,7 +202,7 @@ func TestActivateAdditionalCoordinators(t *testing.T) { return "192.0.2.1", nil } - api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr) + api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) defer api.Close() // spawn coordinator diff --git a/coordinator/pubapi/node_test.go b/coordinator/pubapi/node_test.go index fae999303..2e7be15e1 100644 --- a/coordinator/pubapi/node_test.go +++ b/coordinator/pubapi/node_test.go @@ -129,7 +129,7 @@ func TestActivateAsNode(t *testing.T) { cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr} dialer := testdialer.NewBufconnDialer() - api := New(logger, cor, dialer, nil, nil, nil) + api := New(logger, cor, dialer, nil, nil, nil, nil) defer api.Close() vserver := grpc.NewServer() @@ -217,7 +217,7 @@ func TestTriggerNodeUpdate(t *testing.T) { core := &fakeCore{state: tc.state} dialer := testdialer.NewBufconnDialer() - api := New(logger, core, dialer, nil, nil, nil) + api := New(logger, core, dialer, nil, nil, nil, nil) vserver := grpc.NewServer() vapi := &stubVPNAPI{ @@ -292,7 +292,7 @@ func TestJoinCluster(t *testing.T) { core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr} dialer := testdialer.NewBufconnDialer() - api := New(logger, core, dialer, nil, nil, nil) + api := New(logger, core, dialer, nil, nil, nil, nil) vserver := grpc.NewServer() vapi := &stubVPNAPI{ diff --git a/coordinator/pubapi/pubapi.go b/coordinator/pubapi/pubapi.go index 94192b41c..4ab4180d3 100644 --- a/coordinator/pubapi/pubapi.go +++ b/coordinator/pubapi/pubapi.go @@ -3,16 +3,19 @@ package pubapi import ( "context" + "errors" "net" "sync" "time" "github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" + "github.com/edgelesssys/constellation/state/setup" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/peer" ) const ( @@ -34,11 +37,12 @@ type API struct { stopUpdate chan struct{} wgClose sync.WaitGroup resourceVersion int + peerFromContext PeerFromContextFunc pubproto.UnimplementedAPIServer } // New creates a new API. -func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc) *API { +func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API { return &API{ logger: logger, core: core, @@ -47,6 +51,7 @@ func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer validator: validator, getPublicIPAddr: getPublicIPAddr, stopUpdate: make(chan struct{}, 1), + peerFromContext: peerFromContext, } } @@ -112,3 +117,21 @@ type VPNAPIServer interface { } type GetIPAddrFunc func() (string, error) + +// PeerFromContextFunc returns a peer endpoint (IP:port) from a given context. +type PeerFromContextFunc func(context.Context) (string, error) + +// GetRecoveryPeerFromContext returns the context's IP joined with the Coordinator's default port. +func GetRecoveryPeerFromContext(ctx context.Context) (string, error) { + peer, ok := peer.FromContext(ctx) + if !ok { + return "", errors.New("unable to get peer from context") + } + + peerIP, _, err := net.SplitHostPort(peer.Addr.String()) + if err != nil { + return "", err + } + + return net.JoinHostPort(peerIP, setup.RecoveryPort), nil +} diff --git a/coordinator/pubapi/pubapi_test.go b/coordinator/pubapi/pubapi_test.go index 5ec2a846f..c1df99893 100644 --- a/coordinator/pubapi/pubapi_test.go +++ b/coordinator/pubapi/pubapi_test.go @@ -1,9 +1,13 @@ package pubapi import ( + "context" + "net" "testing" + "github.com/stretchr/testify/assert" "go.uber.org/goleak" + grpcpeer "google.golang.org/grpc/peer" ) func TestMain(m *testing.M) { @@ -14,3 +18,20 @@ func TestMain(m *testing.M) { goleak.IgnoreTopFunction("k8s.io/klog/v2.(*loggingT).flushDaemon"), ) } + +func TestGetRecoveryPeerFromContext(t *testing.T) { + assert := assert.New(t) + testIP := "192.0.2.1" + testPort := 1234 + expectedPeer := net.JoinHostPort(testIP, "9000") + + addr := &net.TCPAddr{IP: net.ParseIP(testIP), Port: testPort} + ctx := grpcpeer.NewContext(context.Background(), &grpcpeer.Peer{Addr: addr}) + + peer, err := GetRecoveryPeerFromContext(ctx) + assert.NoError(err) + assert.Equal(expectedPeer, peer) + + _, err = GetRecoveryPeerFromContext(context.Background()) + assert.Error(err) +}