mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-15 01:17:13 -05:00
219 lines
5.8 KiB
Go
219 lines
5.8 KiB
Go
package vpnapi
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"testing"
|
|
|
|
"github.com/edgelesssys/constellation/coordinator/peer"
|
|
"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/goleak"
|
|
"go.uber.org/zap/zaptest"
|
|
gpeer "google.golang.org/grpc/peer"
|
|
kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
goleak.VerifyTestMain(m)
|
|
}
|
|
|
|
func TestGetUpdate(t *testing.T) {
|
|
someErr := errors.New("failed")
|
|
clientIP := &net.IPAddr{IP: net.ParseIP("192.0.2.1")}
|
|
peer1 := peer.Peer{PublicIP: "192.0.2.11", VPNIP: "192.0.2.21", VPNPubKey: []byte{1, 2, 3}}
|
|
peer2 := peer.Peer{PublicIP: "192.0.2.12", VPNIP: "192.0.2.22", VPNPubKey: []byte{2, 3, 4}}
|
|
peer3 := peer.Peer{PublicIP: "192.0.2.13", VPNIP: "192.0.2.23", VPNPubKey: []byte{3, 4, 5}}
|
|
|
|
testCases := map[string]struct {
|
|
clientAddr net.Addr
|
|
peers []peer.Peer
|
|
getPeersErr error
|
|
wantErr bool
|
|
}{
|
|
"0 peers": {
|
|
clientAddr: clientIP,
|
|
peers: []peer.Peer{},
|
|
},
|
|
"1 peer": {
|
|
clientAddr: clientIP,
|
|
peers: []peer.Peer{peer1},
|
|
},
|
|
"2 peers": {
|
|
clientAddr: clientIP,
|
|
peers: []peer.Peer{peer1, peer2},
|
|
},
|
|
"3 peers": {
|
|
clientAddr: clientIP,
|
|
peers: []peer.Peer{peer1, peer2, peer3},
|
|
},
|
|
"nil peers": {
|
|
clientAddr: clientIP,
|
|
peers: nil,
|
|
},
|
|
"getPeers error": {
|
|
clientAddr: clientIP,
|
|
getPeersErr: someErr,
|
|
wantErr: true,
|
|
},
|
|
"missing client addr": {
|
|
peers: []peer.Peer{peer1},
|
|
},
|
|
}
|
|
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
const serverResourceVersion = 2
|
|
const clientResourceVersion = 3
|
|
|
|
core := &stubCore{peers: tc.peers, serverResourceVersion: serverResourceVersion, getPeersErr: tc.getPeersErr}
|
|
api := New(zaptest.NewLogger(t), core)
|
|
|
|
ctx := context.Background()
|
|
if tc.clientAddr != nil {
|
|
ctx = gpeer.NewContext(ctx, &gpeer.Peer{Addr: tc.clientAddr})
|
|
}
|
|
|
|
resp, err := api.GetUpdate(ctx, &vpnproto.GetUpdateRequest{ResourceVersion: clientResourceVersion})
|
|
if tc.wantErr {
|
|
assert.Error(err)
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
|
|
assert.EqualValues(serverResourceVersion, resp.ResourceVersion)
|
|
assert.Equal([]int{clientResourceVersion}, core.clientResourceVersions)
|
|
|
|
require.Len(resp.Peers, len(tc.peers))
|
|
for i, actual := range resp.Peers {
|
|
want := tc.peers[i]
|
|
assert.EqualValues(want.PublicIP, actual.PublicIp)
|
|
assert.EqualValues(want.VPNIP, actual.VpnIp)
|
|
assert.Equal(want.VPNPubKey, actual.VpnPubKey)
|
|
}
|
|
|
|
if tc.clientAddr == nil {
|
|
assert.Empty(core.heartbeats)
|
|
} else {
|
|
assert.Equal([]net.Addr{tc.clientAddr}, core.heartbeats)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetK8SJoinArgs(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
joinArgs := kubeadm.BootstrapTokenDiscovery{
|
|
APIServerEndpoint: "endp",
|
|
Token: "token",
|
|
CACertHashes: []string{"dis"},
|
|
}
|
|
api := New(zaptest.NewLogger(t), &stubCore{joinArgs: joinArgs})
|
|
|
|
resp, err := api.GetK8SJoinArgs(context.Background(), &vpnproto.GetK8SJoinArgsRequest{})
|
|
require.NoError(err)
|
|
assert.Equal(joinArgs.APIServerEndpoint, resp.ApiServerEndpoint)
|
|
assert.Equal(joinArgs.Token, resp.Token)
|
|
assert.Equal(joinArgs.CACertHashes[0], resp.DiscoveryTokenCaCertHash)
|
|
}
|
|
|
|
func TestGetDataKey(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
core := &stubCore{derivedKey: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5}}
|
|
api := New(zaptest.NewLogger(t), core)
|
|
res, err := api.GetDataKey(context.Background(), &vpnproto.GetDataKeyRequest{DataKeyId: "key-1", Length: 32})
|
|
require.NoError(err)
|
|
assert.Equal(core.derivedKey, res.DataKey)
|
|
|
|
api = New(zaptest.NewLogger(t), &stubCore{deriveKeyErr: errors.New("error")})
|
|
res, err = api.GetDataKey(context.Background(), &vpnproto.GetDataKeyRequest{DataKeyId: "key-1", Length: 32})
|
|
assert.Error(err)
|
|
assert.Nil(res)
|
|
}
|
|
|
|
func TestGetK8SCertificateKey(t *testing.T) {
|
|
someErr := errors.New("someErr")
|
|
certKey := "kubeadmKey"
|
|
|
|
testCases := map[string]struct {
|
|
certKey string
|
|
getCertKeyErr error
|
|
wantErr bool
|
|
}{
|
|
"basic": {
|
|
certKey: certKey,
|
|
},
|
|
"error": {
|
|
getCertKeyErr: someErr,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
for name, tc := range testCases {
|
|
t.Run(name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
require := require.New(t)
|
|
|
|
core := &stubCore{
|
|
kubeadmCertificateKey: certKey,
|
|
getCertKeyErr: tc.getCertKeyErr,
|
|
}
|
|
|
|
api := New(zaptest.NewLogger(t), core)
|
|
resp, err := api.GetK8SCertificateKey(context.Background(), &vpnproto.GetK8SCertificateKeyRequest{})
|
|
|
|
if tc.wantErr {
|
|
assert.Error(err)
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
assert.Equal(certKey, resp.CertificateKey)
|
|
})
|
|
}
|
|
}
|
|
|
|
type stubCore struct {
|
|
peers []peer.Peer
|
|
serverResourceVersion int
|
|
getPeersErr error
|
|
clientResourceVersions []int
|
|
heartbeats []net.Addr
|
|
joinArgs kubeadm.BootstrapTokenDiscovery
|
|
kubeadmCertificateKey string
|
|
getCertKeyErr error
|
|
derivedKey []byte
|
|
deriveKeyErr error
|
|
}
|
|
|
|
func (c *stubCore) GetPeers(resourceVersion int) (int, []peer.Peer, error) {
|
|
c.clientResourceVersions = append(c.clientResourceVersions, resourceVersion)
|
|
return c.serverResourceVersion, c.peers, c.getPeersErr
|
|
}
|
|
|
|
func (c *stubCore) NotifyNodeHeartbeat(addr net.Addr) {
|
|
c.heartbeats = append(c.heartbeats, addr)
|
|
}
|
|
|
|
func (c *stubCore) GetK8sJoinArgs(context.Context) (*kubeadm.BootstrapTokenDiscovery, error) {
|
|
return &c.joinArgs, nil
|
|
}
|
|
|
|
func (c *stubCore) GetK8SCertificateKey(context.Context) (string, error) {
|
|
return c.kubeadmCertificateKey, c.getCertKeyErr
|
|
}
|
|
|
|
func (c *stubCore) GetDataKey(ctx context.Context, dataKeyID string, length int) ([]byte, error) {
|
|
if c.deriveKeyErr != nil {
|
|
return nil, c.deriveKeyErr
|
|
}
|
|
return c.derivedKey, nil
|
|
}
|