extract shared grpcutil dialer from pubapi

Signed-off-by: Malte Poll <mp@edgeless.systems>
This commit is contained in:
Malte Poll 2022-04-28 09:49:15 +02:00 committed by Malte Poll
parent 5ac72c730d
commit 77b0237dd5
17 changed files with 275 additions and 152 deletions

View File

@ -23,6 +23,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi"
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi/kubectl"
"github.com/edgelesssys/constellation/coordinator/util"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/wireguard"
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
"github.com/spf13/afero"
@ -153,7 +154,8 @@ func main() {
}
fileHandler := file.NewHandler(fs)
dialer := &net.Dialer{}
run(validator, issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube,
netDialer := &net.Dialer{}
dialer := grpcutil.NewDialer(validator, netDialer)
run(issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube,
metadata, cloudControllerManager, cloudNodeManager, autoscaler, encryptedDisk, etcdEndpoint, enforceEtcdTls, bindIP, bindPort, zapLoggerCore)
}

View File

@ -14,6 +14,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/store"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/vpnapi"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
@ -26,7 +27,7 @@ import (
var version = "0.0.0"
func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer pubapi.Dialer, fileHandler file.Handler,
func run(issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer *grpcutil.Dialer, fileHandler file.Handler,
kube core.Cluster, metadata core.ProviderMetadata, cloudControllerManager core.CloudControllerManager, cloudNodeManager core.CloudNodeManager, clusterAutoscaler core.ClusterAutoscaler, encryptedDisk core.EncryptedDisk, etcdEndpoint string, etcdTLS bool, bindIP, bindPort string, zapLoggerCore *zap.Logger,
) {
defer zapLoggerCore.Sync()
@ -46,16 +47,16 @@ func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, o
if err != nil {
zapLoggerCore.Fatal("failed to create core", zap.Error(err))
}
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
// initialize state machine and wait for re-joining of the VPN (if applicable)
nodeActivated, err := core.Initialize()
if err != nil {
zapLoggerCore.Fatal("failed to initialize core", zap.Error(err))
}
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
zapLoggergRPC := zapLoggerPubapi.Named("gRPC")
grpcServer := grpc.NewServer(

View File

@ -18,6 +18,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/store"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/coordinator/vpnapi"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
@ -210,7 +211,7 @@ func TestConcurrent(t *testing.T) {
assert.Error(<-actCoordErrs)
}
func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdialer.BufconnDialer, netw *network, endpoint string) (*grpc.Server, *pubapi.API, *fakeVPN) {
func spawnPeer(require *require.Assertions, logger *zap.Logger, netDialer *testdialer.BufconnDialer, netw *network, endpoint string) (*grpc.Server, *pubapi.API, *fakeVPN) {
vpn := newVPN(netw, endpoint)
cor, err := core.NewCore(vpn, &core.ClusterFake{}, &core.ProviderMetadataFake{}, &core.CloudControllerManagerFake{}, &core.CloudNodeManagerFake{}, &core.ClusterAutoscalerFake{}, &core.EncryptedDiskFake{}, logger, simulator.OpenSimulatedTPM, fakeStoreFactory{}, file.NewHandler(afero.NewMemMapFs()))
require.NoError(err)
@ -219,22 +220,23 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdial
getPublicAddr := func() (string, error) {
return "192.0.2.1", nil
}
dialer := grpcutil.NewDialer(&core.MockValidator{}, netDialer)
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: netDialer}
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer}
papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr, nil)
papi := pubapi.New(logger, cor, dialer, vapiServer, getPublicAddr, nil)
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
require.NoError(err)
server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
pubproto.RegisterAPIServer(server, papi)
listener := dialer.GetListener(endpoint)
listener := netDialer.GetListener(endpoint)
go server.Serve(listener)
return server, papi, vpn
}
func activateCoordinator(require *require.Assertions, dialer pubapi.Dialer, coordinatorIP, bindPort string, nodeIPs []string) error {
func activateCoordinator(require *require.Assertions, dialer netDialer, coordinatorIP, bindPort string, nodeIPs []string) error {
ctx := context.Background()
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(coordinatorIP, bindPort))
require.NoError(err)
@ -260,7 +262,7 @@ func activateCoordinator(require *require.Assertions, dialer pubapi.Dialer, coor
}
}
func dialGRPC(ctx context.Context, dialer pubapi.Dialer, target string) (*grpc.ClientConn, error) {
func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
if err != nil {
return nil, err
@ -398,3 +400,7 @@ func (v *fakeVPN) recv() *packet {
}
return &packet
}
type netDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View File

@ -16,6 +16,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/vpnapi"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
"github.com/spf13/afero"
@ -112,13 +113,13 @@ func TestLegacyActivateCoordinator(t *testing.T) {
}
// newMockCoreWithDialer creates a new core object with attestation mock and provided dialer for testing.
func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) {
func newMockCoreWithDialer(bufDialer *bufconnDialer) (*Core, *pubapi.API, error) {
zapLogger, err := zap.NewDevelopment()
if err != nil {
return nil, nil, err
}
validator := NewMockValidator()
dialer := grpcutil.NewDialer(NewMockValidator(), bufDialer)
vpn := &stubVPN{}
kubeFake := &ClusterFake{}
metadataFake := &ProviderMetadataFake{}
@ -138,8 +139,8 @@ func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) {
return nil, nil, err
}
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer}
papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr, nil)
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: bufDialer}
papi := pubapi.New(zapLogger, core, dialer, vapiServer, getPublicAddr, nil)
return core, papi, nil
}

View File

@ -204,7 +204,7 @@ func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestState
return nil, status.Errorf(codes.Internal, "%v", err)
}
conn, err := a.dial(ctx, peer)
conn, err := a.dialer.Dial(ctx, peer)
if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
@ -297,7 +297,7 @@ func (a *API) activateNode(nodePublicIP string, nodeVPNIP string, initialPeers [
ctx, cancel := context.WithTimeout(context.Background(), deadlineDuration)
defer cancel()
conn, err := a.dial(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
conn, err := a.dialer.Dial(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
if err != nil {
return nil, err
}
@ -417,7 +417,7 @@ func (a *API) joinCluster(nodePublicIP string) error {
}
// We don't verify the peer certificate here, since JoinCluster triggers a connection over VPN
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
if err != nil {
return err
}
@ -455,7 +455,7 @@ func (a *API) triggerNodeUpdate(nodePublicIP string) error {
// We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
if err != nil {
return err
}

View File

@ -18,6 +18,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
"github.com/stretchr/testify/assert"
@ -125,13 +126,14 @@ func TestActivateAsCoordinator(t *testing.T) {
ownerID: []byte("ownerID"),
clusterID: []byte("clusterID"),
}
dialer := testdialer.NewBufconnDialer()
netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil
}
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
defer api.Close()
// spawn nodes
@ -142,7 +144,7 @@ func TestActivateAsCoordinator(t *testing.T) {
server := n.newServer()
wg.Add(1)
go func(endpoint string) {
listener := dialer.GetListener(endpoint)
listener := netDialer.GetListener(endpoint)
wg.Done()
_ = server.Serve(listener)
}(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort))
@ -256,13 +258,14 @@ func TestActivateAdditionalNodes(t *testing.T) {
require := require.New(t)
core := &fakeCore{state: tc.state}
dialer := testdialer.NewBufconnDialer()
netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil
}
api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr, nil)
api := New(zaptest.NewLogger(t), core, dialer, nil, getPublicIPAddr, nil)
defer api.Close()
// spawn nodes
var nodePublicIPs []string
@ -272,7 +275,7 @@ func TestActivateAdditionalNodes(t *testing.T) {
server := n.newServer()
wg.Add(1)
go func(endpoint string) {
listener := dialer.GetListener(endpoint)
listener := netDialer.GetListener(endpoint)
wg.Done()
_ = server.Serve(listener)
}(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort))
@ -311,7 +314,7 @@ func TestAssemblePeerStruct(t *testing.T) {
vpnPubKey := []byte{2, 3, 4}
core := &fakeCore{vpnPubKey: vpnPubKey}
api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr, nil)
api := New(zaptest.NewLogger(t), core, nil, nil, getPublicIPAddr, nil)
defer api.Close()
vpnIP, err := core.GetVPNIP()
@ -512,7 +515,8 @@ func TestRequestStateDiskKey(t *testing.T) {
dataKey: tc.dataKey,
getDataKeyErr: tc.getDataKeyErr,
}
api := New(zaptest.NewLogger(t), core, &net.Dialer{}, nil, dummyValidator{}, nil, getPeerFromContext)
api := New(zaptest.NewLogger(t), core, grpcutil.NewDialer(dummyValidator{}, &net.Dialer{}), nil, nil, getPeerFromContext)
_, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{})
if tc.wantErr {

View File

@ -199,7 +199,7 @@ func (a *API) activateCoordinator(ctx context.Context, coordinatorIP string) err
return err
}
conn, err := a.dial(ctx, net.JoinHostPort(coordinatorIP, endpointAVPNPort))
conn, err := a.dialer.Dial(ctx, net.JoinHostPort(coordinatorIP, endpointAVPNPort))
if err != nil {
return fmt.Errorf("dialing new coordinator: %v", err)
}
@ -271,7 +271,7 @@ func (a *API) triggerCoordinatorUpdate(ctx context.Context, publicIP string) err
// We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(publicIP, endpointAVPNPort))
conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(publicIP, endpointAVPNPort))
if err != nil {
return err
}
@ -284,7 +284,7 @@ func (a *API) triggerCoordinatorUpdate(ctx context.Context, publicIP string) err
}
func (a *API) getk8SCoordinatorJoinArgs(ctx context.Context, coordinatorIP, port string) (*kubeadm.BootstrapTokenDiscovery, string, error) {
conn, err := a.dialInsecure(ctx, net.JoinHostPort(coordinatorIP, port))
conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort(coordinatorIP, port))
if err != nil {
return nil, "", err
}

View File

@ -10,6 +10,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -92,18 +93,19 @@ func TestActivateAsAdditionalCoordinator(t *testing.T) {
ownerID: []byte("ownerID"),
clusterID: []byte("clusterID"),
}
dialer := testdialer.NewBufconnDialer()
netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil
}
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
defer api.Close()
// spawn vpnServer
vpnapiServer := tc.vpnapi.newServer()
go vpnapiServer.Serve(dialer.GetListener(net.JoinHostPort(tc.coordinators.peer.VPNIP, vpnAPIPort)))
go vpnapiServer.Serve(netDialer.GetListener(net.JoinHostPort(tc.coordinators.peer.VPNIP, vpnAPIPort)))
defer vpnapiServer.GracefulStop()
_, err := api.ActivateAsAdditionalCoordinator(context.Background(), &pubproto.ActivateAsAdditionalCoordinatorRequest{
@ -163,9 +165,9 @@ func TestTriggerCoordinatorUpdate(t *testing.T) {
state: tc.state,
peers: tc.peers,
}
dialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, nil)
api := New(logger, core, dialer, nil, nil, nil, nil)
api := New(logger, core, dialer, nil, nil, nil)
_, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{})
if tc.wantErr {
@ -236,20 +238,21 @@ func TestActivateAdditionalCoordinators(t *testing.T) {
ownerID: []byte("ownerID"),
clusterID: []byte("clusterID"),
}
dialer := testdialer.NewBufconnDialer()
netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil
}
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
defer api.Close()
// spawn coordinator
tc.coordinators.activateErr = tc.activateErr
tc.coordinators.getPubKeyErr = tc.getPublicKeyErr
server := tc.coordinators.newServer()
go server.Serve(dialer.GetListener(net.JoinHostPort(tc.coordinators.peer.PublicIP, endpointAVPNPort)))
go server.Serve(netDialer.GetListener(net.JoinHostPort(tc.coordinators.peer.PublicIP, endpointAVPNPort)))
defer server.GracefulStop()
_, err := api.ActivateAdditionalCoordinator(context.Background(), &pubproto.ActivateAdditionalCoordinatorRequest{CoordinatorPublicIp: tc.coordinators.peer.PublicIP})
@ -293,13 +296,13 @@ func TestGetPeerVPNPublicKey(t *testing.T) {
vpnPubKey: tc.coordinator.peer.VPNPubKey,
getvpnPubKeyErr: tc.getVPNPubKeyErr,
}
dialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, testdialer.NewBufconnDialer())
getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil
}
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
defer api.Close()
resp, err := api.GetPeerVPNPublicKey(context.Background(), &pubproto.GetPeerVPNPublicKeyRequest{})

View File

@ -168,7 +168,7 @@ func (a *API) JoinCluster(ctx context.Context, in *pubproto.JoinClusterRequest)
return nil, status.Errorf(codes.FailedPrecondition, "node is not in required state for cluster join: %v", err)
}
conn, err := a.dialInsecure(ctx, net.JoinHostPort(in.CoordinatorVpnIp, vpnAPIPort))
conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort(in.CoordinatorVpnIp, vpnAPIPort))
if err != nil {
return nil, status.Errorf(codes.Unavailable, "dial coordinator: %v", err)
}
@ -231,7 +231,7 @@ func (a *API) update(ctx context.Context) error {
defer cancel()
// TODO: replace hardcoded IP
conn, err := a.dialInsecure(ctx, net.JoinHostPort("10.118.0.1", vpnAPIPort))
conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort("10.118.0.1", vpnAPIPort))
if err != nil {
return err
}

View File

@ -13,6 +13,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
"github.com/stretchr/testify/assert"
@ -127,25 +128,26 @@ func TestActivateAsNode(t *testing.T) {
logger := zaptest.NewLogger(t)
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr}
dialer := testdialer.NewBufconnDialer()
netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
api := New(logger, cor, dialer, nil, nil, nil, nil)
api := New(logger, cor, dialer, nil, nil, nil)
defer api.Close()
vserver := grpc.NewServer()
vapi := &stubVPNAPI{peers: tc.updatedPeers, getUpdateErr: tc.getUpdateErr}
vpnproto.RegisterAPIServer(vserver, vapi)
go vserver.Serve(dialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
defer vserver.GracefulStop()
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
require.NoError(err)
pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
pubproto.RegisterAPIServer(pubserver, api)
go pubserver.Serve(dialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort)))
go pubserver.Serve(netDialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort)))
defer pubserver.GracefulStop()
_, nodeVPNPubKey, err := activateNode(require, dialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey)
_, nodeVPNPubKey, err := activateNode(require, netDialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey)
assert.Equal(tc.wantState, cor.state)
if tc.wantErr {
@ -215,9 +217,10 @@ func TestTriggerNodeUpdate(t *testing.T) {
logger := zaptest.NewLogger(t)
core := &fakeCore{state: tc.state}
dialer := testdialer.NewBufconnDialer()
netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
api := New(logger, core, dialer, nil, nil, nil, nil)
api := New(logger, core, dialer, nil, nil, nil)
vserver := grpc.NewServer()
vapi := &stubVPNAPI{
@ -225,7 +228,7 @@ func TestTriggerNodeUpdate(t *testing.T) {
getUpdateErr: tc.getUpdateErr,
}
vpnproto.RegisterAPIServer(vserver, vapi)
go vserver.Serve(dialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
defer vserver.GracefulStop()
_, err := api.TriggerNodeUpdate(context.Background(), &pubproto.TriggerNodeUpdateRequest{})
@ -290,9 +293,10 @@ func TestJoinCluster(t *testing.T) {
logger := zaptest.NewLogger(t)
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
dialer := testdialer.NewBufconnDialer()
netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
api := New(logger, core, dialer, nil, nil, nil, nil)
api := New(logger, core, dialer, nil, nil, nil)
vserver := grpc.NewServer()
vapi := &stubVPNAPI{
@ -304,7 +308,7 @@ func TestJoinCluster(t *testing.T) {
getJoinArgsErr: tc.getJoinArgsErr,
}
vpnproto.RegisterAPIServer(vserver, vapi)
go vserver.Serve(dialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort)))
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort)))
defer vserver.GracefulStop()
_, err := api.JoinCluster(context.Background(), &pubproto.JoinClusterRequest{CoordinatorVpnIp: "192.0.2.1"})
@ -322,7 +326,7 @@ func TestJoinCluster(t *testing.T) {
}
}
func activateNode(require *require.Assertions, dialer Dialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) {
func activateNode(require *require.Assertions, dialer netDialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) {
ctx := context.Background()
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(nodeIP, bindPort))
require.NoError(err)
@ -385,7 +389,7 @@ func activateNode(require *require.Assertions, dialer Dialer, messageSequence []
return diskUUID, nodeVPNPubKey, nil
}
func dialGRPC(ctx context.Context, dialer Dialer, target string) (*grpc.ClientConn, error) {
func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
if err != nil {
return nil, err
@ -429,3 +433,7 @@ func (a *stubVPNAPI) newServer() *grpc.Server {
vpnproto.RegisterAPIServer(server, a)
return server
}
type netDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View File

@ -8,13 +8,10 @@ import (
"sync"
"time"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/state/setup"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/peer"
)
@ -32,7 +29,6 @@ type API struct {
core Core
dialer Dialer
vpnAPIServer VPNAPIServer
validator atls.Validator
getPublicIPAddr GetIPAddrFunc
stopUpdate chan struct{}
wgClose sync.WaitGroup
@ -42,13 +38,12 @@ type API struct {
}
// New creates a new API.
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API {
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API {
return &API{
logger: logger,
core: core,
dialer: dialer,
vpnAPIServer: vpnAPIServer,
validator: validator,
getPublicIPAddr: getPublicIPAddr,
stopUpdate: make(chan struct{}, 1),
peerFromContext: peerFromContext,
@ -69,47 +64,6 @@ func (a *API) Close() {
a.wgClose.Wait()
}
func (a *API) dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{a.validator})
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target,
a.grpcWithDialer(),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
func (a *API) dialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target,
a.grpcWithDialer(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}
func (a *API) dialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target,
a.grpcWithDialer(),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
func (a *API) grpcWithDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return a.dialer.DialContext(ctx, "tcp", addr)
})
}
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type VPNAPIServer interface {
Listen(endpoint string) error
Serve() error
@ -135,3 +89,10 @@ func GetRecoveryPeerFromContext(ctx context.Context) (string, error) {
return net.JoinHostPort(peerIP, setup.RecoveryPort), nil
}
// Dialer can open grpc client connections with different levels of ATLS encryption / verification.
type Dialer interface {
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
DialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error)
DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error)
}

View File

@ -0,0 +1,71 @@
package grpcutil
import (
"context"
"net"
"github.com/edgelesssys/constellation/coordinator/atls"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
// Dialer can open grpc client connections with different levels of ATLS encryption / verification.
type Dialer struct {
validator atls.Validator
netDialer NetDialer
}
// NewDialer creates a new Dialer.
func NewDialer(validator atls.Validator, netDialer NetDialer) *Dialer {
return &Dialer{
validator: validator,
netDialer: netDialer,
}
}
// Dial creates a new grpc client connection to the given target using the atls validator.
func (d *Dialer) Dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{d.validator})
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target,
d.grpcWithDialer(),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
// DialInsecure creates a new grpc client connection to the given target without using encryption or verification.
// Only use this method when using another kind of encryption / verification (VPN, etc).
func (d *Dialer) DialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target,
d.grpcWithDialer(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}
// DialNoVerify creates a new grpc client connection to the given target without verifying the server's attestation.
func (d *Dialer) DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target,
d.grpcWithDialer(),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
func (d *Dialer) grpcWithDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return d.netDialer.DialContext(ctx, "tcp", addr)
})
}
// NetDialer implements the net Dialer interface.
type NetDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View File

@ -0,0 +1,106 @@
package grpcutil
import (
"context"
"testing"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/test/grpc_testing"
)
func TestDial(t *testing.T) {
testCases := map[string]struct {
tls bool
dialFn func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error)
wantErr bool
}{
"Dial with tls on server works": {
tls: true,
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.Dial(ctx, target)
},
},
"Dial without tls on server fails": {
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.Dial(ctx, target)
},
wantErr: true,
},
"DialNoVerify with tls on server works": {
tls: true,
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.DialNoVerify(ctx, target)
},
},
"DialNoVerify without tls on server fails": {
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.DialNoVerify(ctx, target)
},
wantErr: true,
},
"DialInsecure without tls on server works": {
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.DialInsecure(ctx, target)
},
},
"DialInsecure with tls on server fails": {
tls: true,
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.DialInsecure(ctx, target)
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
netDialer := testdialer.NewBufconnDialer()
dialer := NewDialer(&core.MockValidator{}, netDialer)
server := newServer(tc.tls)
api := &testAPI{}
grpc_testing.RegisterTestServiceServer(server, api)
go server.Serve(netDialer.GetListener("192.0.2.1:1234"))
defer server.Stop()
conn, err := tc.dialFn(dialer, context.Background(), "192.0.2.1:1234")
require.NoError(err)
defer conn.Close()
client := grpc_testing.NewTestServiceClient(conn)
_, err = client.EmptyCall(context.Background(), &grpc_testing.Empty{})
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
})
}
}
func newServer(tls bool) *grpc.Server {
if tls {
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
if err != nil {
panic(err)
}
return grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
}
return grpc.NewServer()
}
type testAPI struct {
grpc_testing.UnimplementedTestServiceServer
}
func (s *testAPI) EmptyCall(ctx context.Context, in *grpc_testing.Empty) (*grpc_testing.Empty, error) {
return &grpc_testing.Empty{}, nil
}

View File

@ -16,14 +16,14 @@ import (
// Download downloads a coordinator from a given debugd instance.
type Download struct {
dialer Dialer
dialer NetDialer
writer streamToFileWriter
serviceManager serviceManager
attemptedDownloads map[string]time.Time
}
// New creates a new Download.
func New(dialer Dialer, serviceManager serviceManager, writer streamToFileWriter) *Download {
func New(dialer NetDialer, serviceManager serviceManager, writer streamToFileWriter) *Download {
return &Download{
dialer: dialer,
writer: writer,
@ -91,6 +91,7 @@ type streamToFileWriter interface {
WriteStream(filename string, stream coordinator.ReadChunkStream, showProgress bool) error
}
type Dialer interface {
// NetDialer can open a net.Conn.
type NetDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View File

@ -9,10 +9,10 @@ import (
"testing"
"time"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/debugd/coordinator"
"github.com/edgelesssys/constellation/debugd/debugd"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/debugd/service/testdialer"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"

View File

@ -8,10 +8,10 @@ import (
"net"
"testing"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/debugd/coordinator"
"github.com/edgelesssys/constellation/debugd/debugd/deploy"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/debugd/service/testdialer"
"github.com/edgelesssys/constellation/debugd/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -351,11 +351,11 @@ func (s *stubServiceManager) WriteSystemdUnitFile(ctx context.Context, unit depl
return s.writeSystemdUnitFileErr
}
type dialer interface {
type netDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func dial(ctx context.Context, dialer dialer, target string) (*grpc.ClientConn, error) {
func dial(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target,
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", addr)

View File

@ -1,41 +0,0 @@
package testdialer
import (
"context"
"fmt"
"net"
"sync"
"google.golang.org/grpc/test/bufconn"
)
// BufconnDialer is a fake dialer based on gRPC bufconn package.
type BufconnDialer struct {
mut sync.Mutex
listeners map[string]*bufconn.Listener
}
// NewBufconnDialer creates a new bufconn dialer for testing.
func NewBufconnDialer() *BufconnDialer {
return &BufconnDialer{listeners: make(map[string]*bufconn.Listener)}
}
// DialContext implements the Dialer interface.
func (b *BufconnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
b.mut.Lock()
listener, ok := b.listeners[address]
b.mut.Unlock()
if !ok {
return nil, fmt.Errorf("could not connect to server on %v", address)
}
return listener.DialContext(ctx)
}
// GetListener returns a fake listener that is coupled with this dialer.
func (b *BufconnDialer) GetListener(endpoint string) net.Listener {
listener := bufconn.Listen(1024)
b.mut.Lock()
b.listeners[endpoint] = listener
b.mut.Unlock()
return listener
}