mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
move updatePeers directly to the VPN and omit the store layer (#4)
This commit is contained in:
parent
6bbb783af8
commit
6f695892bf
@ -57,7 +57,11 @@ func main() {
|
||||
}
|
||||
zapLoggerCore := zapLogger.Named("core")
|
||||
|
||||
wg := wireguard.New()
|
||||
wg, err := wireguard.New()
|
||||
if err != nil {
|
||||
zapLogger.Panic("error opening wgctrl client")
|
||||
}
|
||||
defer wg.Close()
|
||||
|
||||
var issuer core.QuoteIssuer
|
||||
var validator core.QuoteValidator
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/coordinator/core"
|
||||
"github.com/edgelesssys/constellation/coordinator/kms"
|
||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/store"
|
||||
@ -339,6 +340,19 @@ func (v *fakeVPN) RemovePeer(pubKey []byte) error {
|
||||
panic("dummy")
|
||||
}
|
||||
|
||||
func (v *fakeVPN) UpdatePeers(peers []peer.Peer) error {
|
||||
for _, peer := range peers {
|
||||
peerIP, _, err := net.SplitHostPort(peer.PublicEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v.AddPeer(peer.VPNPubKey, peerIP, peer.VPNIP); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *fakeVPN) send(dst string, data string) {
|
||||
pubdst := v.peers[dst]
|
||||
packets := v.netw.packets
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||
"github.com/edgelesssys/constellation/coordinator/storewrapper"
|
||||
"go.uber.org/multierr"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@ -78,45 +77,5 @@ func (c *Core) AddPeer(peer peer.Peer) error {
|
||||
|
||||
// UpdatePeers synchronizes the peers known to the store and the vpn with the passed peers.
|
||||
func (c *Core) UpdatePeers(peers []peer.Peer) error {
|
||||
// exclude myself
|
||||
myIP, err := c.vpn.GetInterfaceIP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i, p := range peers {
|
||||
if p.VPNIP == myIP {
|
||||
peers = append(peers[:i], peers[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := c.store.BeginTransaction()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
added, removed, err := storewrapper.StoreWrapper{Store: tx}.UpdatePeers(peers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// perform remove and add on the vpn
|
||||
var vpnErr error
|
||||
for _, p := range removed {
|
||||
vpnErr = multierr.Append(vpnErr, c.vpn.RemovePeer(p.VPNPubKey))
|
||||
}
|
||||
for _, p := range added {
|
||||
publicIP, _, err := net.SplitHostPort(p.PublicEndpoint)
|
||||
if err != nil {
|
||||
vpnErr = multierr.Append(vpnErr, err)
|
||||
continue
|
||||
}
|
||||
vpnErr = multierr.Append(vpnErr, c.vpn.AddPeer(p.VPNPubKey, publicIP, p.VPNIP))
|
||||
}
|
||||
if vpnErr != nil {
|
||||
return vpnErr
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
return c.vpn.UpdatePeers(peers)
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package core
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||
@ -139,129 +138,3 @@ func TestAddPeer(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdatePeer(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
peer1 := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1}}
|
||||
peer1KeyUpd := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1, 1}}
|
||||
peer1EndpUpd := peer.Peer{PublicEndpoint: "192.0.2.110:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1}}
|
||||
peer2 := peer.Peer{PublicEndpoint: "192.0.2.12:2000", VPNIP: "192.0.2.22", VPNPubKey: []byte{2}}
|
||||
peer3 := peer.Peer{PublicEndpoint: "192.0.2.13:2000", VPNIP: "192.0.2.23", VPNPubKey: []byte{3}}
|
||||
|
||||
makeVPNPeers := func(peers ...peer.Peer) []stubVPNPeer {
|
||||
var result []stubVPNPeer
|
||||
for _, p := range peers {
|
||||
publicIP, _, err := net.SplitHostPort(p.PublicEndpoint)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
result = append(result, stubVPNPeer{pubKey: p.VPNPubKey, publicIP: publicIP, vpnIP: p.VPNIP})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
peers []peer.Peer
|
||||
storePeers []peer.Peer
|
||||
vpn stubVPN
|
||||
expectErr bool
|
||||
expectedVPNPeers []stubVPNPeer
|
||||
expectedStorePeers []peer.Peer
|
||||
}{
|
||||
"basic": {
|
||||
peers: []peer.Peer{peer1, peer3},
|
||||
storePeers: []peer.Peer{peer1, peer2},
|
||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
|
||||
expectedVPNPeers: makeVPNPeers(peer1, peer3),
|
||||
expectedStorePeers: []peer.Peer{peer1, peer3},
|
||||
},
|
||||
"previously empty": {
|
||||
peers: []peer.Peer{peer1, peer2},
|
||||
expectedVPNPeers: makeVPNPeers(peer1, peer2),
|
||||
expectedStorePeers: []peer.Peer{peer1, peer2},
|
||||
},
|
||||
"no changes": {
|
||||
peers: []peer.Peer{peer1, peer2},
|
||||
storePeers: []peer.Peer{peer1, peer2},
|
||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
|
||||
expectedVPNPeers: makeVPNPeers(peer1, peer2),
|
||||
expectedStorePeers: []peer.Peer{peer1, peer2},
|
||||
},
|
||||
"key update": {
|
||||
peers: []peer.Peer{peer1KeyUpd, peer3},
|
||||
storePeers: []peer.Peer{peer1, peer2},
|
||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
|
||||
expectedVPNPeers: makeVPNPeers(peer1KeyUpd, peer3),
|
||||
expectedStorePeers: []peer.Peer{peer1KeyUpd, peer3},
|
||||
},
|
||||
"public endpoint update": {
|
||||
peers: []peer.Peer{peer1EndpUpd, peer3},
|
||||
storePeers: []peer.Peer{peer1, peer2},
|
||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
|
||||
expectedVPNPeers: makeVPNPeers(peer1EndpUpd, peer3),
|
||||
expectedStorePeers: []peer.Peer{peer1EndpUpd, peer3},
|
||||
},
|
||||
"don't add self": {
|
||||
peers: []peer.Peer{peer1, peer3},
|
||||
storePeers: []peer.Peer{peer1, peer2},
|
||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), interfaceIP: peer3.VPNIP},
|
||||
expectedVPNPeers: makeVPNPeers(peer1),
|
||||
expectedStorePeers: []peer.Peer{peer1},
|
||||
},
|
||||
"public endpoint without port": {
|
||||
peers: []peer.Peer{
|
||||
peer1,
|
||||
{
|
||||
PublicEndpoint: "192.0.2.13",
|
||||
VPNIP: "192.0.2.23",
|
||||
VPNPubKey: []byte{3},
|
||||
},
|
||||
},
|
||||
storePeers: []peer.Peer{peer1, peer2},
|
||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
|
||||
expectErr: true,
|
||||
},
|
||||
"vpn add peer error": {
|
||||
peers: []peer.Peer{peer1, peer3},
|
||||
storePeers: []peer.Peer{peer1, peer2},
|
||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), addPeerErr: someErr},
|
||||
expectErr: true,
|
||||
},
|
||||
"vpn remove peer error": {
|
||||
peers: []peer.Peer{peer1, peer3},
|
||||
storePeers: []peer.Peer{peer1, peer2},
|
||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), removePeerErr: someErr},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
core, err := NewCore(&tc.vpn, nil, nil, nil, nil, zaptest.NewLogger(t), nil, nil)
|
||||
require.NoError(err)
|
||||
|
||||
// prepare store
|
||||
for _, p := range tc.storePeers {
|
||||
require.NoError(core.data().PutPeer(p))
|
||||
}
|
||||
|
||||
updateErr := core.UpdatePeers(tc.peers)
|
||||
|
||||
actualStorePeers, err := core.data().GetPeers()
|
||||
require.NoError(err)
|
||||
|
||||
if tc.expectErr {
|
||||
assert.Error(updateErr)
|
||||
assert.ElementsMatch(tc.storePeers, actualStorePeers, "store has been changed despite failure")
|
||||
return
|
||||
}
|
||||
require.NoError(updateErr)
|
||||
|
||||
assert.ElementsMatch(tc.expectedVPNPeers, tc.vpn.peers)
|
||||
assert.ElementsMatch(tc.expectedStorePeers, actualStorePeers)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -3,6 +3,9 @@ package core
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||
)
|
||||
|
||||
type VPN interface {
|
||||
@ -12,6 +15,7 @@ type VPN interface {
|
||||
SetInterfaceIP(ip string) error
|
||||
AddPeer(pubKey []byte, publicIP string, vpnIP string) error
|
||||
RemovePeer(pubKey []byte) error
|
||||
UpdatePeers(peers []peer.Peer) error
|
||||
}
|
||||
|
||||
type stubVPN struct {
|
||||
@ -57,6 +61,19 @@ func (v *stubVPN) RemovePeer(pubKey []byte) error {
|
||||
return v.removePeerErr
|
||||
}
|
||||
|
||||
func (v *stubVPN) UpdatePeers(peers []peer.Peer) error {
|
||||
for _, peer := range peers {
|
||||
peerIP, _, err := net.SplitHostPort(peer.PublicEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := v.AddPeer(peer.VPNPubKey, peerIP, peer.VPNIP); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type stubVPNPeer struct {
|
||||
pubKey []byte
|
||||
publicIP string
|
||||
|
@ -1,12 +1,15 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||
"github.com/edgelesssys/constellation/coordinator/util"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
@ -18,20 +21,22 @@ const (
|
||||
port = 51820
|
||||
)
|
||||
|
||||
type Wireguard struct{}
|
||||
type Wireguard struct {
|
||||
client wgClient
|
||||
getInterfaceIP func(string) (string, error)
|
||||
}
|
||||
|
||||
func New() *Wireguard {
|
||||
return &Wireguard{}
|
||||
func New() (*Wireguard, error) {
|
||||
client, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Wireguard{client: client, getInterfaceIP: util.GetInterfaceIP}, nil
|
||||
}
|
||||
|
||||
func (w *Wireguard) Setup(privKey []byte) ([]byte, error) {
|
||||
client, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open wgctrl: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
var key wgtypes.Key
|
||||
var err error
|
||||
if len(privKey) == 0 {
|
||||
key, err = wgtypes.GeneratePrivateKey()
|
||||
} else {
|
||||
@ -42,7 +47,7 @@ func (w *Wireguard) Setup(privKey []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
listenPort := port
|
||||
if err := client.ConfigureDevice(netInterface, wgtypes.Config{PrivateKey: &key, ListenPort: &listenPort}); err != nil {
|
||||
if err := w.client.ConfigureDevice(netInterface, wgtypes.Config{PrivateKey: &key, ListenPort: &listenPort}); err != nil {
|
||||
return nil, prettyWgError(err)
|
||||
}
|
||||
|
||||
@ -59,7 +64,7 @@ func (w *Wireguard) GetPublicKey(privKey []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
func (w *Wireguard) GetInterfaceIP() (string, error) {
|
||||
return util.GetInterfaceIP(netInterface)
|
||||
return w.getInterfaceIP(netInterface)
|
||||
}
|
||||
|
||||
// SetInterfaceIP sets the ip interface ip.
|
||||
@ -80,12 +85,6 @@ func (w *Wireguard) SetInterfaceIP(ip string) error {
|
||||
|
||||
// AddPeer adds a new peer to a wireguard interface.
|
||||
func (w *Wireguard) AddPeer(pubKey []byte, publicIP string, vpnIP string) error {
|
||||
client, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open wgctrl: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
_, allowedIPs, err := net.ParseCIDR(vpnIP + "/32")
|
||||
if err != nil {
|
||||
return err
|
||||
@ -115,17 +114,11 @@ func (w *Wireguard) AddPeer(pubKey []byte, publicIP string, vpnIP string) error
|
||||
},
|
||||
}
|
||||
|
||||
return prettyWgError(client.ConfigureDevice(netInterface, cfg))
|
||||
return prettyWgError(w.client.ConfigureDevice(netInterface, cfg))
|
||||
}
|
||||
|
||||
// RemovePeer removes a peer from the wireguard interface.
|
||||
func (w *Wireguard) RemovePeer(pubKey []byte) error {
|
||||
client, err := wgctrl.New()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open wgctrl: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
key, err := wgtypes.NewKey(pubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -133,7 +126,7 @@ func (w *Wireguard) RemovePeer(pubKey []byte) error {
|
||||
|
||||
cfg := wgtypes.Config{Peers: []wgtypes.PeerConfig{{PublicKey: key, Remove: true}}}
|
||||
|
||||
return prettyWgError(client.ConfigureDevice(netInterface, cfg))
|
||||
return prettyWgError(w.client.ConfigureDevice(netInterface, cfg))
|
||||
}
|
||||
|
||||
func prettyWgError(err error) error {
|
||||
@ -142,3 +135,126 @@ func prettyWgError(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (w *Wireguard) UpdatePeers(peers []peer.Peer) error {
|
||||
ownVPNIP, err := w.getInterfaceIP(netInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to obtain vpn ip: %w", err)
|
||||
}
|
||||
wgPeers, err := transformToWgpeer(peers, ownVPNIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to transform peers to wireguard-peers: %w", err)
|
||||
}
|
||||
|
||||
deviceData, err := w.client.Device(netInterface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to obtain device data: %w", err)
|
||||
}
|
||||
// convert to map for easier lookup
|
||||
storePeers := make(map[string]wgtypes.Peer)
|
||||
for _, p := range wgPeers {
|
||||
storePeers[p.AllowedIPs[0].String()] = p
|
||||
}
|
||||
var added []wgtypes.Peer
|
||||
var removed []wgtypes.Peer
|
||||
var updated []wgtypes.Peer
|
||||
|
||||
for _, interfacePeer := range deviceData.Peers {
|
||||
if updPeer, ok := storePeers[interfacePeer.AllowedIPs[0].String()]; ok {
|
||||
if updPeer.Endpoint.String() != interfacePeer.Endpoint.String() {
|
||||
updated = append(updated, updPeer)
|
||||
}
|
||||
if !bytes.Equal(updPeer.PublicKey[:], interfacePeer.PublicKey[:]) {
|
||||
added = append(added, updPeer)
|
||||
removed = append(removed, interfacePeer)
|
||||
}
|
||||
delete(storePeers, updPeer.AllowedIPs[0].String())
|
||||
} else {
|
||||
removed = append(removed, interfacePeer)
|
||||
}
|
||||
}
|
||||
// remaining store peers are new ones
|
||||
for _, peer := range storePeers {
|
||||
added = append(added, peer)
|
||||
}
|
||||
|
||||
keepAlive := 10 * time.Second
|
||||
var newPeerConfig []wgtypes.PeerConfig
|
||||
for _, peer := range removed {
|
||||
newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{
|
||||
// pub Key for remove matching is enought
|
||||
PublicKey: peer.PublicKey,
|
||||
Remove: true,
|
||||
})
|
||||
}
|
||||
for _, peer := range updated {
|
||||
newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{
|
||||
PublicKey: peer.PublicKey,
|
||||
Remove: false,
|
||||
UpdateOnly: true,
|
||||
Endpoint: peer.Endpoint,
|
||||
})
|
||||
}
|
||||
for _, peer := range added {
|
||||
newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{
|
||||
PublicKey: peer.PublicKey,
|
||||
Remove: false,
|
||||
UpdateOnly: false,
|
||||
Endpoint: peer.Endpoint,
|
||||
AllowedIPs: peer.AllowedIPs,
|
||||
// needed, otherwise gRPC has problems establishing the initial connection.
|
||||
PersistentKeepaliveInterval: &keepAlive,
|
||||
})
|
||||
}
|
||||
if len(newPeerConfig) == 0 {
|
||||
return nil
|
||||
}
|
||||
cfg := wgtypes.Config{
|
||||
ReplacePeers: false,
|
||||
Peers: newPeerConfig,
|
||||
}
|
||||
return prettyWgError(w.client.ConfigureDevice(netInterface, cfg))
|
||||
}
|
||||
|
||||
func (w *Wireguard) Close() error {
|
||||
return w.client.Close()
|
||||
}
|
||||
|
||||
// A wgClient is a type which can control a WireGuard device.
|
||||
type wgClient interface {
|
||||
io.Closer
|
||||
Device(name string) (*wgtypes.Device, error)
|
||||
ConfigureDevice(name string, cfg wgtypes.Config) error
|
||||
}
|
||||
|
||||
func transformToWgpeer(corePeers []peer.Peer, excludedIP string) ([]wgtypes.Peer, error) {
|
||||
var wgPeers []wgtypes.Peer
|
||||
for _, peer := range corePeers {
|
||||
if peer.VPNIP == excludedIP {
|
||||
continue
|
||||
}
|
||||
key, err := wgtypes.NewKey(peer.VPNPubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, allowedIPs, err := net.ParseCIDR(peer.VPNIP + "/32")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
publicIP, _, err := net.SplitHostPort(peer.PublicEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var endpoint *net.UDPAddr
|
||||
if ip := net.ParseIP(publicIP); ip != nil {
|
||||
endpoint = &net.UDPAddr{IP: ip, Port: port}
|
||||
}
|
||||
wgPeers = append(wgPeers, wgtypes.Peer{
|
||||
PublicKey: key,
|
||||
Endpoint: endpoint,
|
||||
AllowedIPs: []net.IPNet{*allowedIPs},
|
||||
})
|
||||
}
|
||||
return wgPeers, nil
|
||||
}
|
||||
|
158
coordinator/wireguard/wireguard_test.go
Normal file
158
coordinator/wireguard/wireguard_test.go
Normal file
@ -0,0 +1,158 @@
|
||||
package wireguard
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func TestUpdatePeer(t *testing.T) {
|
||||
requirePre := require.New(t)
|
||||
|
||||
firstKey, err := wgtypes.GenerateKey()
|
||||
requirePre.NoError(err)
|
||||
peer1 := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKey[:]}
|
||||
firstKeyUpd, err := wgtypes.GenerateKey()
|
||||
requirePre.NoError(err)
|
||||
peer1KeyUpd := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKeyUpd[:]}
|
||||
peer1EndpUpd := peer.Peer{PublicEndpoint: "192.0.2.110:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKey[:]}
|
||||
secondKey, err := wgtypes.GenerateKey()
|
||||
requirePre.NoError(err)
|
||||
peer2 := peer.Peer{PublicEndpoint: "192.0.2.12:2000", VPNIP: "192.0.2.22", VPNPubKey: secondKey[:]}
|
||||
thirdKey, err := wgtypes.GenerateKey()
|
||||
requirePre.NoError(err)
|
||||
peer3 := peer.Peer{PublicEndpoint: "192.0.2.13:2000", VPNIP: "192.0.2.23", VPNPubKey: thirdKey[:]}
|
||||
fourthKey, err := wgtypes.GenerateKey()
|
||||
requirePre.NoError(err)
|
||||
peerSelf := peer.Peer{PublicEndpoint: "192.0.2.10:2000", VPNIP: "192.0.2.20", VPNPubKey: fourthKey[:]}
|
||||
|
||||
checkError := func(peers []wgtypes.Peer, err error) []wgtypes.Peer {
|
||||
requirePre.NoError(err)
|
||||
return peers
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
storePeers []peer.Peer
|
||||
vpnPeers []wgtypes.Peer
|
||||
expectErr bool
|
||||
expectedVPNPeers []wgtypes.Peer
|
||||
}{
|
||||
"basic": {
|
||||
storePeers: []peer.Peer{peer1, peer3},
|
||||
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer3}, "")),
|
||||
},
|
||||
"previously empty": {
|
||||
storePeers: []peer.Peer{peer1, peer2},
|
||||
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||
},
|
||||
"no changes": {
|
||||
storePeers: []peer.Peer{peer1, peer2},
|
||||
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||
},
|
||||
"key update": {
|
||||
storePeers: []peer.Peer{peer1KeyUpd, peer3},
|
||||
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1KeyUpd, peer3}, "")),
|
||||
},
|
||||
"public endpoint update": {
|
||||
storePeers: []peer.Peer{peer1EndpUpd, peer3},
|
||||
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1EndpUpd, peer3}, "")),
|
||||
},
|
||||
"dont add self": {
|
||||
storePeers: []peer.Peer{peerSelf, peer3},
|
||||
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer2, peer3}, "")),
|
||||
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer3}, "")),
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fakewg := fakewgClient{}
|
||||
fakewg.devices = make(map[string]*wgtypes.Device)
|
||||
wg := Wireguard{client: &fakewg, getInterfaceIP: func(s string) (string, error) {
|
||||
return "192.0.2.20", nil
|
||||
}}
|
||||
|
||||
fakewg.devices[netInterface] = &wgtypes.Device{Peers: tc.vpnPeers}
|
||||
|
||||
updateErr := wg.UpdatePeers(tc.storePeers)
|
||||
|
||||
if tc.expectErr {
|
||||
assert.Error(updateErr)
|
||||
return
|
||||
}
|
||||
require.NoError(updateErr)
|
||||
|
||||
assert.ElementsMatch(tc.expectedVPNPeers, fakewg.devices[netInterface].Peers)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakewgClient struct {
|
||||
devices map[string]*wgtypes.Device
|
||||
}
|
||||
|
||||
func (w *fakewgClient) Device(name string) (*wgtypes.Device, error) {
|
||||
if val, ok := w.devices[name]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return nil, errors.New("device does not exist")
|
||||
}
|
||||
|
||||
func (w *fakewgClient) ConfigureDevice(name string, cfg wgtypes.Config) error {
|
||||
var newPeerList []wgtypes.Peer
|
||||
var operation bool
|
||||
vpnPeers := make(map[wgtypes.Key]wgtypes.Peer)
|
||||
|
||||
for _, peer := range w.devices[netInterface].Peers {
|
||||
vpnPeers[peer.PublicKey] = peer
|
||||
}
|
||||
|
||||
for _, configPeer := range cfg.Peers {
|
||||
operation = false
|
||||
for _, vpnPeer := range w.devices[netInterface].Peers {
|
||||
// wireguard matches internally via pubkey
|
||||
if vpnPeer.PublicKey == configPeer.PublicKey {
|
||||
operation = true
|
||||
if configPeer.Remove {
|
||||
delete(vpnPeers, vpnPeer.PublicKey)
|
||||
continue
|
||||
}
|
||||
if configPeer.UpdateOnly {
|
||||
vpnPeers[vpnPeer.PublicKey] = wgtypes.Peer{
|
||||
PublicKey: vpnPeer.PublicKey,
|
||||
AllowedIPs: vpnPeer.AllowedIPs,
|
||||
Endpoint: configPeer.Endpoint,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !operation {
|
||||
vpnPeers[configPeer.PublicKey] = wgtypes.Peer{
|
||||
PublicKey: configPeer.PublicKey,
|
||||
AllowedIPs: configPeer.AllowedIPs,
|
||||
Endpoint: configPeer.Endpoint,
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, peer := range vpnPeers {
|
||||
newPeerList = append(newPeerList, peer)
|
||||
}
|
||||
w.devices[netInterface].Peers = newPeerList
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *fakewgClient) Close() error {
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user