From 77b0237dd5dedf36ab277dff05fea6c3066c23f4 Mon Sep 17 00:00:00 2001 From: Malte Poll Date: Thu, 28 Apr 2022 09:49:15 +0200 Subject: [PATCH] extract shared grpcutil dialer from pubapi Signed-off-by: Malte Poll --- coordinator/cmd/coordinator/main.go | 6 +- coordinator/cmd/coordinator/run.go | 11 ++- coordinator/coordinator_test.go | 18 ++-- coordinator/core/legacy_test.go | 9 +- coordinator/pubapi/coord.go | 8 +- coordinator/pubapi/coord_test.go | 20 ++-- coordinator/pubapi/multicoord.go | 6 +- coordinator/pubapi/multicoord_test.go | 23 +++-- coordinator/pubapi/node.go | 4 +- coordinator/pubapi/node_test.go | 34 ++++--- coordinator/pubapi/pubapi.go | 55 ++--------- coordinator/util/grpcutil/dialer.go | 71 ++++++++++++++ coordinator/util/grpcutil/dialer_test.go | 106 +++++++++++++++++++++ debugd/debugd/deploy/download.go | 7 +- debugd/debugd/deploy/download_test.go | 2 +- debugd/debugd/server/server_test.go | 6 +- debugd/service/testdialer/bufconndialer.go | 41 -------- 17 files changed, 275 insertions(+), 152 deletions(-) create mode 100644 coordinator/util/grpcutil/dialer.go create mode 100644 coordinator/util/grpcutil/dialer_test.go delete mode 100644 debugd/service/testdialer/bufconndialer.go diff --git a/coordinator/cmd/coordinator/main.go b/coordinator/cmd/coordinator/main.go index 28fe1fefd..80a6fa4ce 100644 --- a/coordinator/cmd/coordinator/main.go +++ b/coordinator/cmd/coordinator/main.go @@ -23,6 +23,7 @@ import ( "github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi" "github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi/kubectl" "github.com/edgelesssys/constellation/coordinator/util" + "github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/coordinator/wireguard" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" "github.com/spf13/afero" @@ -153,7 +154,8 @@ func main() { } fileHandler := file.NewHandler(fs) - dialer := &net.Dialer{} - run(validator, issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube, + netDialer := &net.Dialer{} + dialer := grpcutil.NewDialer(validator, netDialer) + run(issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube, metadata, cloudControllerManager, cloudNodeManager, autoscaler, encryptedDisk, etcdEndpoint, enforceEtcdTls, bindIP, bindPort, zapLoggerCore) } diff --git a/coordinator/cmd/coordinator/run.go b/coordinator/cmd/coordinator/run.go index f4c084397..6d22b6d3a 100644 --- a/coordinator/cmd/coordinator/run.go +++ b/coordinator/cmd/coordinator/run.go @@ -14,6 +14,7 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/store" + "github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" @@ -26,7 +27,7 @@ import ( var version = "0.0.0" -func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer pubapi.Dialer, fileHandler file.Handler, +func run(issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer *grpcutil.Dialer, fileHandler file.Handler, kube core.Cluster, metadata core.ProviderMetadata, cloudControllerManager core.CloudControllerManager, cloudNodeManager core.CloudNodeManager, clusterAutoscaler core.ClusterAutoscaler, encryptedDisk core.EncryptedDisk, etcdEndpoint string, etcdTLS bool, bindIP, bindPort string, zapLoggerCore *zap.Logger, ) { defer zapLoggerCore.Sync() @@ -46,16 +47,16 @@ func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, o if err != nil { zapLoggerCore.Fatal("failed to create core", zap.Error(err)) } + + vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core} + zapLoggerPubapi := zapLoggerCore.Named("pubapi") + papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext) // initialize state machine and wait for re-joining of the VPN (if applicable) nodeActivated, err := core.Initialize() if err != nil { zapLoggerCore.Fatal("failed to initialize core", zap.Error(err)) } - vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core} - zapLoggerPubapi := zapLoggerCore.Named("pubapi") - papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext) - zapLoggergRPC := zapLoggerPubapi.Named("gRPC") grpcServer := grpc.NewServer( diff --git a/coordinator/coordinator_test.go b/coordinator/coordinator_test.go index 5b4405923..b0d58bfe6 100644 --- a/coordinator/coordinator_test.go +++ b/coordinator/coordinator_test.go @@ -18,6 +18,7 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/store" + "github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" @@ -210,7 +211,7 @@ func TestConcurrent(t *testing.T) { assert.Error(<-actCoordErrs) } -func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdialer.BufconnDialer, netw *network, endpoint string) (*grpc.Server, *pubapi.API, *fakeVPN) { +func spawnPeer(require *require.Assertions, logger *zap.Logger, netDialer *testdialer.BufconnDialer, netw *network, endpoint string) (*grpc.Server, *pubapi.API, *fakeVPN) { vpn := newVPN(netw, endpoint) cor, err := core.NewCore(vpn, &core.ClusterFake{}, &core.ProviderMetadataFake{}, &core.CloudControllerManagerFake{}, &core.CloudNodeManagerFake{}, &core.ClusterAutoscalerFake{}, &core.EncryptedDiskFake{}, logger, simulator.OpenSimulatedTPM, fakeStoreFactory{}, file.NewHandler(afero.NewMemMapFs())) require.NoError(err) @@ -219,22 +220,23 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdial getPublicAddr := func() (string, error) { return "192.0.2.1", nil } + dialer := grpcutil.NewDialer(&core.MockValidator{}, netDialer) + vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: netDialer} - vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer} - papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr, nil) + papi := pubapi.New(logger, cor, dialer, vapiServer, getPublicAddr, nil) tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}) require.NoError(err) server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) pubproto.RegisterAPIServer(server, papi) - listener := dialer.GetListener(endpoint) + listener := netDialer.GetListener(endpoint) go server.Serve(listener) return server, papi, vpn } -func activateCoordinator(require *require.Assertions, dialer pubapi.Dialer, coordinatorIP, bindPort string, nodeIPs []string) error { +func activateCoordinator(require *require.Assertions, dialer netDialer, coordinatorIP, bindPort string, nodeIPs []string) error { ctx := context.Background() conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(coordinatorIP, bindPort)) require.NoError(err) @@ -260,7 +262,7 @@ func activateCoordinator(require *require.Assertions, dialer pubapi.Dialer, coor } } -func dialGRPC(ctx context.Context, dialer pubapi.Dialer, target string) (*grpc.ClientConn, error) { +func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) { tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}}) if err != nil { return nil, err @@ -398,3 +400,7 @@ func (v *fakeVPN) recv() *packet { } return &packet } + +type netDialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} diff --git a/coordinator/core/legacy_test.go b/coordinator/core/legacy_test.go index 9c6c2603e..fde2eeb78 100644 --- a/coordinator/core/legacy_test.go +++ b/coordinator/core/legacy_test.go @@ -16,6 +16,7 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/state" + "github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/spf13/afero" @@ -112,13 +113,13 @@ func TestLegacyActivateCoordinator(t *testing.T) { } // newMockCoreWithDialer creates a new core object with attestation mock and provided dialer for testing. -func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) { +func newMockCoreWithDialer(bufDialer *bufconnDialer) (*Core, *pubapi.API, error) { zapLogger, err := zap.NewDevelopment() if err != nil { return nil, nil, err } - validator := NewMockValidator() + dialer := grpcutil.NewDialer(NewMockValidator(), bufDialer) vpn := &stubVPN{} kubeFake := &ClusterFake{} metadataFake := &ProviderMetadataFake{} @@ -138,8 +139,8 @@ func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) { return nil, nil, err } - vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer} - papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr, nil) + vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: bufDialer} + papi := pubapi.New(zapLogger, core, dialer, vapiServer, getPublicAddr, nil) return core, papi, nil } diff --git a/coordinator/pubapi/coord.go b/coordinator/pubapi/coord.go index b54d80522..9348f8316 100644 --- a/coordinator/pubapi/coord.go +++ b/coordinator/pubapi/coord.go @@ -204,7 +204,7 @@ func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestState return nil, status.Errorf(codes.Internal, "%v", err) } - conn, err := a.dial(ctx, peer) + conn, err := a.dialer.Dial(ctx, peer) if err != nil { return nil, status.Errorf(codes.Internal, "%v", err) } @@ -297,7 +297,7 @@ func (a *API) activateNode(nodePublicIP string, nodeVPNIP string, initialPeers [ ctx, cancel := context.WithTimeout(context.Background(), deadlineDuration) defer cancel() - conn, err := a.dial(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort)) + conn, err := a.dialer.Dial(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort)) if err != nil { return nil, err } @@ -417,7 +417,7 @@ func (a *API) joinCluster(nodePublicIP string) error { } // We don't verify the peer certificate here, since JoinCluster triggers a connection over VPN // The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted - conn, err := a.dialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort)) + conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort)) if err != nil { return err } @@ -455,7 +455,7 @@ func (a *API) triggerNodeUpdate(nodePublicIP string) error { // We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN // The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted - conn, err := a.dialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort)) + conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort)) if err != nil { return err } diff --git a/coordinator/pubapi/coord_test.go b/coordinator/pubapi/coord_test.go index 5a8fe179a..b2cadb20b 100644 --- a/coordinator/pubapi/coord_test.go +++ b/coordinator/pubapi/coord_test.go @@ -18,6 +18,7 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/state" + "github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/stretchr/testify/assert" @@ -125,13 +126,14 @@ func TestActivateAsCoordinator(t *testing.T) { ownerID: []byte("ownerID"), clusterID: []byte("clusterID"), } - dialer := testdialer.NewBufconnDialer() + netDialer := testdialer.NewBufconnDialer() + dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) getPublicIPAddr := func() (string, error) { return "192.0.2.1", nil } - api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) + api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil) defer api.Close() // spawn nodes @@ -142,7 +144,7 @@ func TestActivateAsCoordinator(t *testing.T) { server := n.newServer() wg.Add(1) go func(endpoint string) { - listener := dialer.GetListener(endpoint) + listener := netDialer.GetListener(endpoint) wg.Done() _ = server.Serve(listener) }(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort)) @@ -256,13 +258,14 @@ func TestActivateAdditionalNodes(t *testing.T) { require := require.New(t) core := &fakeCore{state: tc.state} - dialer := testdialer.NewBufconnDialer() + netDialer := testdialer.NewBufconnDialer() + dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) getPublicIPAddr := func() (string, error) { return "192.0.2.1", nil } - api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr, nil) + api := New(zaptest.NewLogger(t), core, dialer, nil, getPublicIPAddr, nil) defer api.Close() // spawn nodes var nodePublicIPs []string @@ -272,7 +275,7 @@ func TestActivateAdditionalNodes(t *testing.T) { server := n.newServer() wg.Add(1) go func(endpoint string) { - listener := dialer.GetListener(endpoint) + listener := netDialer.GetListener(endpoint) wg.Done() _ = server.Serve(listener) }(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort)) @@ -311,7 +314,7 @@ func TestAssemblePeerStruct(t *testing.T) { vpnPubKey := []byte{2, 3, 4} core := &fakeCore{vpnPubKey: vpnPubKey} - api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr, nil) + api := New(zaptest.NewLogger(t), core, nil, nil, getPublicIPAddr, nil) defer api.Close() vpnIP, err := core.GetVPNIP() @@ -512,7 +515,8 @@ func TestRequestStateDiskKey(t *testing.T) { dataKey: tc.dataKey, getDataKeyErr: tc.getDataKeyErr, } - api := New(zaptest.NewLogger(t), core, &net.Dialer{}, nil, dummyValidator{}, nil, getPeerFromContext) + + api := New(zaptest.NewLogger(t), core, grpcutil.NewDialer(dummyValidator{}, &net.Dialer{}), nil, nil, getPeerFromContext) _, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{}) if tc.wantErr { diff --git a/coordinator/pubapi/multicoord.go b/coordinator/pubapi/multicoord.go index 804118660..b2aa61f19 100644 --- a/coordinator/pubapi/multicoord.go +++ b/coordinator/pubapi/multicoord.go @@ -199,7 +199,7 @@ func (a *API) activateCoordinator(ctx context.Context, coordinatorIP string) err return err } - conn, err := a.dial(ctx, net.JoinHostPort(coordinatorIP, endpointAVPNPort)) + conn, err := a.dialer.Dial(ctx, net.JoinHostPort(coordinatorIP, endpointAVPNPort)) if err != nil { return fmt.Errorf("dialing new coordinator: %v", err) } @@ -271,7 +271,7 @@ func (a *API) triggerCoordinatorUpdate(ctx context.Context, publicIP string) err // We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN // The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted - conn, err := a.dialNoVerify(ctx, net.JoinHostPort(publicIP, endpointAVPNPort)) + conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(publicIP, endpointAVPNPort)) if err != nil { return err } @@ -284,7 +284,7 @@ func (a *API) triggerCoordinatorUpdate(ctx context.Context, publicIP string) err } func (a *API) getk8SCoordinatorJoinArgs(ctx context.Context, coordinatorIP, port string) (*kubeadm.BootstrapTokenDiscovery, string, error) { - conn, err := a.dialInsecure(ctx, net.JoinHostPort(coordinatorIP, port)) + conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort(coordinatorIP, port)) if err != nil { return nil, "", err } diff --git a/coordinator/pubapi/multicoord_test.go b/coordinator/pubapi/multicoord_test.go index bacfb652f..930aa3137 100644 --- a/coordinator/pubapi/multicoord_test.go +++ b/coordinator/pubapi/multicoord_test.go @@ -10,6 +10,7 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/state" + "github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -92,18 +93,19 @@ func TestActivateAsAdditionalCoordinator(t *testing.T) { ownerID: []byte("ownerID"), clusterID: []byte("clusterID"), } - dialer := testdialer.NewBufconnDialer() + netDialer := testdialer.NewBufconnDialer() + dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) getPublicIPAddr := func() (string, error) { return "192.0.2.1", nil } - api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) + api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil) defer api.Close() // spawn vpnServer vpnapiServer := tc.vpnapi.newServer() - go vpnapiServer.Serve(dialer.GetListener(net.JoinHostPort(tc.coordinators.peer.VPNIP, vpnAPIPort))) + go vpnapiServer.Serve(netDialer.GetListener(net.JoinHostPort(tc.coordinators.peer.VPNIP, vpnAPIPort))) defer vpnapiServer.GracefulStop() _, err := api.ActivateAsAdditionalCoordinator(context.Background(), &pubproto.ActivateAsAdditionalCoordinatorRequest{ @@ -163,9 +165,9 @@ func TestTriggerCoordinatorUpdate(t *testing.T) { state: tc.state, peers: tc.peers, } - dialer := testdialer.NewBufconnDialer() + dialer := grpcutil.NewDialer(fakeValidator{}, nil) - api := New(logger, core, dialer, nil, nil, nil, nil) + api := New(logger, core, dialer, nil, nil, nil) _, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{}) if tc.wantErr { @@ -236,20 +238,21 @@ func TestActivateAdditionalCoordinators(t *testing.T) { ownerID: []byte("ownerID"), clusterID: []byte("clusterID"), } - dialer := testdialer.NewBufconnDialer() + netDialer := testdialer.NewBufconnDialer() + dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) getPublicIPAddr := func() (string, error) { return "192.0.2.1", nil } - api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) + api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil) defer api.Close() // spawn coordinator tc.coordinators.activateErr = tc.activateErr tc.coordinators.getPubKeyErr = tc.getPublicKeyErr server := tc.coordinators.newServer() - go server.Serve(dialer.GetListener(net.JoinHostPort(tc.coordinators.peer.PublicIP, endpointAVPNPort))) + go server.Serve(netDialer.GetListener(net.JoinHostPort(tc.coordinators.peer.PublicIP, endpointAVPNPort))) defer server.GracefulStop() _, err := api.ActivateAdditionalCoordinator(context.Background(), &pubproto.ActivateAdditionalCoordinatorRequest{CoordinatorPublicIp: tc.coordinators.peer.PublicIP}) @@ -293,13 +296,13 @@ func TestGetPeerVPNPublicKey(t *testing.T) { vpnPubKey: tc.coordinator.peer.VPNPubKey, getvpnPubKeyErr: tc.getVPNPubKeyErr, } - dialer := testdialer.NewBufconnDialer() + dialer := grpcutil.NewDialer(fakeValidator{}, testdialer.NewBufconnDialer()) getPublicIPAddr := func() (string, error) { return "192.0.2.1", nil } - api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) + api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil) defer api.Close() resp, err := api.GetPeerVPNPublicKey(context.Background(), &pubproto.GetPeerVPNPublicKeyRequest{}) diff --git a/coordinator/pubapi/node.go b/coordinator/pubapi/node.go index f7305cbde..2f14c119f 100644 --- a/coordinator/pubapi/node.go +++ b/coordinator/pubapi/node.go @@ -168,7 +168,7 @@ func (a *API) JoinCluster(ctx context.Context, in *pubproto.JoinClusterRequest) return nil, status.Errorf(codes.FailedPrecondition, "node is not in required state for cluster join: %v", err) } - conn, err := a.dialInsecure(ctx, net.JoinHostPort(in.CoordinatorVpnIp, vpnAPIPort)) + conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort(in.CoordinatorVpnIp, vpnAPIPort)) if err != nil { return nil, status.Errorf(codes.Unavailable, "dial coordinator: %v", err) } @@ -231,7 +231,7 @@ func (a *API) update(ctx context.Context) error { defer cancel() // TODO: replace hardcoded IP - conn, err := a.dialInsecure(ctx, net.JoinHostPort("10.118.0.1", vpnAPIPort)) + conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort("10.118.0.1", vpnAPIPort)) if err != nil { return err } diff --git a/coordinator/pubapi/node_test.go b/coordinator/pubapi/node_test.go index 381c43d72..26e394cc3 100644 --- a/coordinator/pubapi/node_test.go +++ b/coordinator/pubapi/node_test.go @@ -13,6 +13,7 @@ import ( "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/state" + "github.com/edgelesssys/constellation/coordinator/util/grpcutil" "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/stretchr/testify/assert" @@ -127,25 +128,26 @@ func TestActivateAsNode(t *testing.T) { logger := zaptest.NewLogger(t) cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr} - dialer := testdialer.NewBufconnDialer() + netDialer := testdialer.NewBufconnDialer() + dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) - api := New(logger, cor, dialer, nil, nil, nil, nil) + api := New(logger, cor, dialer, nil, nil, nil) defer api.Close() vserver := grpc.NewServer() vapi := &stubVPNAPI{peers: tc.updatedPeers, getUpdateErr: tc.getUpdateErr} vpnproto.RegisterAPIServer(vserver, vapi) - go vserver.Serve(dialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort))) + go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort))) defer vserver.GracefulStop() tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}) require.NoError(err) pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) pubproto.RegisterAPIServer(pubserver, api) - go pubserver.Serve(dialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort))) + go pubserver.Serve(netDialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort))) defer pubserver.GracefulStop() - _, nodeVPNPubKey, err := activateNode(require, dialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey) + _, nodeVPNPubKey, err := activateNode(require, netDialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey) assert.Equal(tc.wantState, cor.state) if tc.wantErr { @@ -215,9 +217,10 @@ func TestTriggerNodeUpdate(t *testing.T) { logger := zaptest.NewLogger(t) core := &fakeCore{state: tc.state} - dialer := testdialer.NewBufconnDialer() + netDialer := testdialer.NewBufconnDialer() + dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) - api := New(logger, core, dialer, nil, nil, nil, nil) + api := New(logger, core, dialer, nil, nil, nil) vserver := grpc.NewServer() vapi := &stubVPNAPI{ @@ -225,7 +228,7 @@ func TestTriggerNodeUpdate(t *testing.T) { getUpdateErr: tc.getUpdateErr, } vpnproto.RegisterAPIServer(vserver, vapi) - go vserver.Serve(dialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort))) + go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort))) defer vserver.GracefulStop() _, err := api.TriggerNodeUpdate(context.Background(), &pubproto.TriggerNodeUpdateRequest{}) @@ -290,9 +293,10 @@ func TestJoinCluster(t *testing.T) { logger := zaptest.NewLogger(t) core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr} - dialer := testdialer.NewBufconnDialer() + netDialer := testdialer.NewBufconnDialer() + dialer := grpcutil.NewDialer(fakeValidator{}, netDialer) - api := New(logger, core, dialer, nil, nil, nil, nil) + api := New(logger, core, dialer, nil, nil, nil) vserver := grpc.NewServer() vapi := &stubVPNAPI{ @@ -304,7 +308,7 @@ func TestJoinCluster(t *testing.T) { getJoinArgsErr: tc.getJoinArgsErr, } vpnproto.RegisterAPIServer(vserver, vapi) - go vserver.Serve(dialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort))) + go vserver.Serve(netDialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort))) defer vserver.GracefulStop() _, err := api.JoinCluster(context.Background(), &pubproto.JoinClusterRequest{CoordinatorVpnIp: "192.0.2.1"}) @@ -322,7 +326,7 @@ func TestJoinCluster(t *testing.T) { } } -func activateNode(require *require.Assertions, dialer Dialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) { +func activateNode(require *require.Assertions, dialer netDialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) { ctx := context.Background() conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(nodeIP, bindPort)) require.NoError(err) @@ -385,7 +389,7 @@ func activateNode(require *require.Assertions, dialer Dialer, messageSequence [] return diskUUID, nodeVPNPubKey, nil } -func dialGRPC(ctx context.Context, dialer Dialer, target string) (*grpc.ClientConn, error) { +func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) { tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}}) if err != nil { return nil, err @@ -429,3 +433,7 @@ func (a *stubVPNAPI) newServer() *grpc.Server { vpnproto.RegisterAPIServer(server, a) return server } + +type netDialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} diff --git a/coordinator/pubapi/pubapi.go b/coordinator/pubapi/pubapi.go index ac3c8fb5a..7f5c73743 100644 --- a/coordinator/pubapi/pubapi.go +++ b/coordinator/pubapi/pubapi.go @@ -8,13 +8,10 @@ import ( "sync" "time" - "github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/state/setup" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/peer" ) @@ -32,7 +29,6 @@ type API struct { core Core dialer Dialer vpnAPIServer VPNAPIServer - validator atls.Validator getPublicIPAddr GetIPAddrFunc stopUpdate chan struct{} wgClose sync.WaitGroup @@ -42,13 +38,12 @@ type API struct { } // New creates a new API. -func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API { +func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API { return &API{ logger: logger, core: core, dialer: dialer, vpnAPIServer: vpnAPIServer, - validator: validator, getPublicIPAddr: getPublicIPAddr, stopUpdate: make(chan struct{}, 1), peerFromContext: peerFromContext, @@ -69,47 +64,6 @@ func (a *API) Close() { a.wgClose.Wait() } -func (a *API) dial(ctx context.Context, target string) (*grpc.ClientConn, error) { - tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{a.validator}) - if err != nil { - return nil, err - } - - return grpc.DialContext(ctx, target, - a.grpcWithDialer(), - grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), - ) -} - -func (a *API) dialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) { - return grpc.DialContext(ctx, target, - a.grpcWithDialer(), - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) -} - -func (a *API) dialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) { - tlsConfig, err := atls.CreateUnverifiedClientTLSConfig() - if err != nil { - return nil, err - } - - return grpc.DialContext(ctx, target, - a.grpcWithDialer(), - grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), - ) -} - -func (a *API) grpcWithDialer() grpc.DialOption { - return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return a.dialer.DialContext(ctx, "tcp", addr) - }) -} - -type Dialer interface { - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - type VPNAPIServer interface { Listen(endpoint string) error Serve() error @@ -135,3 +89,10 @@ func GetRecoveryPeerFromContext(ctx context.Context) (string, error) { return net.JoinHostPort(peerIP, setup.RecoveryPort), nil } + +// Dialer can open grpc client connections with different levels of ATLS encryption / verification. +type Dialer interface { + Dial(ctx context.Context, target string) (*grpc.ClientConn, error) + DialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) + DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) +} diff --git a/coordinator/util/grpcutil/dialer.go b/coordinator/util/grpcutil/dialer.go new file mode 100644 index 000000000..76e9d37fa --- /dev/null +++ b/coordinator/util/grpcutil/dialer.go @@ -0,0 +1,71 @@ +package grpcutil + +import ( + "context" + "net" + + "github.com/edgelesssys/constellation/coordinator/atls" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" +) + +// Dialer can open grpc client connections with different levels of ATLS encryption / verification. +type Dialer struct { + validator atls.Validator + netDialer NetDialer +} + +// NewDialer creates a new Dialer. +func NewDialer(validator atls.Validator, netDialer NetDialer) *Dialer { + return &Dialer{ + validator: validator, + netDialer: netDialer, + } +} + +// Dial creates a new grpc client connection to the given target using the atls validator. +func (d *Dialer) Dial(ctx context.Context, target string) (*grpc.ClientConn, error) { + tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{d.validator}) + if err != nil { + return nil, err + } + + return grpc.DialContext(ctx, target, + d.grpcWithDialer(), + grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + ) +} + +// DialInsecure creates a new grpc client connection to the given target without using encryption or verification. +// Only use this method when using another kind of encryption / verification (VPN, etc). +func (d *Dialer) DialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) { + return grpc.DialContext(ctx, target, + d.grpcWithDialer(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) +} + +// DialNoVerify creates a new grpc client connection to the given target without verifying the server's attestation. +func (d *Dialer) DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) { + tlsConfig, err := atls.CreateUnverifiedClientTLSConfig() + if err != nil { + return nil, err + } + + return grpc.DialContext(ctx, target, + d.grpcWithDialer(), + grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + ) +} + +func (d *Dialer) grpcWithDialer() grpc.DialOption { + return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return d.netDialer.DialContext(ctx, "tcp", addr) + }) +} + +// NetDialer implements the net Dialer interface. +type NetDialer interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} diff --git a/coordinator/util/grpcutil/dialer_test.go b/coordinator/util/grpcutil/dialer_test.go new file mode 100644 index 000000000..f64f7778a --- /dev/null +++ b/coordinator/util/grpcutil/dialer_test.go @@ -0,0 +1,106 @@ +package grpcutil + +import ( + "context" + "testing" + + "github.com/edgelesssys/constellation/coordinator/atls" + "github.com/edgelesssys/constellation/coordinator/core" + "github.com/edgelesssys/constellation/coordinator/util/testdialer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/test/grpc_testing" +) + +func TestDial(t *testing.T) { + testCases := map[string]struct { + tls bool + dialFn func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) + wantErr bool + }{ + "Dial with tls on server works": { + tls: true, + dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) { + return dialer.Dial(ctx, target) + }, + }, + "Dial without tls on server fails": { + dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) { + return dialer.Dial(ctx, target) + }, + wantErr: true, + }, + "DialNoVerify with tls on server works": { + tls: true, + dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) { + return dialer.DialNoVerify(ctx, target) + }, + }, + "DialNoVerify without tls on server fails": { + dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) { + return dialer.DialNoVerify(ctx, target) + }, + wantErr: true, + }, + "DialInsecure without tls on server works": { + dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) { + return dialer.DialInsecure(ctx, target) + }, + }, + "DialInsecure with tls on server fails": { + tls: true, + dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) { + return dialer.DialInsecure(ctx, target) + }, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + netDialer := testdialer.NewBufconnDialer() + dialer := NewDialer(&core.MockValidator{}, netDialer) + server := newServer(tc.tls) + api := &testAPI{} + grpc_testing.RegisterTestServiceServer(server, api) + go server.Serve(netDialer.GetListener("192.0.2.1:1234")) + defer server.Stop() + conn, err := tc.dialFn(dialer, context.Background(), "192.0.2.1:1234") + require.NoError(err) + defer conn.Close() + + client := grpc_testing.NewTestServiceClient(conn) + _, err = client.EmptyCall(context.Background(), &grpc_testing.Empty{}) + + if tc.wantErr { + assert.Error(err) + return + } + require.NoError(err) + }) + } +} + +func newServer(tls bool) *grpc.Server { + if tls { + tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}) + if err != nil { + panic(err) + } + return grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) + } + return grpc.NewServer() +} + +type testAPI struct { + grpc_testing.UnimplementedTestServiceServer +} + +func (s *testAPI) EmptyCall(ctx context.Context, in *grpc_testing.Empty) (*grpc_testing.Empty, error) { + return &grpc_testing.Empty{}, nil +} diff --git a/debugd/debugd/deploy/download.go b/debugd/debugd/deploy/download.go index 80cbadb66..ee0c2d55f 100644 --- a/debugd/debugd/deploy/download.go +++ b/debugd/debugd/deploy/download.go @@ -16,14 +16,14 @@ import ( // Download downloads a coordinator from a given debugd instance. type Download struct { - dialer Dialer + dialer NetDialer writer streamToFileWriter serviceManager serviceManager attemptedDownloads map[string]time.Time } // New creates a new Download. -func New(dialer Dialer, serviceManager serviceManager, writer streamToFileWriter) *Download { +func New(dialer NetDialer, serviceManager serviceManager, writer streamToFileWriter) *Download { return &Download{ dialer: dialer, writer: writer, @@ -91,6 +91,7 @@ type streamToFileWriter interface { WriteStream(filename string, stream coordinator.ReadChunkStream, showProgress bool) error } -type Dialer interface { +// NetDialer can open a net.Conn. +type NetDialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } diff --git a/debugd/debugd/deploy/download_test.go b/debugd/debugd/deploy/download_test.go index 8ca6db8f4..c0d08df01 100644 --- a/debugd/debugd/deploy/download_test.go +++ b/debugd/debugd/deploy/download_test.go @@ -9,10 +9,10 @@ import ( "testing" "time" + "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/debugd/coordinator" "github.com/edgelesssys/constellation/debugd/debugd" pb "github.com/edgelesssys/constellation/debugd/service" - "github.com/edgelesssys/constellation/debugd/service/testdialer" "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/protobuf/proto" diff --git a/debugd/debugd/server/server_test.go b/debugd/debugd/server/server_test.go index 0c233f9cd..244ab231c 100644 --- a/debugd/debugd/server/server_test.go +++ b/debugd/debugd/server/server_test.go @@ -8,10 +8,10 @@ import ( "net" "testing" + "github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/debugd/coordinator" "github.com/edgelesssys/constellation/debugd/debugd/deploy" pb "github.com/edgelesssys/constellation/debugd/service" - "github.com/edgelesssys/constellation/debugd/service/testdialer" "github.com/edgelesssys/constellation/debugd/ssh" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -351,11 +351,11 @@ func (s *stubServiceManager) WriteSystemdUnitFile(ctx context.Context, unit depl return s.writeSystemdUnitFileErr } -type dialer interface { +type netDialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } -func dial(ctx context.Context, dialer dialer, target string) (*grpc.ClientConn, error) { +func dial(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) { return grpc.DialContext(ctx, target, grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { return dialer.DialContext(ctx, "tcp", addr) diff --git a/debugd/service/testdialer/bufconndialer.go b/debugd/service/testdialer/bufconndialer.go deleted file mode 100644 index 22ee11ddb..000000000 --- a/debugd/service/testdialer/bufconndialer.go +++ /dev/null @@ -1,41 +0,0 @@ -package testdialer - -import ( - "context" - "fmt" - "net" - "sync" - - "google.golang.org/grpc/test/bufconn" -) - -// BufconnDialer is a fake dialer based on gRPC bufconn package. -type BufconnDialer struct { - mut sync.Mutex - listeners map[string]*bufconn.Listener -} - -// NewBufconnDialer creates a new bufconn dialer for testing. -func NewBufconnDialer() *BufconnDialer { - return &BufconnDialer{listeners: make(map[string]*bufconn.Listener)} -} - -// DialContext implements the Dialer interface. -func (b *BufconnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - b.mut.Lock() - listener, ok := b.listeners[address] - b.mut.Unlock() - if !ok { - return nil, fmt.Errorf("could not connect to server on %v", address) - } - return listener.DialContext(ctx) -} - -// GetListener returns a fake listener that is coupled with this dialer. -func (b *BufconnDialer) GetListener(endpoint string) net.Listener { - listener := bufconn.Listen(1024) - b.mut.Lock() - b.listeners[endpoint] = listener - b.mut.Unlock() - return listener -}