extract shared grpcutil dialer from pubapi

Signed-off-by: Malte Poll <mp@edgeless.systems>
This commit is contained in:
Malte Poll 2022-04-28 09:49:15 +02:00 committed by Malte Poll
parent 5ac72c730d
commit 77b0237dd5
17 changed files with 275 additions and 152 deletions

View file

@ -23,6 +23,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi" "github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi"
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi/kubectl" "github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi/kubectl"
"github.com/edgelesssys/constellation/coordinator/util" "github.com/edgelesssys/constellation/coordinator/util"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/wireguard" "github.com/edgelesssys/constellation/coordinator/wireguard"
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
"github.com/spf13/afero" "github.com/spf13/afero"
@ -153,7 +154,8 @@ func main() {
} }
fileHandler := file.NewHandler(fs) fileHandler := file.NewHandler(fs)
dialer := &net.Dialer{} netDialer := &net.Dialer{}
run(validator, issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube, dialer := grpcutil.NewDialer(validator, netDialer)
run(issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube,
metadata, cloudControllerManager, cloudNodeManager, autoscaler, encryptedDisk, etcdEndpoint, enforceEtcdTls, bindIP, bindPort, zapLoggerCore) metadata, cloudControllerManager, cloudNodeManager, autoscaler, encryptedDisk, etcdEndpoint, enforceEtcdTls, bindIP, bindPort, zapLoggerCore)
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi" "github.com/edgelesssys/constellation/coordinator/pubapi"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/store" "github.com/edgelesssys/constellation/coordinator/store"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
@ -26,7 +27,7 @@ import (
var version = "0.0.0" var version = "0.0.0"
func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer pubapi.Dialer, fileHandler file.Handler, func run(issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer *grpcutil.Dialer, fileHandler file.Handler,
kube core.Cluster, metadata core.ProviderMetadata, cloudControllerManager core.CloudControllerManager, cloudNodeManager core.CloudNodeManager, clusterAutoscaler core.ClusterAutoscaler, encryptedDisk core.EncryptedDisk, etcdEndpoint string, etcdTLS bool, bindIP, bindPort string, zapLoggerCore *zap.Logger, kube core.Cluster, metadata core.ProviderMetadata, cloudControllerManager core.CloudControllerManager, cloudNodeManager core.CloudNodeManager, clusterAutoscaler core.ClusterAutoscaler, encryptedDisk core.EncryptedDisk, etcdEndpoint string, etcdTLS bool, bindIP, bindPort string, zapLoggerCore *zap.Logger,
) { ) {
defer zapLoggerCore.Sync() defer zapLoggerCore.Sync()
@ -46,16 +47,16 @@ func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, o
if err != nil { if err != nil {
zapLoggerCore.Fatal("failed to create core", zap.Error(err)) zapLoggerCore.Fatal("failed to create core", zap.Error(err))
} }
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
// initialize state machine and wait for re-joining of the VPN (if applicable) // initialize state machine and wait for re-joining of the VPN (if applicable)
nodeActivated, err := core.Initialize() nodeActivated, err := core.Initialize()
if err != nil { if err != nil {
zapLoggerCore.Fatal("failed to initialize core", zap.Error(err)) zapLoggerCore.Fatal("failed to initialize core", zap.Error(err))
} }
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
zapLoggergRPC := zapLoggerPubapi.Named("gRPC") zapLoggergRPC := zapLoggerPubapi.Named("gRPC")
grpcServer := grpc.NewServer( grpcServer := grpc.NewServer(

View file

@ -18,6 +18,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/store" "github.com/edgelesssys/constellation/coordinator/store"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
@ -210,7 +211,7 @@ func TestConcurrent(t *testing.T) {
assert.Error(<-actCoordErrs) assert.Error(<-actCoordErrs)
} }
func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdialer.BufconnDialer, netw *network, endpoint string) (*grpc.Server, *pubapi.API, *fakeVPN) { func spawnPeer(require *require.Assertions, logger *zap.Logger, netDialer *testdialer.BufconnDialer, netw *network, endpoint string) (*grpc.Server, *pubapi.API, *fakeVPN) {
vpn := newVPN(netw, endpoint) vpn := newVPN(netw, endpoint)
cor, err := core.NewCore(vpn, &core.ClusterFake{}, &core.ProviderMetadataFake{}, &core.CloudControllerManagerFake{}, &core.CloudNodeManagerFake{}, &core.ClusterAutoscalerFake{}, &core.EncryptedDiskFake{}, logger, simulator.OpenSimulatedTPM, fakeStoreFactory{}, file.NewHandler(afero.NewMemMapFs())) cor, err := core.NewCore(vpn, &core.ClusterFake{}, &core.ProviderMetadataFake{}, &core.CloudControllerManagerFake{}, &core.CloudNodeManagerFake{}, &core.ClusterAutoscalerFake{}, &core.EncryptedDiskFake{}, logger, simulator.OpenSimulatedTPM, fakeStoreFactory{}, file.NewHandler(afero.NewMemMapFs()))
require.NoError(err) require.NoError(err)
@ -219,22 +220,23 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdial
getPublicAddr := func() (string, error) { getPublicAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
dialer := grpcutil.NewDialer(&core.MockValidator{}, netDialer)
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: netDialer}
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer} papi := pubapi.New(logger, cor, dialer, vapiServer, getPublicAddr, nil)
papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr, nil)
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}) tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
require.NoError(err) require.NoError(err)
server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
pubproto.RegisterAPIServer(server, papi) pubproto.RegisterAPIServer(server, papi)
listener := dialer.GetListener(endpoint) listener := netDialer.GetListener(endpoint)
go server.Serve(listener) go server.Serve(listener)
return server, papi, vpn return server, papi, vpn
} }
func activateCoordinator(require *require.Assertions, dialer pubapi.Dialer, coordinatorIP, bindPort string, nodeIPs []string) error { func activateCoordinator(require *require.Assertions, dialer netDialer, coordinatorIP, bindPort string, nodeIPs []string) error {
ctx := context.Background() ctx := context.Background()
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(coordinatorIP, bindPort)) conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(coordinatorIP, bindPort))
require.NoError(err) require.NoError(err)
@ -260,7 +262,7 @@ func activateCoordinator(require *require.Assertions, dialer pubapi.Dialer, coor
} }
} }
func dialGRPC(ctx context.Context, dialer pubapi.Dialer, target string) (*grpc.ClientConn, error) { func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}}) tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
if err != nil { if err != nil {
return nil, err return nil, err
@ -398,3 +400,7 @@ func (v *fakeVPN) recv() *packet {
} }
return &packet return &packet
} }
type netDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View file

@ -16,6 +16,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi" "github.com/edgelesssys/constellation/coordinator/pubapi"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/vpnapi" "github.com/edgelesssys/constellation/coordinator/vpnapi"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
"github.com/spf13/afero" "github.com/spf13/afero"
@ -112,13 +113,13 @@ func TestLegacyActivateCoordinator(t *testing.T) {
} }
// newMockCoreWithDialer creates a new core object with attestation mock and provided dialer for testing. // newMockCoreWithDialer creates a new core object with attestation mock and provided dialer for testing.
func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) { func newMockCoreWithDialer(bufDialer *bufconnDialer) (*Core, *pubapi.API, error) {
zapLogger, err := zap.NewDevelopment() zapLogger, err := zap.NewDevelopment()
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
validator := NewMockValidator() dialer := grpcutil.NewDialer(NewMockValidator(), bufDialer)
vpn := &stubVPN{} vpn := &stubVPN{}
kubeFake := &ClusterFake{} kubeFake := &ClusterFake{}
metadataFake := &ProviderMetadataFake{} metadataFake := &ProviderMetadataFake{}
@ -138,8 +139,8 @@ func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) {
return nil, nil, err return nil, nil, err
} }
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer} vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: bufDialer}
papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr, nil) papi := pubapi.New(zapLogger, core, dialer, vapiServer, getPublicAddr, nil)
return core, papi, nil return core, papi, nil
} }

View file

@ -204,7 +204,7 @@ func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestState
return nil, status.Errorf(codes.Internal, "%v", err) return nil, status.Errorf(codes.Internal, "%v", err)
} }
conn, err := a.dial(ctx, peer) conn, err := a.dialer.Dial(ctx, peer)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err) return nil, status.Errorf(codes.Internal, "%v", err)
} }
@ -297,7 +297,7 @@ func (a *API) activateNode(nodePublicIP string, nodeVPNIP string, initialPeers [
ctx, cancel := context.WithTimeout(context.Background(), deadlineDuration) ctx, cancel := context.WithTimeout(context.Background(), deadlineDuration)
defer cancel() defer cancel()
conn, err := a.dial(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort)) conn, err := a.dialer.Dial(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -417,7 +417,7 @@ func (a *API) joinCluster(nodePublicIP string) error {
} }
// We don't verify the peer certificate here, since JoinCluster triggers a connection over VPN // We don't verify the peer certificate here, since JoinCluster triggers a connection over VPN
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted // The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort)) conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
if err != nil { if err != nil {
return err return err
} }
@ -455,7 +455,7 @@ func (a *API) triggerNodeUpdate(nodePublicIP string) error {
// We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN // We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted // The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort)) conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
if err != nil { if err != nil {
return err return err
} }

View file

@ -18,6 +18,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/state/keyservice/keyproto" "github.com/edgelesssys/constellation/state/keyservice/keyproto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -125,13 +126,14 @@ func TestActivateAsCoordinator(t *testing.T) {
ownerID: []byte("ownerID"), ownerID: []byte("ownerID"),
clusterID: []byte("clusterID"), clusterID: []byte("clusterID"),
} }
dialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) { getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
// spawn nodes // spawn nodes
@ -142,7 +144,7 @@ func TestActivateAsCoordinator(t *testing.T) {
server := n.newServer() server := n.newServer()
wg.Add(1) wg.Add(1)
go func(endpoint string) { go func(endpoint string) {
listener := dialer.GetListener(endpoint) listener := netDialer.GetListener(endpoint)
wg.Done() wg.Done()
_ = server.Serve(listener) _ = server.Serve(listener)
}(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort)) }(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort))
@ -256,13 +258,14 @@ func TestActivateAdditionalNodes(t *testing.T) {
require := require.New(t) require := require.New(t)
core := &fakeCore{state: tc.state} core := &fakeCore{state: tc.state}
dialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) { getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr, nil) api := New(zaptest.NewLogger(t), core, dialer, nil, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
// spawn nodes // spawn nodes
var nodePublicIPs []string var nodePublicIPs []string
@ -272,7 +275,7 @@ func TestActivateAdditionalNodes(t *testing.T) {
server := n.newServer() server := n.newServer()
wg.Add(1) wg.Add(1)
go func(endpoint string) { go func(endpoint string) {
listener := dialer.GetListener(endpoint) listener := netDialer.GetListener(endpoint)
wg.Done() wg.Done()
_ = server.Serve(listener) _ = server.Serve(listener)
}(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort)) }(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort))
@ -311,7 +314,7 @@ func TestAssemblePeerStruct(t *testing.T) {
vpnPubKey := []byte{2, 3, 4} vpnPubKey := []byte{2, 3, 4}
core := &fakeCore{vpnPubKey: vpnPubKey} core := &fakeCore{vpnPubKey: vpnPubKey}
api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr, nil) api := New(zaptest.NewLogger(t), core, nil, nil, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
vpnIP, err := core.GetVPNIP() vpnIP, err := core.GetVPNIP()
@ -512,7 +515,8 @@ func TestRequestStateDiskKey(t *testing.T) {
dataKey: tc.dataKey, dataKey: tc.dataKey,
getDataKeyErr: tc.getDataKeyErr, getDataKeyErr: tc.getDataKeyErr,
} }
api := New(zaptest.NewLogger(t), core, &net.Dialer{}, nil, dummyValidator{}, nil, getPeerFromContext)
api := New(zaptest.NewLogger(t), core, grpcutil.NewDialer(dummyValidator{}, &net.Dialer{}), nil, nil, getPeerFromContext)
_, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{}) _, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{})
if tc.wantErr { if tc.wantErr {

View file

@ -199,7 +199,7 @@ func (a *API) activateCoordinator(ctx context.Context, coordinatorIP string) err
return err return err
} }
conn, err := a.dial(ctx, net.JoinHostPort(coordinatorIP, endpointAVPNPort)) conn, err := a.dialer.Dial(ctx, net.JoinHostPort(coordinatorIP, endpointAVPNPort))
if err != nil { if err != nil {
return fmt.Errorf("dialing new coordinator: %v", err) return fmt.Errorf("dialing new coordinator: %v", err)
} }
@ -271,7 +271,7 @@ func (a *API) triggerCoordinatorUpdate(ctx context.Context, publicIP string) err
// We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN // We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted // The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(publicIP, endpointAVPNPort)) conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(publicIP, endpointAVPNPort))
if err != nil { if err != nil {
return err return err
} }
@ -284,7 +284,7 @@ func (a *API) triggerCoordinatorUpdate(ctx context.Context, publicIP string) err
} }
func (a *API) getk8SCoordinatorJoinArgs(ctx context.Context, coordinatorIP, port string) (*kubeadm.BootstrapTokenDiscovery, string, error) { func (a *API) getk8SCoordinatorJoinArgs(ctx context.Context, coordinatorIP, port string) (*kubeadm.BootstrapTokenDiscovery, string, error) {
conn, err := a.dialInsecure(ctx, net.JoinHostPort(coordinatorIP, port)) conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort(coordinatorIP, port))
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View file

@ -10,6 +10,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -92,18 +93,19 @@ func TestActivateAsAdditionalCoordinator(t *testing.T) {
ownerID: []byte("ownerID"), ownerID: []byte("ownerID"),
clusterID: []byte("clusterID"), clusterID: []byte("clusterID"),
} }
dialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) { getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
// spawn vpnServer // spawn vpnServer
vpnapiServer := tc.vpnapi.newServer() vpnapiServer := tc.vpnapi.newServer()
go vpnapiServer.Serve(dialer.GetListener(net.JoinHostPort(tc.coordinators.peer.VPNIP, vpnAPIPort))) go vpnapiServer.Serve(netDialer.GetListener(net.JoinHostPort(tc.coordinators.peer.VPNIP, vpnAPIPort)))
defer vpnapiServer.GracefulStop() defer vpnapiServer.GracefulStop()
_, err := api.ActivateAsAdditionalCoordinator(context.Background(), &pubproto.ActivateAsAdditionalCoordinatorRequest{ _, err := api.ActivateAsAdditionalCoordinator(context.Background(), &pubproto.ActivateAsAdditionalCoordinatorRequest{
@ -163,9 +165,9 @@ func TestTriggerCoordinatorUpdate(t *testing.T) {
state: tc.state, state: tc.state,
peers: tc.peers, peers: tc.peers,
} }
dialer := testdialer.NewBufconnDialer() dialer := grpcutil.NewDialer(fakeValidator{}, nil)
api := New(logger, core, dialer, nil, nil, nil, nil) api := New(logger, core, dialer, nil, nil, nil)
_, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{}) _, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{})
if tc.wantErr { if tc.wantErr {
@ -236,20 +238,21 @@ func TestActivateAdditionalCoordinators(t *testing.T) {
ownerID: []byte("ownerID"), ownerID: []byte("ownerID"),
clusterID: []byte("clusterID"), clusterID: []byte("clusterID"),
} }
dialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
getPublicIPAddr := func() (string, error) { getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
// spawn coordinator // spawn coordinator
tc.coordinators.activateErr = tc.activateErr tc.coordinators.activateErr = tc.activateErr
tc.coordinators.getPubKeyErr = tc.getPublicKeyErr tc.coordinators.getPubKeyErr = tc.getPublicKeyErr
server := tc.coordinators.newServer() server := tc.coordinators.newServer()
go server.Serve(dialer.GetListener(net.JoinHostPort(tc.coordinators.peer.PublicIP, endpointAVPNPort))) go server.Serve(netDialer.GetListener(net.JoinHostPort(tc.coordinators.peer.PublicIP, endpointAVPNPort)))
defer server.GracefulStop() defer server.GracefulStop()
_, err := api.ActivateAdditionalCoordinator(context.Background(), &pubproto.ActivateAdditionalCoordinatorRequest{CoordinatorPublicIp: tc.coordinators.peer.PublicIP}) _, err := api.ActivateAdditionalCoordinator(context.Background(), &pubproto.ActivateAdditionalCoordinatorRequest{CoordinatorPublicIp: tc.coordinators.peer.PublicIP})
@ -293,13 +296,13 @@ func TestGetPeerVPNPublicKey(t *testing.T) {
vpnPubKey: tc.coordinator.peer.VPNPubKey, vpnPubKey: tc.coordinator.peer.VPNPubKey,
getvpnPubKeyErr: tc.getVPNPubKeyErr, getvpnPubKeyErr: tc.getVPNPubKeyErr,
} }
dialer := testdialer.NewBufconnDialer() dialer := grpcutil.NewDialer(fakeValidator{}, testdialer.NewBufconnDialer())
getPublicIPAddr := func() (string, error) { getPublicIPAddr := func() (string, error) {
return "192.0.2.1", nil return "192.0.2.1", nil
} }
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil) api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
defer api.Close() defer api.Close()
resp, err := api.GetPeerVPNPublicKey(context.Background(), &pubproto.GetPeerVPNPublicKeyRequest{}) resp, err := api.GetPeerVPNPublicKey(context.Background(), &pubproto.GetPeerVPNPublicKeyRequest{})

View file

@ -168,7 +168,7 @@ func (a *API) JoinCluster(ctx context.Context, in *pubproto.JoinClusterRequest)
return nil, status.Errorf(codes.FailedPrecondition, "node is not in required state for cluster join: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "node is not in required state for cluster join: %v", err)
} }
conn, err := a.dialInsecure(ctx, net.JoinHostPort(in.CoordinatorVpnIp, vpnAPIPort)) conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort(in.CoordinatorVpnIp, vpnAPIPort))
if err != nil { if err != nil {
return nil, status.Errorf(codes.Unavailable, "dial coordinator: %v", err) return nil, status.Errorf(codes.Unavailable, "dial coordinator: %v", err)
} }
@ -231,7 +231,7 @@ func (a *API) update(ctx context.Context) error {
defer cancel() defer cancel()
// TODO: replace hardcoded IP // TODO: replace hardcoded IP
conn, err := a.dialInsecure(ctx, net.JoinHostPort("10.118.0.1", vpnAPIPort)) conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort("10.118.0.1", vpnAPIPort))
if err != nil { if err != nil {
return err return err
} }

View file

@ -13,6 +13,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role" "github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer" "github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto" "github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -127,25 +128,26 @@ func TestActivateAsNode(t *testing.T) {
logger := zaptest.NewLogger(t) logger := zaptest.NewLogger(t)
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr} cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr}
dialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
api := New(logger, cor, dialer, nil, nil, nil, nil) api := New(logger, cor, dialer, nil, nil, nil)
defer api.Close() defer api.Close()
vserver := grpc.NewServer() vserver := grpc.NewServer()
vapi := &stubVPNAPI{peers: tc.updatedPeers, getUpdateErr: tc.getUpdateErr} vapi := &stubVPNAPI{peers: tc.updatedPeers, getUpdateErr: tc.getUpdateErr}
vpnproto.RegisterAPIServer(vserver, vapi) vpnproto.RegisterAPIServer(vserver, vapi)
go vserver.Serve(dialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort))) go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
defer vserver.GracefulStop() defer vserver.GracefulStop()
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}) tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
require.NoError(err) require.NoError(err)
pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
pubproto.RegisterAPIServer(pubserver, api) pubproto.RegisterAPIServer(pubserver, api)
go pubserver.Serve(dialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort))) go pubserver.Serve(netDialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort)))
defer pubserver.GracefulStop() defer pubserver.GracefulStop()
_, nodeVPNPubKey, err := activateNode(require, dialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey) _, nodeVPNPubKey, err := activateNode(require, netDialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey)
assert.Equal(tc.wantState, cor.state) assert.Equal(tc.wantState, cor.state)
if tc.wantErr { if tc.wantErr {
@ -215,9 +217,10 @@ func TestTriggerNodeUpdate(t *testing.T) {
logger := zaptest.NewLogger(t) logger := zaptest.NewLogger(t)
core := &fakeCore{state: tc.state} core := &fakeCore{state: tc.state}
dialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
api := New(logger, core, dialer, nil, nil, nil, nil) api := New(logger, core, dialer, nil, nil, nil)
vserver := grpc.NewServer() vserver := grpc.NewServer()
vapi := &stubVPNAPI{ vapi := &stubVPNAPI{
@ -225,7 +228,7 @@ func TestTriggerNodeUpdate(t *testing.T) {
getUpdateErr: tc.getUpdateErr, getUpdateErr: tc.getUpdateErr,
} }
vpnproto.RegisterAPIServer(vserver, vapi) vpnproto.RegisterAPIServer(vserver, vapi)
go vserver.Serve(dialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort))) go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
defer vserver.GracefulStop() defer vserver.GracefulStop()
_, err := api.TriggerNodeUpdate(context.Background(), &pubproto.TriggerNodeUpdateRequest{}) _, err := api.TriggerNodeUpdate(context.Background(), &pubproto.TriggerNodeUpdateRequest{})
@ -290,9 +293,10 @@ func TestJoinCluster(t *testing.T) {
logger := zaptest.NewLogger(t) logger := zaptest.NewLogger(t)
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr} core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
dialer := testdialer.NewBufconnDialer() netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
api := New(logger, core, dialer, nil, nil, nil, nil) api := New(logger, core, dialer, nil, nil, nil)
vserver := grpc.NewServer() vserver := grpc.NewServer()
vapi := &stubVPNAPI{ vapi := &stubVPNAPI{
@ -304,7 +308,7 @@ func TestJoinCluster(t *testing.T) {
getJoinArgsErr: tc.getJoinArgsErr, getJoinArgsErr: tc.getJoinArgsErr,
} }
vpnproto.RegisterAPIServer(vserver, vapi) vpnproto.RegisterAPIServer(vserver, vapi)
go vserver.Serve(dialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort))) go vserver.Serve(netDialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort)))
defer vserver.GracefulStop() defer vserver.GracefulStop()
_, err := api.JoinCluster(context.Background(), &pubproto.JoinClusterRequest{CoordinatorVpnIp: "192.0.2.1"}) _, err := api.JoinCluster(context.Background(), &pubproto.JoinClusterRequest{CoordinatorVpnIp: "192.0.2.1"})
@ -322,7 +326,7 @@ func TestJoinCluster(t *testing.T) {
} }
} }
func activateNode(require *require.Assertions, dialer Dialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) { func activateNode(require *require.Assertions, dialer netDialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) {
ctx := context.Background() ctx := context.Background()
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(nodeIP, bindPort)) conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(nodeIP, bindPort))
require.NoError(err) require.NoError(err)
@ -385,7 +389,7 @@ func activateNode(require *require.Assertions, dialer Dialer, messageSequence []
return diskUUID, nodeVPNPubKey, nil return diskUUID, nodeVPNPubKey, nil
} }
func dialGRPC(ctx context.Context, dialer Dialer, target string) (*grpc.ClientConn, error) { func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}}) tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
if err != nil { if err != nil {
return nil, err return nil, err
@ -429,3 +433,7 @@ func (a *stubVPNAPI) newServer() *grpc.Server {
vpnproto.RegisterAPIServer(server, a) vpnproto.RegisterAPIServer(server, a)
return server return server
} }
type netDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View file

@ -8,13 +8,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/state/setup" "github.com/edgelesssys/constellation/state/setup"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/peer" "google.golang.org/grpc/peer"
) )
@ -32,7 +29,6 @@ type API struct {
core Core core Core
dialer Dialer dialer Dialer
vpnAPIServer VPNAPIServer vpnAPIServer VPNAPIServer
validator atls.Validator
getPublicIPAddr GetIPAddrFunc getPublicIPAddr GetIPAddrFunc
stopUpdate chan struct{} stopUpdate chan struct{}
wgClose sync.WaitGroup wgClose sync.WaitGroup
@ -42,13 +38,12 @@ type API struct {
} }
// New creates a new API. // New creates a new API.
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API { func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API {
return &API{ return &API{
logger: logger, logger: logger,
core: core, core: core,
dialer: dialer, dialer: dialer,
vpnAPIServer: vpnAPIServer, vpnAPIServer: vpnAPIServer,
validator: validator,
getPublicIPAddr: getPublicIPAddr, getPublicIPAddr: getPublicIPAddr,
stopUpdate: make(chan struct{}, 1), stopUpdate: make(chan struct{}, 1),
peerFromContext: peerFromContext, peerFromContext: peerFromContext,
@ -69,47 +64,6 @@ func (a *API) Close() {
a.wgClose.Wait() a.wgClose.Wait()
} }
func (a *API) dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{a.validator})
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target,
a.grpcWithDialer(),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
func (a *API) dialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target,
a.grpcWithDialer(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}
func (a *API) dialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target,
a.grpcWithDialer(),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
func (a *API) grpcWithDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return a.dialer.DialContext(ctx, "tcp", addr)
})
}
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type VPNAPIServer interface { type VPNAPIServer interface {
Listen(endpoint string) error Listen(endpoint string) error
Serve() error Serve() error
@ -135,3 +89,10 @@ func GetRecoveryPeerFromContext(ctx context.Context) (string, error) {
return net.JoinHostPort(peerIP, setup.RecoveryPort), nil return net.JoinHostPort(peerIP, setup.RecoveryPort), nil
} }
// Dialer can open grpc client connections with different levels of ATLS encryption / verification.
type Dialer interface {
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
DialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error)
DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error)
}

View file

@ -0,0 +1,71 @@
package grpcutil
import (
"context"
"net"
"github.com/edgelesssys/constellation/coordinator/atls"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
// Dialer can open grpc client connections with different levels of ATLS encryption / verification.
type Dialer struct {
validator atls.Validator
netDialer NetDialer
}
// NewDialer creates a new Dialer.
func NewDialer(validator atls.Validator, netDialer NetDialer) *Dialer {
return &Dialer{
validator: validator,
netDialer: netDialer,
}
}
// Dial creates a new grpc client connection to the given target using the atls validator.
func (d *Dialer) Dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{d.validator})
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target,
d.grpcWithDialer(),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
// DialInsecure creates a new grpc client connection to the given target without using encryption or verification.
// Only use this method when using another kind of encryption / verification (VPN, etc).
func (d *Dialer) DialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target,
d.grpcWithDialer(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}
// DialNoVerify creates a new grpc client connection to the given target without verifying the server's attestation.
func (d *Dialer) DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target,
d.grpcWithDialer(),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
func (d *Dialer) grpcWithDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return d.netDialer.DialContext(ctx, "tcp", addr)
})
}
// NetDialer implements the net Dialer interface.
type NetDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View file

@ -0,0 +1,106 @@
package grpcutil
import (
"context"
"testing"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/test/grpc_testing"
)
func TestDial(t *testing.T) {
testCases := map[string]struct {
tls bool
dialFn func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error)
wantErr bool
}{
"Dial with tls on server works": {
tls: true,
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.Dial(ctx, target)
},
},
"Dial without tls on server fails": {
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.Dial(ctx, target)
},
wantErr: true,
},
"DialNoVerify with tls on server works": {
tls: true,
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.DialNoVerify(ctx, target)
},
},
"DialNoVerify without tls on server fails": {
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.DialNoVerify(ctx, target)
},
wantErr: true,
},
"DialInsecure without tls on server works": {
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.DialInsecure(ctx, target)
},
},
"DialInsecure with tls on server fails": {
tls: true,
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
return dialer.DialInsecure(ctx, target)
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
netDialer := testdialer.NewBufconnDialer()
dialer := NewDialer(&core.MockValidator{}, netDialer)
server := newServer(tc.tls)
api := &testAPI{}
grpc_testing.RegisterTestServiceServer(server, api)
go server.Serve(netDialer.GetListener("192.0.2.1:1234"))
defer server.Stop()
conn, err := tc.dialFn(dialer, context.Background(), "192.0.2.1:1234")
require.NoError(err)
defer conn.Close()
client := grpc_testing.NewTestServiceClient(conn)
_, err = client.EmptyCall(context.Background(), &grpc_testing.Empty{})
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
})
}
}
func newServer(tls bool) *grpc.Server {
if tls {
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
if err != nil {
panic(err)
}
return grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
}
return grpc.NewServer()
}
type testAPI struct {
grpc_testing.UnimplementedTestServiceServer
}
func (s *testAPI) EmptyCall(ctx context.Context, in *grpc_testing.Empty) (*grpc_testing.Empty, error) {
return &grpc_testing.Empty{}, nil
}

View file

@ -16,14 +16,14 @@ import (
// Download downloads a coordinator from a given debugd instance. // Download downloads a coordinator from a given debugd instance.
type Download struct { type Download struct {
dialer Dialer dialer NetDialer
writer streamToFileWriter writer streamToFileWriter
serviceManager serviceManager serviceManager serviceManager
attemptedDownloads map[string]time.Time attemptedDownloads map[string]time.Time
} }
// New creates a new Download. // New creates a new Download.
func New(dialer Dialer, serviceManager serviceManager, writer streamToFileWriter) *Download { func New(dialer NetDialer, serviceManager serviceManager, writer streamToFileWriter) *Download {
return &Download{ return &Download{
dialer: dialer, dialer: dialer,
writer: writer, writer: writer,
@ -91,6 +91,7 @@ type streamToFileWriter interface {
WriteStream(filename string, stream coordinator.ReadChunkStream, showProgress bool) error WriteStream(filename string, stream coordinator.ReadChunkStream, showProgress bool) error
} }
type Dialer interface { // NetDialer can open a net.Conn.
type NetDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error) DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }

View file

@ -9,10 +9,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/debugd/coordinator" "github.com/edgelesssys/constellation/debugd/coordinator"
"github.com/edgelesssys/constellation/debugd/debugd" "github.com/edgelesssys/constellation/debugd/debugd"
pb "github.com/edgelesssys/constellation/debugd/service" pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/debugd/service/testdialer"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"

View file

@ -8,10 +8,10 @@ import (
"net" "net"
"testing" "testing"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/debugd/coordinator" "github.com/edgelesssys/constellation/debugd/coordinator"
"github.com/edgelesssys/constellation/debugd/debugd/deploy" "github.com/edgelesssys/constellation/debugd/debugd/deploy"
pb "github.com/edgelesssys/constellation/debugd/service" pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/debugd/service/testdialer"
"github.com/edgelesssys/constellation/debugd/ssh" "github.com/edgelesssys/constellation/debugd/ssh"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -351,11 +351,11 @@ func (s *stubServiceManager) WriteSystemdUnitFile(ctx context.Context, unit depl
return s.writeSystemdUnitFileErr return s.writeSystemdUnitFileErr
} }
type dialer interface { type netDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error) DialContext(ctx context.Context, network, address string) (net.Conn, error)
} }
func dial(ctx context.Context, dialer dialer, target string) (*grpc.ClientConn, error) { func dial(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target, return grpc.DialContext(ctx, target,
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", addr) return dialer.DialContext(ctx, "tcp", addr)

View file

@ -1,41 +0,0 @@
package testdialer
import (
"context"
"fmt"
"net"
"sync"
"google.golang.org/grpc/test/bufconn"
)
// BufconnDialer is a fake dialer based on gRPC bufconn package.
type BufconnDialer struct {
mut sync.Mutex
listeners map[string]*bufconn.Listener
}
// NewBufconnDialer creates a new bufconn dialer for testing.
func NewBufconnDialer() *BufconnDialer {
return &BufconnDialer{listeners: make(map[string]*bufconn.Listener)}
}
// DialContext implements the Dialer interface.
func (b *BufconnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
b.mut.Lock()
listener, ok := b.listeners[address]
b.mut.Unlock()
if !ok {
return nil, fmt.Errorf("could not connect to server on %v", address)
}
return listener.DialContext(ctx)
}
// GetListener returns a fake listener that is coupled with this dialer.
func (b *BufconnDialer) GetListener(endpoint string) net.Listener {
listener := bufconn.Listen(1024)
b.mut.Lock()
b.listeners[endpoint] = listener
b.mut.Unlock()
return listener
}