package vpnapi

import (
	"context"
	"errors"
	"net"
	"testing"

	"github.com/edgelesssys/constellation/coordinator/peer"
	"github.com/edgelesssys/constellation/coordinator/vpnapi/vpnproto"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"go.uber.org/goleak"
	"go.uber.org/zap/zaptest"
	gpeer "google.golang.org/grpc/peer"
	kubeadm "k8s.io/kubernetes/cmd/kubeadm/app/apis/kubeadm/v1beta3"
)

func TestMain(m *testing.M) {
	goleak.VerifyTestMain(m)
}

func TestGetUpdate(t *testing.T) {
	someErr := errors.New("failed")
	clientIP := &net.IPAddr{IP: net.ParseIP("192.0.2.1")}
	peer1 := peer.Peer{PublicIP: "192.0.2.11", VPNIP: "192.0.2.21", VPNPubKey: []byte{1, 2, 3}}
	peer2 := peer.Peer{PublicIP: "192.0.2.12", VPNIP: "192.0.2.22", VPNPubKey: []byte{2, 3, 4}}
	peer3 := peer.Peer{PublicIP: "192.0.2.13", VPNIP: "192.0.2.23", VPNPubKey: []byte{3, 4, 5}}

	testCases := map[string]struct {
		clientAddr  net.Addr
		peers       []peer.Peer
		getPeersErr error
		wantErr     bool
	}{
		"0 peers": {
			clientAddr: clientIP,
			peers:      []peer.Peer{},
		},
		"1 peer": {
			clientAddr: clientIP,
			peers:      []peer.Peer{peer1},
		},
		"2 peers": {
			clientAddr: clientIP,
			peers:      []peer.Peer{peer1, peer2},
		},
		"3 peers": {
			clientAddr: clientIP,
			peers:      []peer.Peer{peer1, peer2, peer3},
		},
		"nil peers": {
			clientAddr: clientIP,
			peers:      nil,
		},
		"getPeers error": {
			clientAddr:  clientIP,
			getPeersErr: someErr,
			wantErr:     true,
		},
		"missing client addr": {
			peers: []peer.Peer{peer1},
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			const serverResourceVersion = 2
			const clientResourceVersion = 3

			core := &stubCore{peers: tc.peers, serverResourceVersion: serverResourceVersion, getPeersErr: tc.getPeersErr}
			api := New(zaptest.NewLogger(t), core)

			ctx := context.Background()
			if tc.clientAddr != nil {
				ctx = gpeer.NewContext(ctx, &gpeer.Peer{Addr: tc.clientAddr})
			}

			resp, err := api.GetUpdate(ctx, &vpnproto.GetUpdateRequest{ResourceVersion: clientResourceVersion})
			if tc.wantErr {
				assert.Error(err)
				return
			}
			require.NoError(err)

			assert.EqualValues(serverResourceVersion, resp.ResourceVersion)
			assert.Equal([]int{clientResourceVersion}, core.clientResourceVersions)

			require.Len(resp.Peers, len(tc.peers))
			for i, actual := range resp.Peers {
				want := tc.peers[i]
				assert.EqualValues(want.PublicIP, actual.PublicIp)
				assert.EqualValues(want.VPNIP, actual.VpnIp)
				assert.Equal(want.VPNPubKey, actual.VpnPubKey)
			}

			if tc.clientAddr == nil {
				assert.Empty(core.heartbeats)
			} else {
				assert.Equal([]net.Addr{tc.clientAddr}, core.heartbeats)
			}
		})
	}
}

func TestGetK8SJoinArgs(t *testing.T) {
	assert := assert.New(t)
	require := require.New(t)

	joinArgs := kubeadm.BootstrapTokenDiscovery{
		APIServerEndpoint: "endp",
		Token:             "token",
		CACertHashes:      []string{"dis"},
	}
	api := New(zaptest.NewLogger(t), &stubCore{joinArgs: joinArgs})

	resp, err := api.GetK8SJoinArgs(context.Background(), &vpnproto.GetK8SJoinArgsRequest{})
	require.NoError(err)
	assert.Equal(joinArgs.APIServerEndpoint, resp.ApiServerEndpoint)
	assert.Equal(joinArgs.Token, resp.Token)
	assert.Equal(joinArgs.CACertHashes[0], resp.DiscoveryTokenCaCertHash)
}

func TestGetDataKey(t *testing.T) {
	assert := assert.New(t)
	require := require.New(t)

	core := &stubCore{derivedKey: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5}}
	api := New(zaptest.NewLogger(t), core)
	res, err := api.GetDataKey(context.Background(), &vpnproto.GetDataKeyRequest{DataKeyId: "key-1", Length: 32})
	require.NoError(err)
	assert.Equal(core.derivedKey, res.DataKey)

	api = New(zaptest.NewLogger(t), &stubCore{deriveKeyErr: errors.New("error")})
	res, err = api.GetDataKey(context.Background(), &vpnproto.GetDataKeyRequest{DataKeyId: "key-1", Length: 32})
	assert.Error(err)
	assert.Nil(res)
}

func TestGetK8SCertificateKey(t *testing.T) {
	someErr := errors.New("someErr")
	certKey := "kubeadmKey"

	testCases := map[string]struct {
		certKey       string
		getCertKeyErr error
		wantErr       bool
	}{
		"basic": {
			certKey: certKey,
		},
		"error": {
			getCertKeyErr: someErr,
			wantErr:       true,
		},
	}
	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			core := &stubCore{
				kubeadmCertificateKey: certKey,
				getCertKeyErr:         tc.getCertKeyErr,
			}

			api := New(zaptest.NewLogger(t), core)
			resp, err := api.GetK8SCertificateKey(context.Background(), &vpnproto.GetK8SCertificateKeyRequest{})

			if tc.wantErr {
				assert.Error(err)
				return
			}
			require.NoError(err)
			assert.Equal(certKey, resp.CertificateKey)
		})
	}
}

type stubCore struct {
	peers                  []peer.Peer
	serverResourceVersion  int
	getPeersErr            error
	clientResourceVersions []int
	heartbeats             []net.Addr
	joinArgs               kubeadm.BootstrapTokenDiscovery
	kubeadmCertificateKey  string
	getCertKeyErr          error
	derivedKey             []byte
	deriveKeyErr           error
}

func (c *stubCore) GetPeers(resourceVersion int) (int, []peer.Peer, error) {
	c.clientResourceVersions = append(c.clientResourceVersions, resourceVersion)
	return c.serverResourceVersion, c.peers, c.getPeersErr
}

func (c *stubCore) NotifyNodeHeartbeat(addr net.Addr) {
	c.heartbeats = append(c.heartbeats, addr)
}

func (c *stubCore) GetK8sJoinArgs(context.Context) (*kubeadm.BootstrapTokenDiscovery, error) {
	return &c.joinArgs, nil
}

func (c *stubCore) GetK8SCertificateKey(context.Context) (string, error) {
	return c.kubeadmCertificateKey, c.getCertKeyErr
}

func (c *stubCore) GetDataKey(ctx context.Context, dataKeyID string, length int) ([]byte, error) {
	if c.deriveKeyErr != nil {
		return nil, c.deriveKeyErr
	}
	return c.derivedKey, nil
}