mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-07-27 17:25:20 -04:00
extract shared grpcutil dialer from pubapi
Signed-off-by: Malte Poll <mp@edgeless.systems>
This commit is contained in:
parent
5ac72c730d
commit
77b0237dd5
17 changed files with 275 additions and 152 deletions
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi"
|
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi"
|
||||||
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi/kubectl"
|
"github.com/edgelesssys/constellation/coordinator/kubernetes/k8sapi/kubectl"
|
||||||
"github.com/edgelesssys/constellation/coordinator/util"
|
"github.com/edgelesssys/constellation/coordinator/util"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||||
"github.com/edgelesssys/constellation/coordinator/wireguard"
|
"github.com/edgelesssys/constellation/coordinator/wireguard"
|
||||||
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
|
grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
|
@ -153,7 +154,8 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
fileHandler := file.NewHandler(fs)
|
fileHandler := file.NewHandler(fs)
|
||||||
dialer := &net.Dialer{}
|
netDialer := &net.Dialer{}
|
||||||
run(validator, issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube,
|
dialer := grpcutil.NewDialer(validator, netDialer)
|
||||||
|
run(issuer, wg, openTPM, util.GetIPAddr, dialer, fileHandler, kube,
|
||||||
metadata, cloudControllerManager, cloudNodeManager, autoscaler, encryptedDisk, etcdEndpoint, enforceEtcdTls, bindIP, bindPort, zapLoggerCore)
|
metadata, cloudControllerManager, cloudNodeManager, autoscaler, encryptedDisk, etcdEndpoint, enforceEtcdTls, bindIP, bindPort, zapLoggerCore)
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi"
|
"github.com/edgelesssys/constellation/coordinator/pubapi"
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
"github.com/edgelesssys/constellation/coordinator/store"
|
"github.com/edgelesssys/constellation/coordinator/store"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||||
"github.com/edgelesssys/constellation/coordinator/vpnapi"
|
"github.com/edgelesssys/constellation/coordinator/vpnapi"
|
||||||
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
||||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||||
|
@ -26,7 +27,7 @@ import (
|
||||||
|
|
||||||
var version = "0.0.0"
|
var version = "0.0.0"
|
||||||
|
|
||||||
func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer pubapi.Dialer, fileHandler file.Handler,
|
func run(issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPublicIPAddr func() (string, error), dialer *grpcutil.Dialer, fileHandler file.Handler,
|
||||||
kube core.Cluster, metadata core.ProviderMetadata, cloudControllerManager core.CloudControllerManager, cloudNodeManager core.CloudNodeManager, clusterAutoscaler core.ClusterAutoscaler, encryptedDisk core.EncryptedDisk, etcdEndpoint string, etcdTLS bool, bindIP, bindPort string, zapLoggerCore *zap.Logger,
|
kube core.Cluster, metadata core.ProviderMetadata, cloudControllerManager core.CloudControllerManager, cloudNodeManager core.CloudNodeManager, clusterAutoscaler core.ClusterAutoscaler, encryptedDisk core.EncryptedDisk, etcdEndpoint string, etcdTLS bool, bindIP, bindPort string, zapLoggerCore *zap.Logger,
|
||||||
) {
|
) {
|
||||||
defer zapLoggerCore.Sync()
|
defer zapLoggerCore.Sync()
|
||||||
|
@ -46,16 +47,16 @@ func run(validator core.QuoteValidator, issuer core.QuoteIssuer, vpn core.VPN, o
|
||||||
if err != nil {
|
if err != nil {
|
||||||
zapLoggerCore.Fatal("failed to create core", zap.Error(err))
|
zapLoggerCore.Fatal("failed to create core", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
|
||||||
|
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
|
||||||
|
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
|
||||||
// initialize state machine and wait for re-joining of the VPN (if applicable)
|
// initialize state machine and wait for re-joining of the VPN (if applicable)
|
||||||
nodeActivated, err := core.Initialize()
|
nodeActivated, err := core.Initialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
zapLoggerCore.Fatal("failed to initialize core", zap.Error(err))
|
zapLoggerCore.Fatal("failed to initialize core", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
vapiServer := &vpnAPIServer{logger: zapLoggerCore.Named("vpnapi"), core: core}
|
|
||||||
zapLoggerPubapi := zapLoggerCore.Named("pubapi")
|
|
||||||
papi := pubapi.New(zapLoggerPubapi, core, dialer, vapiServer, validator, getPublicIPAddr, pubapi.GetRecoveryPeerFromContext)
|
|
||||||
|
|
||||||
zapLoggergRPC := zapLoggerPubapi.Named("gRPC")
|
zapLoggergRPC := zapLoggerPubapi.Named("gRPC")
|
||||||
|
|
||||||
grpcServer := grpc.NewServer(
|
grpcServer := grpc.NewServer(
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
"github.com/edgelesssys/constellation/coordinator/state"
|
"github.com/edgelesssys/constellation/coordinator/state"
|
||||||
"github.com/edgelesssys/constellation/coordinator/store"
|
"github.com/edgelesssys/constellation/coordinator/store"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||||
"github.com/edgelesssys/constellation/coordinator/vpnapi"
|
"github.com/edgelesssys/constellation/coordinator/vpnapi"
|
||||||
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
||||||
|
@ -210,7 +211,7 @@ func TestConcurrent(t *testing.T) {
|
||||||
assert.Error(<-actCoordErrs)
|
assert.Error(<-actCoordErrs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdialer.BufconnDialer, netw *network, endpoint string) (*grpc.Server, *pubapi.API, *fakeVPN) {
|
func spawnPeer(require *require.Assertions, logger *zap.Logger, netDialer *testdialer.BufconnDialer, netw *network, endpoint string) (*grpc.Server, *pubapi.API, *fakeVPN) {
|
||||||
vpn := newVPN(netw, endpoint)
|
vpn := newVPN(netw, endpoint)
|
||||||
cor, err := core.NewCore(vpn, &core.ClusterFake{}, &core.ProviderMetadataFake{}, &core.CloudControllerManagerFake{}, &core.CloudNodeManagerFake{}, &core.ClusterAutoscalerFake{}, &core.EncryptedDiskFake{}, logger, simulator.OpenSimulatedTPM, fakeStoreFactory{}, file.NewHandler(afero.NewMemMapFs()))
|
cor, err := core.NewCore(vpn, &core.ClusterFake{}, &core.ProviderMetadataFake{}, &core.CloudControllerManagerFake{}, &core.CloudNodeManagerFake{}, &core.ClusterAutoscalerFake{}, &core.EncryptedDiskFake{}, logger, simulator.OpenSimulatedTPM, fakeStoreFactory{}, file.NewHandler(afero.NewMemMapFs()))
|
||||||
require.NoError(err)
|
require.NoError(err)
|
||||||
|
@ -219,22 +220,23 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, dialer *testdial
|
||||||
getPublicAddr := func() (string, error) {
|
getPublicAddr := func() (string, error) {
|
||||||
return "192.0.2.1", nil
|
return "192.0.2.1", nil
|
||||||
}
|
}
|
||||||
|
dialer := grpcutil.NewDialer(&core.MockValidator{}, netDialer)
|
||||||
|
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: netDialer}
|
||||||
|
|
||||||
vapiServer := &fakeVPNAPIServer{logger: logger.Named("vpnapi"), core: cor, dialer: dialer}
|
papi := pubapi.New(logger, cor, dialer, vapiServer, getPublicAddr, nil)
|
||||||
papi := pubapi.New(logger, cor, dialer, vapiServer, &core.MockValidator{}, getPublicAddr, nil)
|
|
||||||
|
|
||||||
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
||||||
require.NoError(err)
|
require.NoError(err)
|
||||||
server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||||
pubproto.RegisterAPIServer(server, papi)
|
pubproto.RegisterAPIServer(server, papi)
|
||||||
|
|
||||||
listener := dialer.GetListener(endpoint)
|
listener := netDialer.GetListener(endpoint)
|
||||||
go server.Serve(listener)
|
go server.Serve(listener)
|
||||||
|
|
||||||
return server, papi, vpn
|
return server, papi, vpn
|
||||||
}
|
}
|
||||||
|
|
||||||
func activateCoordinator(require *require.Assertions, dialer pubapi.Dialer, coordinatorIP, bindPort string, nodeIPs []string) error {
|
func activateCoordinator(require *require.Assertions, dialer netDialer, coordinatorIP, bindPort string, nodeIPs []string) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(coordinatorIP, bindPort))
|
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(coordinatorIP, bindPort))
|
||||||
require.NoError(err)
|
require.NoError(err)
|
||||||
|
@ -260,7 +262,7 @@ func activateCoordinator(require *require.Assertions, dialer pubapi.Dialer, coor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func dialGRPC(ctx context.Context, dialer pubapi.Dialer, target string) (*grpc.ClientConn, error) {
|
func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
|
||||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
|
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -398,3 +400,7 @@ func (v *fakeVPN) recv() *packet {
|
||||||
}
|
}
|
||||||
return &packet
|
return &packet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type netDialer interface {
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi"
|
"github.com/edgelesssys/constellation/coordinator/pubapi"
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
"github.com/edgelesssys/constellation/coordinator/state"
|
"github.com/edgelesssys/constellation/coordinator/state"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||||
"github.com/edgelesssys/constellation/coordinator/vpnapi"
|
"github.com/edgelesssys/constellation/coordinator/vpnapi"
|
||||||
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
||||||
"github.com/spf13/afero"
|
"github.com/spf13/afero"
|
||||||
|
@ -112,13 +113,13 @@ func TestLegacyActivateCoordinator(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// newMockCoreWithDialer creates a new core object with attestation mock and provided dialer for testing.
|
// newMockCoreWithDialer creates a new core object with attestation mock and provided dialer for testing.
|
||||||
func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) {
|
func newMockCoreWithDialer(bufDialer *bufconnDialer) (*Core, *pubapi.API, error) {
|
||||||
zapLogger, err := zap.NewDevelopment()
|
zapLogger, err := zap.NewDevelopment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
validator := NewMockValidator()
|
dialer := grpcutil.NewDialer(NewMockValidator(), bufDialer)
|
||||||
vpn := &stubVPN{}
|
vpn := &stubVPN{}
|
||||||
kubeFake := &ClusterFake{}
|
kubeFake := &ClusterFake{}
|
||||||
metadataFake := &ProviderMetadataFake{}
|
metadataFake := &ProviderMetadataFake{}
|
||||||
|
@ -138,8 +139,8 @@ func newMockCoreWithDialer(dialer *bufconnDialer) (*Core, *pubapi.API, error) {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: dialer}
|
vapiServer := &fakeVPNAPIServer{logger: zapLogger, core: core, dialer: bufDialer}
|
||||||
papi := pubapi.New(zapLogger, core, dialer, vapiServer, validator, getPublicAddr, nil)
|
papi := pubapi.New(zapLogger, core, dialer, vapiServer, getPublicAddr, nil)
|
||||||
|
|
||||||
return core, papi, nil
|
return core, papi, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -204,7 +204,7 @@ func (a *API) RequestStateDiskKey(ctx context.Context, in *pubproto.RequestState
|
||||||
return nil, status.Errorf(codes.Internal, "%v", err)
|
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := a.dial(ctx, peer)
|
conn, err := a.dialer.Dial(ctx, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Internal, "%v", err)
|
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||||
}
|
}
|
||||||
|
@ -297,7 +297,7 @@ func (a *API) activateNode(nodePublicIP string, nodeVPNIP string, initialPeers [
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), deadlineDuration)
|
ctx, cancel := context.WithTimeout(context.Background(), deadlineDuration)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := a.dial(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
conn, err := a.dialer.Dial(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -417,7 +417,7 @@ func (a *API) joinCluster(nodePublicIP string) error {
|
||||||
}
|
}
|
||||||
// We don't verify the peer certificate here, since JoinCluster triggers a connection over VPN
|
// We don't verify the peer certificate here, since JoinCluster triggers a connection over VPN
|
||||||
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
|
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
|
||||||
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -455,7 +455,7 @@ func (a *API) triggerNodeUpdate(nodePublicIP string) error {
|
||||||
|
|
||||||
// We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN
|
// We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN
|
||||||
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
|
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
|
||||||
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(nodePublicIP, endpointAVPNPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
"github.com/edgelesssys/constellation/coordinator/role"
|
"github.com/edgelesssys/constellation/coordinator/role"
|
||||||
"github.com/edgelesssys/constellation/coordinator/state"
|
"github.com/edgelesssys/constellation/coordinator/state"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||||
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
"github.com/edgelesssys/constellation/state/keyservice/keyproto"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -125,13 +126,14 @@ func TestActivateAsCoordinator(t *testing.T) {
|
||||||
ownerID: []byte("ownerID"),
|
ownerID: []byte("ownerID"),
|
||||||
clusterID: []byte("clusterID"),
|
clusterID: []byte("clusterID"),
|
||||||
}
|
}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
netDialer := testdialer.NewBufconnDialer()
|
||||||
|
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||||
|
|
||||||
getPublicIPAddr := func() (string, error) {
|
getPublicIPAddr := func() (string, error) {
|
||||||
return "192.0.2.1", nil
|
return "192.0.2.1", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
// spawn nodes
|
// spawn nodes
|
||||||
|
@ -142,7 +144,7 @@ func TestActivateAsCoordinator(t *testing.T) {
|
||||||
server := n.newServer()
|
server := n.newServer()
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(endpoint string) {
|
go func(endpoint string) {
|
||||||
listener := dialer.GetListener(endpoint)
|
listener := netDialer.GetListener(endpoint)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
_ = server.Serve(listener)
|
_ = server.Serve(listener)
|
||||||
}(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort))
|
}(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort))
|
||||||
|
@ -256,13 +258,14 @@ func TestActivateAdditionalNodes(t *testing.T) {
|
||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
|
|
||||||
core := &fakeCore{state: tc.state}
|
core := &fakeCore{state: tc.state}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
netDialer := testdialer.NewBufconnDialer()
|
||||||
|
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||||
|
|
||||||
getPublicIPAddr := func() (string, error) {
|
getPublicIPAddr := func() (string, error) {
|
||||||
return "192.0.2.1", nil
|
return "192.0.2.1", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
api := New(zaptest.NewLogger(t), core, dialer, nil, fakeValidator{}, getPublicIPAddr, nil)
|
api := New(zaptest.NewLogger(t), core, dialer, nil, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
// spawn nodes
|
// spawn nodes
|
||||||
var nodePublicIPs []string
|
var nodePublicIPs []string
|
||||||
|
@ -272,7 +275,7 @@ func TestActivateAdditionalNodes(t *testing.T) {
|
||||||
server := n.newServer()
|
server := n.newServer()
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(endpoint string) {
|
go func(endpoint string) {
|
||||||
listener := dialer.GetListener(endpoint)
|
listener := netDialer.GetListener(endpoint)
|
||||||
wg.Done()
|
wg.Done()
|
||||||
_ = server.Serve(listener)
|
_ = server.Serve(listener)
|
||||||
}(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort))
|
}(net.JoinHostPort(n.peer.PublicIP, endpointAVPNPort))
|
||||||
|
@ -311,7 +314,7 @@ func TestAssemblePeerStruct(t *testing.T) {
|
||||||
|
|
||||||
vpnPubKey := []byte{2, 3, 4}
|
vpnPubKey := []byte{2, 3, 4}
|
||||||
core := &fakeCore{vpnPubKey: vpnPubKey}
|
core := &fakeCore{vpnPubKey: vpnPubKey}
|
||||||
api := New(zaptest.NewLogger(t), core, nil, nil, nil, getPublicIPAddr, nil)
|
api := New(zaptest.NewLogger(t), core, nil, nil, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
vpnIP, err := core.GetVPNIP()
|
vpnIP, err := core.GetVPNIP()
|
||||||
|
@ -512,7 +515,8 @@ func TestRequestStateDiskKey(t *testing.T) {
|
||||||
dataKey: tc.dataKey,
|
dataKey: tc.dataKey,
|
||||||
getDataKeyErr: tc.getDataKeyErr,
|
getDataKeyErr: tc.getDataKeyErr,
|
||||||
}
|
}
|
||||||
api := New(zaptest.NewLogger(t), core, &net.Dialer{}, nil, dummyValidator{}, nil, getPeerFromContext)
|
|
||||||
|
api := New(zaptest.NewLogger(t), core, grpcutil.NewDialer(dummyValidator{}, &net.Dialer{}), nil, nil, getPeerFromContext)
|
||||||
|
|
||||||
_, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{})
|
_, err = api.RequestStateDiskKey(ctx, &pubproto.RequestStateDiskKeyRequest{})
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
|
|
|
@ -199,7 +199,7 @@ func (a *API) activateCoordinator(ctx context.Context, coordinatorIP string) err
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := a.dial(ctx, net.JoinHostPort(coordinatorIP, endpointAVPNPort))
|
conn, err := a.dialer.Dial(ctx, net.JoinHostPort(coordinatorIP, endpointAVPNPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("dialing new coordinator: %v", err)
|
return fmt.Errorf("dialing new coordinator: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -271,7 +271,7 @@ func (a *API) triggerCoordinatorUpdate(ctx context.Context, publicIP string) err
|
||||||
|
|
||||||
// We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN
|
// We don't verify the peer certificate here, since TriggerNodeUpdate triggers a connection over VPN
|
||||||
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
|
// The target of the rpc needs to already be part of the VPN to process the request, meaning it is trusted
|
||||||
conn, err := a.dialNoVerify(ctx, net.JoinHostPort(publicIP, endpointAVPNPort))
|
conn, err := a.dialer.DialNoVerify(ctx, net.JoinHostPort(publicIP, endpointAVPNPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -284,7 +284,7 @@ func (a *API) triggerCoordinatorUpdate(ctx context.Context, publicIP string) err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) getk8SCoordinatorJoinArgs(ctx context.Context, coordinatorIP, port string) (*kubeadm.BootstrapTokenDiscovery, string, error) {
|
func (a *API) getk8SCoordinatorJoinArgs(ctx context.Context, coordinatorIP, port string) (*kubeadm.BootstrapTokenDiscovery, string, error) {
|
||||||
conn, err := a.dialInsecure(ctx, net.JoinHostPort(coordinatorIP, port))
|
conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort(coordinatorIP, port))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
"github.com/edgelesssys/constellation/coordinator/role"
|
"github.com/edgelesssys/constellation/coordinator/role"
|
||||||
"github.com/edgelesssys/constellation/coordinator/state"
|
"github.com/edgelesssys/constellation/coordinator/state"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -92,18 +93,19 @@ func TestActivateAsAdditionalCoordinator(t *testing.T) {
|
||||||
ownerID: []byte("ownerID"),
|
ownerID: []byte("ownerID"),
|
||||||
clusterID: []byte("clusterID"),
|
clusterID: []byte("clusterID"),
|
||||||
}
|
}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
netDialer := testdialer.NewBufconnDialer()
|
||||||
|
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||||
|
|
||||||
getPublicIPAddr := func() (string, error) {
|
getPublicIPAddr := func() (string, error) {
|
||||||
return "192.0.2.1", nil
|
return "192.0.2.1", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
// spawn vpnServer
|
// spawn vpnServer
|
||||||
vpnapiServer := tc.vpnapi.newServer()
|
vpnapiServer := tc.vpnapi.newServer()
|
||||||
go vpnapiServer.Serve(dialer.GetListener(net.JoinHostPort(tc.coordinators.peer.VPNIP, vpnAPIPort)))
|
go vpnapiServer.Serve(netDialer.GetListener(net.JoinHostPort(tc.coordinators.peer.VPNIP, vpnAPIPort)))
|
||||||
defer vpnapiServer.GracefulStop()
|
defer vpnapiServer.GracefulStop()
|
||||||
|
|
||||||
_, err := api.ActivateAsAdditionalCoordinator(context.Background(), &pubproto.ActivateAsAdditionalCoordinatorRequest{
|
_, err := api.ActivateAsAdditionalCoordinator(context.Background(), &pubproto.ActivateAsAdditionalCoordinatorRequest{
|
||||||
|
@ -163,9 +165,9 @@ func TestTriggerCoordinatorUpdate(t *testing.T) {
|
||||||
state: tc.state,
|
state: tc.state,
|
||||||
peers: tc.peers,
|
peers: tc.peers,
|
||||||
}
|
}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
dialer := grpcutil.NewDialer(fakeValidator{}, nil)
|
||||||
|
|
||||||
api := New(logger, core, dialer, nil, nil, nil, nil)
|
api := New(logger, core, dialer, nil, nil, nil)
|
||||||
|
|
||||||
_, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{})
|
_, err := api.TriggerCoordinatorUpdate(context.Background(), &pubproto.TriggerCoordinatorUpdateRequest{})
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
|
@ -236,20 +238,21 @@ func TestActivateAdditionalCoordinators(t *testing.T) {
|
||||||
ownerID: []byte("ownerID"),
|
ownerID: []byte("ownerID"),
|
||||||
clusterID: []byte("clusterID"),
|
clusterID: []byte("clusterID"),
|
||||||
}
|
}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
netDialer := testdialer.NewBufconnDialer()
|
||||||
|
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||||
|
|
||||||
getPublicIPAddr := func() (string, error) {
|
getPublicIPAddr := func() (string, error) {
|
||||||
return "192.0.2.1", nil
|
return "192.0.2.1", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
// spawn coordinator
|
// spawn coordinator
|
||||||
tc.coordinators.activateErr = tc.activateErr
|
tc.coordinators.activateErr = tc.activateErr
|
||||||
tc.coordinators.getPubKeyErr = tc.getPublicKeyErr
|
tc.coordinators.getPubKeyErr = tc.getPublicKeyErr
|
||||||
server := tc.coordinators.newServer()
|
server := tc.coordinators.newServer()
|
||||||
go server.Serve(dialer.GetListener(net.JoinHostPort(tc.coordinators.peer.PublicIP, endpointAVPNPort)))
|
go server.Serve(netDialer.GetListener(net.JoinHostPort(tc.coordinators.peer.PublicIP, endpointAVPNPort)))
|
||||||
defer server.GracefulStop()
|
defer server.GracefulStop()
|
||||||
|
|
||||||
_, err := api.ActivateAdditionalCoordinator(context.Background(), &pubproto.ActivateAdditionalCoordinatorRequest{CoordinatorPublicIp: tc.coordinators.peer.PublicIP})
|
_, err := api.ActivateAdditionalCoordinator(context.Background(), &pubproto.ActivateAdditionalCoordinatorRequest{CoordinatorPublicIp: tc.coordinators.peer.PublicIP})
|
||||||
|
@ -293,13 +296,13 @@ func TestGetPeerVPNPublicKey(t *testing.T) {
|
||||||
vpnPubKey: tc.coordinator.peer.VPNPubKey,
|
vpnPubKey: tc.coordinator.peer.VPNPubKey,
|
||||||
getvpnPubKeyErr: tc.getVPNPubKeyErr,
|
getvpnPubKeyErr: tc.getVPNPubKeyErr,
|
||||||
}
|
}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
dialer := grpcutil.NewDialer(fakeValidator{}, testdialer.NewBufconnDialer())
|
||||||
|
|
||||||
getPublicIPAddr := func() (string, error) {
|
getPublicIPAddr := func() (string, error) {
|
||||||
return "192.0.2.1", nil
|
return "192.0.2.1", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, fakeValidator{}, getPublicIPAddr, nil)
|
api := New(zaptest.NewLogger(t), core, dialer, stubVPNAPIServer{}, getPublicIPAddr, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
resp, err := api.GetPeerVPNPublicKey(context.Background(), &pubproto.GetPeerVPNPublicKeyRequest{})
|
resp, err := api.GetPeerVPNPublicKey(context.Background(), &pubproto.GetPeerVPNPublicKeyRequest{})
|
||||||
|
|
|
@ -168,7 +168,7 @@ func (a *API) JoinCluster(ctx context.Context, in *pubproto.JoinClusterRequest)
|
||||||
return nil, status.Errorf(codes.FailedPrecondition, "node is not in required state for cluster join: %v", err)
|
return nil, status.Errorf(codes.FailedPrecondition, "node is not in required state for cluster join: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := a.dialInsecure(ctx, net.JoinHostPort(in.CoordinatorVpnIp, vpnAPIPort))
|
conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort(in.CoordinatorVpnIp, vpnAPIPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(codes.Unavailable, "dial coordinator: %v", err)
|
return nil, status.Errorf(codes.Unavailable, "dial coordinator: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -231,7 +231,7 @@ func (a *API) update(ctx context.Context) error {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// TODO: replace hardcoded IP
|
// TODO: replace hardcoded IP
|
||||||
conn, err := a.dialInsecure(ctx, net.JoinHostPort("10.118.0.1", vpnAPIPort))
|
conn, err := a.dialer.DialInsecure(ctx, net.JoinHostPort("10.118.0.1", vpnAPIPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
"github.com/edgelesssys/constellation/coordinator/role"
|
"github.com/edgelesssys/constellation/coordinator/role"
|
||||||
"github.com/edgelesssys/constellation/coordinator/state"
|
"github.com/edgelesssys/constellation/coordinator/state"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
|
||||||
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||||
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -127,25 +128,26 @@ func TestActivateAsNode(t *testing.T) {
|
||||||
|
|
||||||
logger := zaptest.NewLogger(t)
|
logger := zaptest.NewLogger(t)
|
||||||
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr}
|
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
netDialer := testdialer.NewBufconnDialer()
|
||||||
|
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||||
|
|
||||||
api := New(logger, cor, dialer, nil, nil, nil, nil)
|
api := New(logger, cor, dialer, nil, nil, nil)
|
||||||
defer api.Close()
|
defer api.Close()
|
||||||
|
|
||||||
vserver := grpc.NewServer()
|
vserver := grpc.NewServer()
|
||||||
vapi := &stubVPNAPI{peers: tc.updatedPeers, getUpdateErr: tc.getUpdateErr}
|
vapi := &stubVPNAPI{peers: tc.updatedPeers, getUpdateErr: tc.getUpdateErr}
|
||||||
vpnproto.RegisterAPIServer(vserver, vapi)
|
vpnproto.RegisterAPIServer(vserver, vapi)
|
||||||
go vserver.Serve(dialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
|
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
|
||||||
defer vserver.GracefulStop()
|
defer vserver.GracefulStop()
|
||||||
|
|
||||||
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
||||||
require.NoError(err)
|
require.NoError(err)
|
||||||
pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||||
pubproto.RegisterAPIServer(pubserver, api)
|
pubproto.RegisterAPIServer(pubserver, api)
|
||||||
go pubserver.Serve(dialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort)))
|
go pubserver.Serve(netDialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort)))
|
||||||
defer pubserver.GracefulStop()
|
defer pubserver.GracefulStop()
|
||||||
|
|
||||||
_, nodeVPNPubKey, err := activateNode(require, dialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey)
|
_, nodeVPNPubKey, err := activateNode(require, netDialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey)
|
||||||
assert.Equal(tc.wantState, cor.state)
|
assert.Equal(tc.wantState, cor.state)
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
|
@ -215,9 +217,10 @@ func TestTriggerNodeUpdate(t *testing.T) {
|
||||||
|
|
||||||
logger := zaptest.NewLogger(t)
|
logger := zaptest.NewLogger(t)
|
||||||
core := &fakeCore{state: tc.state}
|
core := &fakeCore{state: tc.state}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
netDialer := testdialer.NewBufconnDialer()
|
||||||
|
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||||
|
|
||||||
api := New(logger, core, dialer, nil, nil, nil, nil)
|
api := New(logger, core, dialer, nil, nil, nil)
|
||||||
|
|
||||||
vserver := grpc.NewServer()
|
vserver := grpc.NewServer()
|
||||||
vapi := &stubVPNAPI{
|
vapi := &stubVPNAPI{
|
||||||
|
@ -225,7 +228,7 @@ func TestTriggerNodeUpdate(t *testing.T) {
|
||||||
getUpdateErr: tc.getUpdateErr,
|
getUpdateErr: tc.getUpdateErr,
|
||||||
}
|
}
|
||||||
vpnproto.RegisterAPIServer(vserver, vapi)
|
vpnproto.RegisterAPIServer(vserver, vapi)
|
||||||
go vserver.Serve(dialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
|
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
|
||||||
defer vserver.GracefulStop()
|
defer vserver.GracefulStop()
|
||||||
|
|
||||||
_, err := api.TriggerNodeUpdate(context.Background(), &pubproto.TriggerNodeUpdateRequest{})
|
_, err := api.TriggerNodeUpdate(context.Background(), &pubproto.TriggerNodeUpdateRequest{})
|
||||||
|
@ -290,9 +293,10 @@ func TestJoinCluster(t *testing.T) {
|
||||||
|
|
||||||
logger := zaptest.NewLogger(t)
|
logger := zaptest.NewLogger(t)
|
||||||
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
|
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
|
||||||
dialer := testdialer.NewBufconnDialer()
|
netDialer := testdialer.NewBufconnDialer()
|
||||||
|
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
|
||||||
|
|
||||||
api := New(logger, core, dialer, nil, nil, nil, nil)
|
api := New(logger, core, dialer, nil, nil, nil)
|
||||||
|
|
||||||
vserver := grpc.NewServer()
|
vserver := grpc.NewServer()
|
||||||
vapi := &stubVPNAPI{
|
vapi := &stubVPNAPI{
|
||||||
|
@ -304,7 +308,7 @@ func TestJoinCluster(t *testing.T) {
|
||||||
getJoinArgsErr: tc.getJoinArgsErr,
|
getJoinArgsErr: tc.getJoinArgsErr,
|
||||||
}
|
}
|
||||||
vpnproto.RegisterAPIServer(vserver, vapi)
|
vpnproto.RegisterAPIServer(vserver, vapi)
|
||||||
go vserver.Serve(dialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort)))
|
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort)))
|
||||||
defer vserver.GracefulStop()
|
defer vserver.GracefulStop()
|
||||||
|
|
||||||
_, err := api.JoinCluster(context.Background(), &pubproto.JoinClusterRequest{CoordinatorVpnIp: "192.0.2.1"})
|
_, err := api.JoinCluster(context.Background(), &pubproto.JoinClusterRequest{CoordinatorVpnIp: "192.0.2.1"})
|
||||||
|
@ -322,7 +326,7 @@ func TestJoinCluster(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func activateNode(require *require.Assertions, dialer Dialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) {
|
func activateNode(require *require.Assertions, dialer netDialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(nodeIP, bindPort))
|
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(nodeIP, bindPort))
|
||||||
require.NoError(err)
|
require.NoError(err)
|
||||||
|
@ -385,7 +389,7 @@ func activateNode(require *require.Assertions, dialer Dialer, messageSequence []
|
||||||
return diskUUID, nodeVPNPubKey, nil
|
return diskUUID, nodeVPNPubKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func dialGRPC(ctx context.Context, dialer Dialer, target string) (*grpc.ClientConn, error) {
|
func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
|
||||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
|
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -429,3 +433,7 @@ func (a *stubVPNAPI) newServer() *grpc.Server {
|
||||||
vpnproto.RegisterAPIServer(server, a)
|
vpnproto.RegisterAPIServer(server, a)
|
||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type netDialer interface {
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
|
@ -8,13 +8,10 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
"github.com/edgelesssys/constellation/state/setup"
|
"github.com/edgelesssys/constellation/state/setup"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
|
||||||
"google.golang.org/grpc/peer"
|
"google.golang.org/grpc/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -32,7 +29,6 @@ type API struct {
|
||||||
core Core
|
core Core
|
||||||
dialer Dialer
|
dialer Dialer
|
||||||
vpnAPIServer VPNAPIServer
|
vpnAPIServer VPNAPIServer
|
||||||
validator atls.Validator
|
|
||||||
getPublicIPAddr GetIPAddrFunc
|
getPublicIPAddr GetIPAddrFunc
|
||||||
stopUpdate chan struct{}
|
stopUpdate chan struct{}
|
||||||
wgClose sync.WaitGroup
|
wgClose sync.WaitGroup
|
||||||
|
@ -42,13 +38,12 @@ type API struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new API.
|
// New creates a new API.
|
||||||
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, validator atls.Validator, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API {
|
func New(logger *zap.Logger, core Core, dialer Dialer, vpnAPIServer VPNAPIServer, getPublicIPAddr GetIPAddrFunc, peerFromContext PeerFromContextFunc) *API {
|
||||||
return &API{
|
return &API{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
core: core,
|
core: core,
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
vpnAPIServer: vpnAPIServer,
|
vpnAPIServer: vpnAPIServer,
|
||||||
validator: validator,
|
|
||||||
getPublicIPAddr: getPublicIPAddr,
|
getPublicIPAddr: getPublicIPAddr,
|
||||||
stopUpdate: make(chan struct{}, 1),
|
stopUpdate: make(chan struct{}, 1),
|
||||||
peerFromContext: peerFromContext,
|
peerFromContext: peerFromContext,
|
||||||
|
@ -69,47 +64,6 @@ func (a *API) Close() {
|
||||||
a.wgClose.Wait()
|
a.wgClose.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *API) dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
|
||||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{a.validator})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return grpc.DialContext(ctx, target,
|
|
||||||
a.grpcWithDialer(),
|
|
||||||
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *API) dialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
|
||||||
return grpc.DialContext(ctx, target,
|
|
||||||
a.grpcWithDialer(),
|
|
||||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *API) dialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
|
||||||
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return grpc.DialContext(ctx, target,
|
|
||||||
a.grpcWithDialer(),
|
|
||||||
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *API) grpcWithDialer() grpc.DialOption {
|
|
||||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
|
||||||
return a.dialer.DialContext(ctx, "tcp", addr)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type Dialer interface {
|
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type VPNAPIServer interface {
|
type VPNAPIServer interface {
|
||||||
Listen(endpoint string) error
|
Listen(endpoint string) error
|
||||||
Serve() error
|
Serve() error
|
||||||
|
@ -135,3 +89,10 @@ func GetRecoveryPeerFromContext(ctx context.Context) (string, error) {
|
||||||
|
|
||||||
return net.JoinHostPort(peerIP, setup.RecoveryPort), nil
|
return net.JoinHostPort(peerIP, setup.RecoveryPort), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dialer can open grpc client connections with different levels of ATLS encryption / verification.
|
||||||
|
type Dialer interface {
|
||||||
|
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||||
|
DialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||||
|
DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||||
|
}
|
||||||
|
|
71
coordinator/util/grpcutil/dialer.go
Normal file
71
coordinator/util/grpcutil/dialer.go
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package grpcutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dialer can open grpc client connections with different levels of ATLS encryption / verification.
|
||||||
|
type Dialer struct {
|
||||||
|
validator atls.Validator
|
||||||
|
netDialer NetDialer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDialer creates a new Dialer.
|
||||||
|
func NewDialer(validator atls.Validator, netDialer NetDialer) *Dialer {
|
||||||
|
return &Dialer{
|
||||||
|
validator: validator,
|
||||||
|
netDialer: netDialer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial creates a new grpc client connection to the given target using the atls validator.
|
||||||
|
func (d *Dialer) Dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||||
|
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{d.validator})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return grpc.DialContext(ctx, target,
|
||||||
|
d.grpcWithDialer(),
|
||||||
|
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialInsecure creates a new grpc client connection to the given target without using encryption or verification.
|
||||||
|
// Only use this method when using another kind of encryption / verification (VPN, etc).
|
||||||
|
func (d *Dialer) DialInsecure(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||||
|
return grpc.DialContext(ctx, target,
|
||||||
|
d.grpcWithDialer(),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialNoVerify creates a new grpc client connection to the given target without verifying the server's attestation.
|
||||||
|
func (d *Dialer) DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||||
|
tlsConfig, err := atls.CreateUnverifiedClientTLSConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return grpc.DialContext(ctx, target,
|
||||||
|
d.grpcWithDialer(),
|
||||||
|
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialer) grpcWithDialer() grpc.DialOption {
|
||||||
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
return d.netDialer.DialContext(ctx, "tcp", addr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetDialer implements the net Dialer interface.
|
||||||
|
type NetDialer interface {
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
106
coordinator/util/grpcutil/dialer_test.go
Normal file
106
coordinator/util/grpcutil/dialer_test.go
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
package grpcutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/core"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/test/grpc_testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDial(t *testing.T) {
|
||||||
|
testCases := map[string]struct {
|
||||||
|
tls bool
|
||||||
|
dialFn func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
"Dial with tls on server works": {
|
||||||
|
tls: true,
|
||||||
|
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||||
|
return dialer.Dial(ctx, target)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"Dial without tls on server fails": {
|
||||||
|
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||||
|
return dialer.Dial(ctx, target)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"DialNoVerify with tls on server works": {
|
||||||
|
tls: true,
|
||||||
|
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||||
|
return dialer.DialNoVerify(ctx, target)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"DialNoVerify without tls on server fails": {
|
||||||
|
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||||
|
return dialer.DialNoVerify(ctx, target)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
"DialInsecure without tls on server works": {
|
||||||
|
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||||
|
return dialer.DialInsecure(ctx, target)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"DialInsecure with tls on server fails": {
|
||||||
|
tls: true,
|
||||||
|
dialFn: func(dialer *Dialer, ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||||
|
return dialer.DialInsecure(ctx, target)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
netDialer := testdialer.NewBufconnDialer()
|
||||||
|
dialer := NewDialer(&core.MockValidator{}, netDialer)
|
||||||
|
server := newServer(tc.tls)
|
||||||
|
api := &testAPI{}
|
||||||
|
grpc_testing.RegisterTestServiceServer(server, api)
|
||||||
|
go server.Serve(netDialer.GetListener("192.0.2.1:1234"))
|
||||||
|
defer server.Stop()
|
||||||
|
conn, err := tc.dialFn(dialer, context.Background(), "192.0.2.1:1234")
|
||||||
|
require.NoError(err)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := grpc_testing.NewTestServiceClient(conn)
|
||||||
|
_, err = client.EmptyCall(context.Background(), &grpc_testing.Empty{})
|
||||||
|
|
||||||
|
if tc.wantErr {
|
||||||
|
assert.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServer(tls bool) *grpc.Server {
|
||||||
|
if tls {
|
||||||
|
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
|
||||||
|
}
|
||||||
|
return grpc.NewServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
type testAPI struct {
|
||||||
|
grpc_testing.UnimplementedTestServiceServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testAPI) EmptyCall(ctx context.Context, in *grpc_testing.Empty) (*grpc_testing.Empty, error) {
|
||||||
|
return &grpc_testing.Empty{}, nil
|
||||||
|
}
|
|
@ -16,14 +16,14 @@ import (
|
||||||
|
|
||||||
// Download downloads a coordinator from a given debugd instance.
|
// Download downloads a coordinator from a given debugd instance.
|
||||||
type Download struct {
|
type Download struct {
|
||||||
dialer Dialer
|
dialer NetDialer
|
||||||
writer streamToFileWriter
|
writer streamToFileWriter
|
||||||
serviceManager serviceManager
|
serviceManager serviceManager
|
||||||
attemptedDownloads map[string]time.Time
|
attemptedDownloads map[string]time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Download.
|
// New creates a new Download.
|
||||||
func New(dialer Dialer, serviceManager serviceManager, writer streamToFileWriter) *Download {
|
func New(dialer NetDialer, serviceManager serviceManager, writer streamToFileWriter) *Download {
|
||||||
return &Download{
|
return &Download{
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
writer: writer,
|
writer: writer,
|
||||||
|
@ -91,6 +91,7 @@ type streamToFileWriter interface {
|
||||||
WriteStream(filename string, stream coordinator.ReadChunkStream, showProgress bool) error
|
WriteStream(filename string, stream coordinator.ReadChunkStream, showProgress bool) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Dialer interface {
|
// NetDialer can open a net.Conn.
|
||||||
|
type NetDialer interface {
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,10 +9,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||||
"github.com/edgelesssys/constellation/debugd/coordinator"
|
"github.com/edgelesssys/constellation/debugd/coordinator"
|
||||||
"github.com/edgelesssys/constellation/debugd/debugd"
|
"github.com/edgelesssys/constellation/debugd/debugd"
|
||||||
pb "github.com/edgelesssys/constellation/debugd/service"
|
pb "github.com/edgelesssys/constellation/debugd/service"
|
||||||
"github.com/edgelesssys/constellation/debugd/service/testdialer"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
|
@ -8,10 +8,10 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
|
||||||
"github.com/edgelesssys/constellation/debugd/coordinator"
|
"github.com/edgelesssys/constellation/debugd/coordinator"
|
||||||
"github.com/edgelesssys/constellation/debugd/debugd/deploy"
|
"github.com/edgelesssys/constellation/debugd/debugd/deploy"
|
||||||
pb "github.com/edgelesssys/constellation/debugd/service"
|
pb "github.com/edgelesssys/constellation/debugd/service"
|
||||||
"github.com/edgelesssys/constellation/debugd/service/testdialer"
|
|
||||||
"github.com/edgelesssys/constellation/debugd/ssh"
|
"github.com/edgelesssys/constellation/debugd/ssh"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -351,11 +351,11 @@ func (s *stubServiceManager) WriteSystemdUnitFile(ctx context.Context, unit depl
|
||||||
return s.writeSystemdUnitFileErr
|
return s.writeSystemdUnitFileErr
|
||||||
}
|
}
|
||||||
|
|
||||||
type dialer interface {
|
type netDialer interface {
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dial(ctx context.Context, dialer dialer, target string) (*grpc.ClientConn, error) {
|
func dial(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
|
||||||
return grpc.DialContext(ctx, target,
|
return grpc.DialContext(ctx, target,
|
||||||
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
return dialer.DialContext(ctx, "tcp", addr)
|
return dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
package testdialer
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"google.golang.org/grpc/test/bufconn"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BufconnDialer is a fake dialer based on gRPC bufconn package.
|
|
||||||
type BufconnDialer struct {
|
|
||||||
mut sync.Mutex
|
|
||||||
listeners map[string]*bufconn.Listener
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBufconnDialer creates a new bufconn dialer for testing.
|
|
||||||
func NewBufconnDialer() *BufconnDialer {
|
|
||||||
return &BufconnDialer{listeners: make(map[string]*bufconn.Listener)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext implements the Dialer interface.
|
|
||||||
func (b *BufconnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
b.mut.Lock()
|
|
||||||
listener, ok := b.listeners[address]
|
|
||||||
b.mut.Unlock()
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("could not connect to server on %v", address)
|
|
||||||
}
|
|
||||||
return listener.DialContext(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetListener returns a fake listener that is coupled with this dialer.
|
|
||||||
func (b *BufconnDialer) GetListener(endpoint string) net.Listener {
|
|
||||||
listener := bufconn.Listen(1024)
|
|
||||||
b.mut.Lock()
|
|
||||||
b.listeners[endpoint] = listener
|
|
||||||
b.mut.Unlock()
|
|
||||||
return listener
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue