constellation/coordinator/pubapi/node_test.go

440 lines
13 KiB
Go
Raw Normal View History

package pubapi
import (
"context"
"errors"
"io"
"net"
"testing"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/coordinator/peer"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/role"
"github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util/grpcutil"
"github.com/edgelesssys/constellation/coordinator/util/testdialer"
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap/zaptest"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3"
)
func TestActivateAsNode(t *testing.T) {
someErr := errors.New("failed")
peer1 := peer.Peer{PublicIP: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1, 2, 3}}
peer2 := peer.Peer{PublicIP: "192.0.2.12:2000", VPNIP: "192.0.2.22", VPNPubKey: []byte{2, 3, 4}}
testCases := map[string]struct {
initialPeers []peer.Peer
updatedPeers []peer.Peer
state state.State
getUpdateErr error
setVPNIPErr error
messageSequenceOverride []string
wantErr bool
wantState state.State
}{
"basic": {
initialPeers: []peer.Peer{peer1},
updatedPeers: []peer.Peer{peer2},
state: state.AcceptingInit,
wantState: state.NodeWaitingForClusterJoin,
},
"already activated": {
initialPeers: []peer.Peer{peer1},
updatedPeers: []peer.Peer{peer2},
state: state.IsNode,
wantErr: true,
wantState: state.IsNode,
},
"wrong peer kind": {
initialPeers: []peer.Peer{peer1},
updatedPeers: []peer.Peer{peer2},
state: state.ActivatingNodes,
wantErr: true,
wantState: state.ActivatingNodes,
},
"GetUpdate error": {
initialPeers: []peer.Peer{peer1},
updatedPeers: []peer.Peer{peer2},
state: state.AcceptingInit,
getUpdateErr: someErr,
wantState: state.NodeWaitingForClusterJoin,
},
"SetVPNIP error": {
initialPeers: []peer.Peer{peer1},
updatedPeers: []peer.Peer{peer2},
state: state.AcceptingInit,
setVPNIPErr: someErr,
wantErr: true,
wantState: state.Failed,
},
"no messages sent to node": {
initialPeers: []peer.Peer{peer1},
updatedPeers: []peer.Peer{peer2},
state: state.AcceptingInit,
messageSequenceOverride: []string{},
wantErr: true,
wantState: state.AcceptingInit,
},
"only initial message sent to node": {
initialPeers: []peer.Peer{peer1},
updatedPeers: []peer.Peer{peer2},
state: state.AcceptingInit,
messageSequenceOverride: []string{"initialRequest"},
wantErr: true,
wantState: state.Failed,
},
"wrong initial message sent to node": {
initialPeers: []peer.Peer{peer1},
updatedPeers: []peer.Peer{peer2},
state: state.AcceptingInit,
messageSequenceOverride: []string{"stateDiskKey"},
wantErr: true,
wantState: state.AcceptingInit,
},
"initial message sent twice to node": {
initialPeers: []peer.Peer{peer1},
updatedPeers: []peer.Peer{peer2},
state: state.AcceptingInit,
messageSequenceOverride: []string{"initialRequest", "initialRequest"},
wantErr: true,
wantState: state.Failed,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
const (
nodeIP = "192.0.2.2"
nodeVPNIP = "10.118.0.2"
)
vpnPubKey := []byte{7, 8, 9}
ownerID := []byte("ownerID")
clusterID := []byte("clusterID")
stateDiskKey := []byte("stateDiskKey")
messageSequence := []string{"initialRequest", "stateDiskKey"}
if tc.messageSequenceOverride != nil {
messageSequence = tc.messageSequenceOverride
}
logger := zaptest.NewLogger(t)
cor := &fakeCore{state: tc.state, vpnPubKey: vpnPubKey, setVPNIPErr: tc.setVPNIPErr}
netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
api := New(logger, cor, dialer, nil, nil, nil)
defer api.Close()
vserver := grpc.NewServer()
vapi := &stubVPNAPI{peers: tc.updatedPeers, getUpdateErr: tc.getUpdateErr}
vpnproto.RegisterAPIServer(vserver, vapi)
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
defer vserver.GracefulStop()
tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{})
require.NoError(err)
pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig)))
pubproto.RegisterAPIServer(pubserver, api)
go pubserver.Serve(netDialer.GetListener(net.JoinHostPort(nodeIP, endpointAVPNPort)))
defer pubserver.GracefulStop()
_, nodeVPNPubKey, err := activateNode(require, netDialer, messageSequence, nodeIP, "9000", nodeVPNIP, peer.ToPubProto(tc.initialPeers), ownerID, clusterID, stateDiskKey)
assert.Equal(tc.wantState, cor.state)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(vpnPubKey, nodeVPNPubKey)
assert.Equal(nodeVPNIP, cor.vpnIP)
assert.Equal(ownerID, cor.ownerID)
assert.Equal(clusterID, cor.clusterID)
api.Close() // blocks until update loop finished
if tc.getUpdateErr == nil {
require.Len(cor.updatedPeers, 2)
assert.Equal(tc.updatedPeers, cor.updatedPeers[1])
} else {
require.Len(cor.updatedPeers, 1)
}
assert.Equal(tc.initialPeers, cor.updatedPeers[0])
assert.Equal([]role.Role{role.Node}, cor.persistNodeStateRoles)
})
}
}
func TestTriggerNodeUpdate(t *testing.T) {
someErr := errors.New("failed")
peers := []peer.Peer{
{PublicIP: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1, 2, 3}},
{PublicIP: "192.0.2.12:2000", VPNIP: "192.0.2.22", VPNPubKey: []byte{2, 3, 4}},
}
testCases := map[string]struct {
peers []peer.Peer
state state.State
getUpdateErr error
wantErr bool
}{
"basic": {
peers: peers,
state: state.IsNode,
},
"not activated": {
peers: peers,
state: state.AcceptingInit,
wantErr: true,
},
"wrong peer kind": {
peers: peers,
state: state.ActivatingNodes,
wantErr: true,
},
"GetUpdate error": {
peers: peers,
state: state.IsNode,
getUpdateErr: someErr,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
logger := zaptest.NewLogger(t)
core := &fakeCore{state: tc.state}
netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
api := New(logger, core, dialer, nil, nil, nil)
vserver := grpc.NewServer()
vapi := &stubVPNAPI{
peers: tc.peers,
getUpdateErr: tc.getUpdateErr,
}
vpnproto.RegisterAPIServer(vserver, vapi)
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort)))
defer vserver.GracefulStop()
_, err := api.TriggerNodeUpdate(context.Background(), &pubproto.TriggerNodeUpdateRequest{})
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
// second update should be a noop
_, err = api.TriggerNodeUpdate(context.Background(), &pubproto.TriggerNodeUpdateRequest{})
require.NoError(err)
require.Len(core.updatedPeers, 1)
assert.Equal(tc.peers, core.updatedPeers[0])
})
}
}
func TestJoinCluster(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
state state.State
getJoinArgsErr error
joinClusterErr error
wantErr bool
wantState state.State
}{
"basic": {
state: state.NodeWaitingForClusterJoin,
wantState: state.IsNode,
},
"not activated": {
state: state.AcceptingInit,
wantErr: true,
wantState: state.AcceptingInit,
},
"wrong peer kind": {
state: state.ActivatingNodes,
wantErr: true,
wantState: state.ActivatingNodes,
},
"GetK8sJoinArgs error": {
state: state.NodeWaitingForClusterJoin,
getJoinArgsErr: someErr,
wantErr: true,
wantState: state.NodeWaitingForClusterJoin,
},
"JoinCluster error": {
state: state.NodeWaitingForClusterJoin,
joinClusterErr: someErr,
wantErr: true,
wantState: state.Failed,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
logger := zaptest.NewLogger(t)
core := &fakeCore{state: tc.state, joinClusterErr: tc.joinClusterErr}
netDialer := testdialer.NewBufconnDialer()
dialer := grpcutil.NewDialer(fakeValidator{}, netDialer)
api := New(logger, core, dialer, nil, nil, nil)
vserver := grpc.NewServer()
vapi := &stubVPNAPI{
joinArgs: kubeadm.BootstrapTokenDiscovery{
APIServerEndpoint: "endp",
Token: "token",
CACertHashes: []string{"dis"},
},
getJoinArgsErr: tc.getJoinArgsErr,
}
vpnproto.RegisterAPIServer(vserver, vapi)
go vserver.Serve(netDialer.GetListener(net.JoinHostPort("192.0.2.1", vpnAPIPort)))
defer vserver.GracefulStop()
_, err := api.JoinCluster(context.Background(), &pubproto.JoinClusterRequest{CoordinatorVpnIp: "192.0.2.1"})
assert.Equal(tc.wantState, core.state)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal([]kubeadm.BootstrapTokenDiscovery{vapi.joinArgs}, core.joinArgs)
})
}
}
func activateNode(require *require.Assertions, dialer netDialer, messageSequence []string, nodeIP, bindPort, nodeVPNIP string, peers []*pubproto.Peer, ownerID, clusterID, stateDiskKey []byte) (string, []byte, error) {
ctx := context.Background()
conn, err := dialGRPC(ctx, dialer, net.JoinHostPort(nodeIP, bindPort))
require.NoError(err)
defer conn.Close()
client := pubproto.NewAPIClient(conn)
stream, err := client.ActivateAsNode(ctx)
if err != nil {
return "", nil, err
}
for _, message := range messageSequence {
switch message {
case "initialRequest":
err = stream.Send(&pubproto.ActivateAsNodeRequest{
Request: &pubproto.ActivateAsNodeRequest_InitialRequest{
InitialRequest: &pubproto.ActivateAsNodeInitialRequest{
NodeVpnIp: nodeVPNIP,
Peers: peers,
OwnerId: ownerID,
ClusterId: clusterID,
},
},
})
if err != nil {
return "", nil, err
}
case "stateDiskKey":
err = stream.Send(&pubproto.ActivateAsNodeRequest{
Request: &pubproto.ActivateAsNodeRequest_StateDiskKey{
StateDiskKey: stateDiskKey,
},
})
if err != nil {
return "", nil, err
}
default:
panic("unknown message in activation")
}
}
require.NoError(stream.CloseSend())
diskUUIDReq, err := stream.Recv()
if err != nil {
return "", nil, err
}
diskUUID := diskUUIDReq.GetStateDiskUuid()
vpnPubKeyReq, err := stream.Recv()
if err != nil {
return "", nil, err
}
nodeVPNPubKey := vpnPubKeyReq.GetNodeVpnPubKey()
_, err = stream.Recv()
if err != io.EOF {
return "", nil, err
}
return diskUUID, nodeVPNPubKey, nil
}
func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}})
if err != nil {
return nil, err
}
return grpc.DialContext(ctx, target,
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", addr)
}),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
type stubVPNAPI struct {
peers []peer.Peer
joinArgs kubeadm.BootstrapTokenDiscovery
getUpdateErr error
getJoinArgsErr error
getK8SCertKeyErr error
vpnproto.UnimplementedAPIServer
}
func (a *stubVPNAPI) GetUpdate(ctx context.Context, in *vpnproto.GetUpdateRequest) (*vpnproto.GetUpdateResponse, error) {
return &vpnproto.GetUpdateResponse{ResourceVersion: 1, Peers: peer.ToVPNProto(a.peers)}, a.getUpdateErr
}
func (a *stubVPNAPI) GetK8SJoinArgs(ctx context.Context, in *vpnproto.GetK8SJoinArgsRequest) (*vpnproto.GetK8SJoinArgsResponse, error) {
return &vpnproto.GetK8SJoinArgsResponse{
ApiServerEndpoint: a.joinArgs.APIServerEndpoint,
Token: a.joinArgs.Token,
DiscoveryTokenCaCertHash: a.joinArgs.CACertHashes[0],
}, a.getJoinArgsErr
}
func (a *stubVPNAPI) GetK8SCertificateKey(ctx context.Context, in *vpnproto.GetK8SCertificateKeyRequest) (*vpnproto.GetK8SCertificateKeyResponse, error) {
return &vpnproto.GetK8SCertificateKeyResponse{CertificateKey: "dummyCertKey"}, a.getK8SCertKeyErr
}
func (a *stubVPNAPI) newServer() *grpc.Server {
server := grpc.NewServer()
vpnproto.RegisterAPIServer(server, a)
return server
}
type netDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}