mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-23 05:41:19 -05:00
extract shared grpcutil dialer from pubapi
Signed-off-by: Malte Poll <mp@edgeless.systems>
This commit is contained in:
parent
5ac72c730d
commit
77b0237dd5
@ -23,6 +23,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi"
|
||||
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi/kubectl"
|
||||
"github.com/edgelesssys/constellation/coordinator/util"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||
"github.com/edgelesssys/constellation/coordinator/wireguard"
|
||||
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
|
||||
"github.com/spf13/afero"
|
||||
@ -153,7 +154,8 @@ func main() {
|
||||
}
|
||||
|
||||
fileHandler := file.NewHandler(fs)
|
||||
dialer := &net.Dialer{}
|
||||
run(validator, issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube,
|
||||
netDialer := &net.Dialer{}
|
||||
dialer := grpcutil.NewDialer(validator, netDialer)
|
||||
run(issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube,
|
||||
metadata, cloudControllerManager, cloudNodeManager, autoscaler, encryptedDisk, etcdEndpoint, enforceEtcdTls, bindIP, bindPort, zapLoggerCore)
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/store"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||
"github.com/edgelesssys/constellation/coordinator/vpnapi"
|
||||
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
@ -26,7 +27,7 @@ import (
|
||||
|
||||
var version = "0.0.0"
|
||||
|
||||
func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer pubapi.Dialer, fileHandler file.Handler,
|
||||
func run(issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer *grpcutil.Dialer, fileHandler file.Handler,
|
||||
kube core.Cluster, metadata core.ProviderMetadata, cloudControllerManager core.CloudControllerManager, cloudNodeManager core.CloudNodeManager, clusterAutoscaler core.ClusterAutoscaler, encryptedDisk core.EncryptedDisk, etcdEndpoint string, etcdTLS bool, bindIP, bindPort string, zapLoggerCore *zap.Logger,
|
||||
) {
|
||||
defer zapLoggerCore.Sync()
|
||||
@ -46,16 +47,16 @@ func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, o
|
||||
if err != nil {
|
||||
zapLoggerCore.Fatal("failed to create core", zap.Error(err))
|
||||
}
|
||||
|
||||
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
|
||||
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
|
||||
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
|
||||
// initialize state machine and wait for re-joining of the VPN (if applicable)
|
||||
nodeActivated, err := core.Initialize()
|
||||
if err != nil {
|
||||
zapLoggerCore.Fatal("failed to initialize core", zap.Error(err))
|
||||
}
|
||||
|
||||
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
|
||||
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
|
||||
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
|
||||
|
||||
zapLoggergRPC := zapLoggerPubapi.Named("gRPC")
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/coordinator/store"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||
"github.com/edgelesssys/constellation/coordinator/vpnapi"
|
||||
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
||||
@ -210,7 +211,7 @@ func TestConcurrent(t *testing.T) {
|
||||
assert.Error(<-actCoordErrs)
|
||||
}
|
||||
|
||||
func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdialer.BufconnDialer, netw *network, endpoint string) (*grpc.Server, *pubapi.API, *fakeVPN) {
|
||||
func spawnPeer(require *require.Assertions, logger *zap.Logger, netDialer *testdialer.BufconnDialer, netw *network, endpoint string) (*grpc.Server, *pubapi.API, *fakeVPN) {
|
||||
vpn := newVPN(netw, endpoint)
|
||||
cor, err := core.NewCore(vpn, &core.ClusterFake{}, &core.ProviderMetadataFake{}, &core.CloudControllerManagerFake{}, &core.CloudNodeManagerFake{}, &core.ClusterAutoscalerFake{}, &core.EncryptedDiskFake{}, logger, simulator.OpenSimulatedTPM, fakeStoreFactory{}, file.NewHandler(afero.NewMemMapFs()))
|
||||
require.NoError(err)
|
||||
@ -219,22 +220,23 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdial
|
||||
getPublicAddr := func() (string, error) {
|
||||
return "192.0.2.1", nil
|
||||
}
|
||||
dialer := grpcutil.NewDialer(&core.MockValidator{}, netDialer)
|
||||
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: netDialer}
|
||||
|
||||
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer}
|
||||
papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr, nil)
|
||||
papi := pubapi.New(logger, cor, dialer, vapiServer, getPublicAddr, nil)
|
||||
|
||||
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
||||
require.NoError(err)
|
||||
server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
pubproto.RegisterAPIServer(server, papi)
|
||||
|
||||
listener := dialer.GetListener(endpoint)
|
||||
listener := netDialer.GetListener(endpoint)
|
||||
go server.Serve(listener)
|
||||
|
||||
return server, papi, vpn
|
||||
}
|
||||
|
||||
func activateCoordinator(require *require.Assertions, dialer pubapi.Dialer, coordinatorIP, bindPort string, nodeIPs []string) error {
|
||||
func activateCoordinator(require *require.Assertions, dialer netDialer, coordinatorIP, bindPort string, nodeIPs []string) error {
|
||||
ctx := context.Background()
|
||||
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(coordinatorIP, bindPort))
|
||||
require.NoError(err)
|
||||
@ -260,7 +262,7 @@ func activateCoordinator(require *require.Assertions, dialer pubapi.Dialer, coor
|
||||
}
|
||||
}
|
||||
|
||||
func dialGRPC(ctx context.Context, dialer pubapi.Dialer, target string) (*grpc.ClientConn, error) {
|
||||
func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -398,3 +400,7 @@ func (v *fakeVPN) recv() *packet {
|
||||
}
|
||||
return &packet
|
||||
}
|
||||
|
||||
type netDialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||
"github.com/edgelesssys/constellation/coordinator/vpnapi"
|
||||
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
||||
"github.com/spf13/afero"
|
||||
@ -112,13 +113,13 @@ func TestLegacyActivateCoordinator(t *testing.T) {
|
||||
}
|
||||
|
||||
// newMockCoreWithDialer creates a new core object with attestation mock and provided dialer for testing.
|
||||
func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) {
|
||||
func newMockCoreWithDialer(bufDialer *bufconnDialer) (*Core, *pubapi.API, error) {
|
||||
zapLogger, err := zap.NewDevelopment()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
validator := NewMockValidator()
|
||||
dialer := grpcutil.NewDialer(NewMockValidator(), bufDialer)
|
||||
vpn := &stubVPN{}
|
||||
kubeFake := &ClusterFake{}
|
||||
metadataFake := &ProviderMetadataFake{}
|
||||
@ -138,8 +139,8 @@ func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer}
|
||||
papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr, nil)
|
||||
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: bufDialer}
|
||||
papi := pubapi.New(zapLogger, core, dialer, vapiServer, getPublicAddr, nil)
|
||||
|
||||
return core, papi, nil
|
||||
}
|
||||
|
@ -204,7 +204,7 @@ func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestState
|
||||
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||
}
|
||||
|
||||
conn, err := a.dial(ctx, peer)
|
||||
conn, err := a.dialer.Dial(ctx, peer)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||
}
|
||||
@ -297,7 +297,7 @@ func (a *API) activateNode(nodePublicIP string, nodeVPNIP string, initialPeers [
|
||||
ctx, cancel := context.WithTimeout(context.Background(), deadlineDuration)
|
||||
defer cancel()
|
||||
|
||||
conn, err := a.dial(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
||||
conn, err := a.dialer.Dial(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -417,7 +417,7 @@ func (a *API) joinCluster(nodePublicIP string) error {
|
||||
}
|
||||
// We don't verify the peer certificate here, since JoinCluster triggers a connection over VPN
|
||||
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
|
||||
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
||||
conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -455,7 +455,7 @@ func (a *API) triggerNodeUpdate(nodePublicIP string) error {
|
||||
|
||||
// We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN
|
||||
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
|
||||
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
||||
conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/role"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -125,13 +126,14 @@ func TestActivateAsCoordinator(t *testing.T) {
|
||||
ownerID: []byte("ownerID"),
|
||||
clusterID: []byte("clusterID"),
|
||||
}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||
|
||||
getPublicIPAddr := func() (string, error) {
|
||||
return "192.0.2.1", nil
|
||||
}
|
||||
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
|
||||
// spawn nodes
|
||||
@ -142,7 +144,7 @@ func TestActivateAsCoordinator(t *testing.T) {
|
||||
server := n.newServer()
|
||||
wg.Add(1)
|
||||
go func(endpoint string) {
|
||||
listener := dialer.GetListener(endpoint)
|
||||
listener := netDialer.GetListener(endpoint)
|
||||
wg.Done()
|
||||
_ = server.Serve(listener)
|
||||
}(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort))
|
||||
@ -256,13 +258,14 @@ func TestActivateAdditionalNodes(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
core := &fakeCore{state: tc.state}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||
|
||||
getPublicIPAddr := func() (string, error) {
|
||||
return "192.0.2.1", nil
|
||||
}
|
||||
|
||||
api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr, nil)
|
||||
api := New(zaptest.NewLogger(t), core, dialer, nil, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
// spawn nodes
|
||||
var nodePublicIPs []string
|
||||
@ -272,7 +275,7 @@ func TestActivateAdditionalNodes(t *testing.T) {
|
||||
server := n.newServer()
|
||||
wg.Add(1)
|
||||
go func(endpoint string) {
|
||||
listener := dialer.GetListener(endpoint)
|
||||
listener := netDialer.GetListener(endpoint)
|
||||
wg.Done()
|
||||
_ = server.Serve(listener)
|
||||
}(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort))
|
||||
@ -311,7 +314,7 @@ func TestAssemblePeerStruct(t *testing.T) {
|
||||
|
||||
vpnPubKey := []byte{2, 3, 4}
|
||||
core := &fakeCore{vpnPubKey: vpnPubKey}
|
||||
api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr, nil)
|
||||
api := New(zaptest.NewLogger(t), core, nil, nil, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
|
||||
vpnIP, err := core.GetVPNIP()
|
||||
@ -512,7 +515,8 @@ func TestRequestStateDiskKey(t *testing.T) {
|
||||
dataKey: tc.dataKey,
|
||||
getDataKeyErr: tc.getDataKeyErr,
|
||||
}
|
||||
api := New(zaptest.NewLogger(t), core, &net.Dialer{}, nil, dummyValidator{}, nil, getPeerFromContext)
|
||||
|
||||
api := New(zaptest.NewLogger(t), core, grpcutil.NewDialer(dummyValidator{}, &net.Dialer{}), nil, nil, getPeerFromContext)
|
||||
|
||||
_, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{})
|
||||
if tc.wantErr {
|
||||
|
@ -199,7 +199,7 @@ func (a *API) activateCoordinator(ctx context.Context, coordinatorIP string) err
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := a.dial(ctx, net.JoinHostPort(coordinatorIP, endpointAVPNPort))
|
||||
conn, err := a.dialer.Dial(ctx, net.JoinHostPort(coordinatorIP, endpointAVPNPort))
|
||||
if err != nil {
|
||||
return fmt.Errorf("dialing new coordinator: %v", err)
|
||||
}
|
||||
@ -271,7 +271,7 @@ func (a *API) triggerCoordinatorUpdate(ctx context.Context, publicIP string) err
|
||||
|
||||
// We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN
|
||||
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
|
||||
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(publicIP, endpointAVPNPort))
|
||||
conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(publicIP, endpointAVPNPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -284,7 +284,7 @@ func (a *API) triggerCoordinatorUpdate(ctx context.Context, publicIP string) err
|
||||
}
|
||||
|
||||
func (a *API) getk8SCoordinatorJoinArgs(ctx context.Context, coordinatorIP, port string) (*kubeadm.BootstrapTokenDiscovery, string, error) {
|
||||
conn, err := a.dialInsecure(ctx, net.JoinHostPort(coordinatorIP, port))
|
||||
conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort(coordinatorIP, port))
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/role"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -92,18 +93,19 @@ func TestActivateAsAdditionalCoordinator(t *testing.T) {
|
||||
ownerID: []byte("ownerID"),
|
||||
clusterID: []byte("clusterID"),
|
||||
}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||
|
||||
getPublicIPAddr := func() (string, error) {
|
||||
return "192.0.2.1", nil
|
||||
}
|
||||
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
|
||||
// spawn vpnServer
|
||||
vpnapiServer := tc.vpnapi.newServer()
|
||||
go vpnapiServer.Serve(dialer.GetListener(net.JoinHostPort(tc.coordinators.peer.VPNIP, vpnAPIPort)))
|
||||
go vpnapiServer.Serve(netDialer.GetListener(net.JoinHostPort(tc.coordinators.peer.VPNIP, vpnAPIPort)))
|
||||
defer vpnapiServer.GracefulStop()
|
||||
|
||||
_, err := api.ActivateAsAdditionalCoordinator(context.Background(), &pubproto.ActivateAsAdditionalCoordinatorRequest{
|
||||
@ -163,9 +165,9 @@ func TestTriggerCoordinatorUpdate(t *testing.T) {
|
||||
state: tc.state,
|
||||
peers: tc.peers,
|
||||
}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
dialer := grpcutil.NewDialer(fakeValidator{}, nil)
|
||||
|
||||
api := New(logger, core, dialer, nil, nil, nil, nil)
|
||||
api := New(logger, core, dialer, nil, nil, nil)
|
||||
|
||||
_, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{})
|
||||
if tc.wantErr {
|
||||
@ -236,20 +238,21 @@ func TestActivateAdditionalCoordinators(t *testing.T) {
|
||||
ownerID: []byte("ownerID"),
|
||||
clusterID: []byte("clusterID"),
|
||||
}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||
|
||||
getPublicIPAddr := func() (string, error) {
|
||||
return "192.0.2.1", nil
|
||||
}
|
||||
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
|
||||
// spawn coordinator
|
||||
tc.coordinators.activateErr = tc.activateErr
|
||||
tc.coordinators.getPubKeyErr = tc.getPublicKeyErr
|
||||
server := tc.coordinators.newServer()
|
||||
go server.Serve(dialer.GetListener(net.JoinHostPort(tc.coordinators.peer.PublicIP, endpointAVPNPort)))
|
||||
go server.Serve(netDialer.GetListener(net.JoinHostPort(tc.coordinators.peer.PublicIP, endpointAVPNPort)))
|
||||
defer server.GracefulStop()
|
||||
|
||||
_, err := api.ActivateAdditionalCoordinator(context.Background(), &pubproto.ActivateAdditionalCoordinatorRequest{CoordinatorPublicIp: tc.coordinators.peer.PublicIP})
|
||||
@ -293,13 +296,13 @@ func TestGetPeerVPNPublicKey(t *testing.T) {
|
||||
vpnPubKey: tc.coordinator.peer.VPNPubKey,
|
||||
getvpnPubKeyErr: tc.getVPNPubKeyErr,
|
||||
}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
dialer := grpcutil.NewDialer(fakeValidator{}, testdialer.NewBufconnDialer())
|
||||
|
||||
getPublicIPAddr := func() (string, error) {
|
||||
return "192.0.2.1", nil
|
||||
}
|
||||
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
|
||||
defer api.Close()
|
||||
|
||||
resp, err := api.GetPeerVPNPublicKey(context.Background(), &pubproto.GetPeerVPNPublicKeyRequest{})
|
||||
|
@ -168,7 +168,7 @@ func (a *API) JoinCluster(ctx context.Context, in *pubproto.JoinClusterRequest)
|
||||
return nil, status.Errorf(codes.FailedPrecondition, "node is not in required state for cluster join: %v", err)
|
||||
}
|
||||
|
||||
conn, err := a.dialInsecure(ctx, net.JoinHostPort(in.CoordinatorVpnIp, vpnAPIPort))
|
||||
conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort(in.CoordinatorVpnIp, vpnAPIPort))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unavailable, "dial coordinator: %v", err)
|
||||
}
|
||||
@ -231,7 +231,7 @@ func (a *API) update(ctx context.Context) error {
|
||||
defer cancel()
|
||||
|
||||
// TODO: replace hardcoded IP
|
||||
conn, err := a.dialInsecure(ctx, net.JoinHostPort("10.118.0.1", vpnAPIPort))
|
||||
conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort("10.118.0.1", vpnAPIPort))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/role"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -127,25 +128,26 @@ func TestActivateAsNode(t *testing.T) {
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||
|
||||
api := New(logger, cor, dialer, nil, nil, nil, nil)
|
||||
api := New(logger, cor, dialer, nil, nil, nil)
|
||||
defer api.Close()
|
||||
|
||||
vserver := grpc.NewServer()
|
||||
vapi := &stubVPNAPI{peers: tc.updatedPeers, getUpdateErr: tc.getUpdateErr}
|
||||
vpnproto.RegisterAPIServer(vserver, vapi)
|
||||
go vserver.Serve(dialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
|
||||
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
|
||||
defer vserver.GracefulStop()
|
||||
|
||||
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
||||
require.NoError(err)
|
||||
pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
pubproto.RegisterAPIServer(pubserver, api)
|
||||
go pubserver.Serve(dialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort)))
|
||||
go pubserver.Serve(netDialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort)))
|
||||
defer pubserver.GracefulStop()
|
||||
|
||||
_, nodeVPNPubKey, err := activateNode(require, dialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey)
|
||||
_, nodeVPNPubKey, err := activateNode(require, netDialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey)
|
||||
assert.Equal(tc.wantState, cor.state)
|
||||
|
||||
if tc.wantErr {
|
||||
@ -215,9 +217,10 @@ func TestTriggerNodeUpdate(t *testing.T) {
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
core := &fakeCore{state: tc.state}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||
|
||||
api := New(logger, core, dialer, nil, nil, nil, nil)
|
||||
api := New(logger, core, dialer, nil, nil, nil)
|
||||
|
||||
vserver := grpc.NewServer()
|
||||
vapi := &stubVPNAPI{
|
||||
@ -225,7 +228,7 @@ func TestTriggerNodeUpdate(t *testing.T) {
|
||||
getUpdateErr: tc.getUpdateErr,
|
||||
}
|
||||
vpnproto.RegisterAPIServer(vserver, vapi)
|
||||
go vserver.Serve(dialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
|
||||
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
|
||||
defer vserver.GracefulStop()
|
||||
|
||||
_, err := api.TriggerNodeUpdate(context.Background(), &pubproto.TriggerNodeUpdateRequest{})
|
||||
@ -290,9 +293,10 @@ func TestJoinCluster(t *testing.T) {
|
||||
|
||||
logger := zaptest.NewLogger(t)
|
||||
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||
|
||||
api := New(logger, core, dialer, nil, nil, nil, nil)
|
||||
api := New(logger, core, dialer, nil, nil, nil)
|
||||
|
||||
vserver := grpc.NewServer()
|
||||
vapi := &stubVPNAPI{
|
||||
@ -304,7 +308,7 @@ func TestJoinCluster(t *testing.T) {
|
||||
getJoinArgsErr: tc.getJoinArgsErr,
|
||||
}
|
||||
vpnproto.RegisterAPIServer(vserver, vapi)
|
||||
go vserver.Serve(dialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort)))
|
||||
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort)))
|
||||
defer vserver.GracefulStop()
|
||||
|
||||
_, err := api.JoinCluster(context.Background(), &pubproto.JoinClusterRequest{CoordinatorVpnIp: "192.0.2.1"})
|
||||
@ -322,7 +326,7 @@ func TestJoinCluster(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func activateNode(require *require.Assertions, dialer Dialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) {
|
||||
func activateNode(require *require.Assertions, dialer netDialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) {
|
||||
ctx := context.Background()
|
||||
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(nodeIP, bindPort))
|
||||
require.NoError(err)
|
||||
@ -385,7 +389,7 @@ func activateNode(require *require.Assertions, dialer Dialer, messageSequence []
|
||||
return diskUUID, nodeVPNPubKey, nil
|
||||
}
|
||||
|
||||
func dialGRPC(ctx context.Context, dialer Dialer, target string) (*grpc.ClientConn, error) {
|
||||
func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -429,3 +433,7 @@ func (a *stubVPNAPI) newServer() *grpc.Server {
|
||||
vpnproto.RegisterAPIServer(server, a)
|
||||
return server
|
||||
}
|
||||
|
||||
type netDialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
@ -8,13 +8,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/state/setup"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
@ -32,7 +29,6 @@ type API struct {
|
||||
core Core
|
||||
dialer Dialer
|
||||
vpnAPIServer VPNAPIServer
|
||||
validator atls.Validator
|
||||
getPublicIPAddr GetIPAddrFunc
|
||||
stopUpdate chan struct{}
|
||||
wgClose sync.WaitGroup
|
||||
@ -42,13 +38,12 @@ type API struct {
|
||||
}
|
||||
|
||||
// New creates a new API.
|
||||
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API {
|
||||
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API {
|
||||
return &API{
|
||||
logger: logger,
|
||||
core: core,
|
||||
dialer: dialer,
|
||||
vpnAPIServer: vpnAPIServer,
|
||||
validator: validator,
|
||||
getPublicIPAddr: getPublicIPAddr,
|
||||
stopUpdate: make(chan struct{}, 1),
|
||||
peerFromContext: peerFromContext,
|
||||
@ -69,47 +64,6 @@ func (a *API) Close() {
|
||||
a.wgClose.Wait()
|
||||
}
|
||||
|
||||
func (a *API) dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{a.validator})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return grpc.DialContext(ctx, target,
|
||||
a.grpcWithDialer(),
|
||||
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
|
||||
)
|
||||
}
|
||||
|
||||
func (a *API) dialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
return grpc.DialContext(ctx, target,
|
||||
a.grpcWithDialer(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
}
|
||||
|
||||
func (a *API) dialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return grpc.DialContext(ctx, target,
|
||||
a.grpcWithDialer(),
|
||||
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
|
||||
)
|
||||
}
|
||||
|
||||
func (a *API) grpcWithDialer() grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return a.dialer.DialContext(ctx, "tcp", addr)
|
||||
})
|
||||
}
|
||||
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type VPNAPIServer interface {
|
||||
Listen(endpoint string) error
|
||||
Serve() error
|
||||
@ -135,3 +89,10 @@ func GetRecoveryPeerFromContext(ctx context.Context) (string, error) {
|
||||
|
||||
return net.JoinHostPort(peerIP, setup.RecoveryPort), nil
|
||||
}
|
||||
|
||||
// Dialer can open grpc client connections with different levels of ATLS encryption / verification.
|
||||
type Dialer interface {
|
||||
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||
DialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||
DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||
}
|
||||
|
71
coordinator/util/grpcutil/dialer.go
Normal file
71
coordinator/util/grpcutil/dialer.go
Normal file
@ -0,0 +1,71 @@
|
||||
package grpcutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
// Dialer can open grpc client connections with different levels of ATLS encryption / verification.
|
||||
type Dialer struct {
|
||||
validator atls.Validator
|
||||
netDialer NetDialer
|
||||
}
|
||||
|
||||
// NewDialer creates a new Dialer.
|
||||
func NewDialer(validator atls.Validator, netDialer NetDialer) *Dialer {
|
||||
return &Dialer{
|
||||
validator: validator,
|
||||
netDialer: netDialer,
|
||||
}
|
||||
}
|
||||
|
||||
// Dial creates a new grpc client connection to the given target using the atls validator.
|
||||
func (d *Dialer) Dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{d.validator})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return grpc.DialContext(ctx, target,
|
||||
d.grpcWithDialer(),
|
||||
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
|
||||
)
|
||||
}
|
||||
|
||||
// DialInsecure creates a new grpc client connection to the given target without using encryption or verification.
|
||||
// Only use this method when using another kind of encryption / verification (VPN, etc).
|
||||
func (d *Dialer) DialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
return grpc.DialContext(ctx, target,
|
||||
d.grpcWithDialer(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
}
|
||||
|
||||
// DialNoVerify creates a new grpc client connection to the given target without verifying the server's attestation.
|
||||
func (d *Dialer) DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return grpc.DialContext(ctx, target,
|
||||
d.grpcWithDialer(),
|
||||
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
|
||||
)
|
||||
}
|
||||
|
||||
func (d *Dialer) grpcWithDialer() grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return d.netDialer.DialContext(ctx, "tcp", addr)
|
||||
})
|
||||
}
|
||||
|
||||
// NetDialer implements the net Dialer interface.
|
||||
type NetDialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
106
coordinator/util/grpcutil/dialer_test.go
Normal file
106
coordinator/util/grpcutil/dialer_test.go
Normal file
@ -0,0 +1,106 @@
|
||||
package grpcutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/core"
|
||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/test/grpc_testing"
|
||||
)
|
||||
|
||||
func TestDial(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
tls bool
|
||||
dialFn func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||
wantErr bool
|
||||
}{
|
||||
"Dial with tls on server works": {
|
||||
tls: true,
|
||||
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
return dialer.Dial(ctx, target)
|
||||
},
|
||||
},
|
||||
"Dial without tls on server fails": {
|
||||
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
return dialer.Dial(ctx, target)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"DialNoVerify with tls on server works": {
|
||||
tls: true,
|
||||
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
return dialer.DialNoVerify(ctx, target)
|
||||
},
|
||||
},
|
||||
"DialNoVerify without tls on server fails": {
|
||||
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
return dialer.DialNoVerify(ctx, target)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"DialInsecure without tls on server works": {
|
||||
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
return dialer.DialInsecure(ctx, target)
|
||||
},
|
||||
},
|
||||
"DialInsecure with tls on server fails": {
|
||||
tls: true,
|
||||
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
return dialer.DialInsecure(ctx, target)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
netDialer := testdialer.NewBufconnDialer()
|
||||
dialer := NewDialer(&core.MockValidator{}, netDialer)
|
||||
server := newServer(tc.tls)
|
||||
api := &testAPI{}
|
||||
grpc_testing.RegisterTestServiceServer(server, api)
|
||||
go server.Serve(netDialer.GetListener("192.0.2.1:1234"))
|
||||
defer server.Stop()
|
||||
conn, err := tc.dialFn(dialer, context.Background(), "192.0.2.1:1234")
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
|
||||
client := grpc_testing.NewTestServiceClient(conn)
|
||||
_, err = client.EmptyCall(context.Background(), &grpc_testing.Empty{})
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newServer(tls bool) *grpc.Server {
|
||||
if tls {
|
||||
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||
}
|
||||
return grpc.NewServer()
|
||||
}
|
||||
|
||||
type testAPI struct {
|
||||
grpc_testing.UnimplementedTestServiceServer
|
||||
}
|
||||
|
||||
func (s *testAPI) EmptyCall(ctx context.Context, in *grpc_testing.Empty) (*grpc_testing.Empty, error) {
|
||||
return &grpc_testing.Empty{}, nil
|
||||
}
|
@ -16,14 +16,14 @@ import (
|
||||
|
||||
// Download downloads a coordinator from a given debugd instance.
|
||||
type Download struct {
|
||||
dialer Dialer
|
||||
dialer NetDialer
|
||||
writer streamToFileWriter
|
||||
serviceManager serviceManager
|
||||
attemptedDownloads map[string]time.Time
|
||||
}
|
||||
|
||||
// New creates a new Download.
|
||||
func New(dialer Dialer, serviceManager serviceManager, writer streamToFileWriter) *Download {
|
||||
func New(dialer NetDialer, serviceManager serviceManager, writer streamToFileWriter) *Download {
|
||||
return &Download{
|
||||
dialer: dialer,
|
||||
writer: writer,
|
||||
@ -91,6 +91,7 @@ type streamToFileWriter interface {
|
||||
WriteStream(filename string, stream coordinator.ReadChunkStream, showProgress bool) error
|
||||
}
|
||||
|
||||
type Dialer interface {
|
||||
// NetDialer can open a net.Conn.
|
||||
type NetDialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
@ -9,10 +9,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||
"github.com/edgelesssys/constellation/debugd/coordinator"
|
||||
"github.com/edgelesssys/constellation/debugd/debugd"
|
||||
pb "github.com/edgelesssys/constellation/debugd/service"
|
||||
"github.com/edgelesssys/constellation/debugd/service/testdialer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
@ -8,10 +8,10 @@ import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||
"github.com/edgelesssys/constellation/debugd/coordinator"
|
||||
"github.com/edgelesssys/constellation/debugd/debugd/deploy"
|
||||
pb "github.com/edgelesssys/constellation/debugd/service"
|
||||
"github.com/edgelesssys/constellation/debugd/service/testdialer"
|
||||
"github.com/edgelesssys/constellation/debugd/ssh"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -351,11 +351,11 @@ func (s *stubServiceManager) WriteSystemdUnitFile(ctx context.Context, unit depl
|
||||
return s.writeSystemdUnitFileErr
|
||||
}
|
||||
|
||||
type dialer interface {
|
||||
type netDialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func dial(ctx context.Context, dialer dialer, target string) (*grpc.ClientConn, error) {
|
||||
func dial(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
|
||||
return grpc.DialContext(ctx, target,
|
||||
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
|
@ -1,41 +0,0 @@
|
||||
package testdialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
)
|
||||
|
||||
// BufconnDialer is a fake dialer based on gRPC bufconn package.
|
||||
type BufconnDialer struct {
|
||||
mut sync.Mutex
|
||||
listeners map[string]*bufconn.Listener
|
||||
}
|
||||
|
||||
// NewBufconnDialer creates a new bufconn dialer for testing.
|
||||
func NewBufconnDialer() *BufconnDialer {
|
||||
return &BufconnDialer{listeners: make(map[string]*bufconn.Listener)}
|
||||
}
|
||||
|
||||
// DialContext implements the Dialer interface.
|
||||
func (b *BufconnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
b.mut.Lock()
|
||||
listener, ok := b.listeners[address]
|
||||
b.mut.Unlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("could not connect to server on %v", address)
|
||||
}
|
||||
return listener.DialContext(ctx)
|
||||
}
|
||||
|
||||
// GetListener returns a fake listener that is coupled with this dialer.
|
||||
func (b *BufconnDialer) GetListener(endpoint string) net.Listener {
|
||||
listener := bufconn.Listen(1024)
|
||||
b.mut.Lock()
|
||||
b.listeners[endpoint] = listener
|
||||
b.mut.Unlock()
|
||||
return listener
|
||||
}
|
Loading…
Reference in New Issue
Block a user