mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
move updatePeers directly to the VPN and omit the store layer (#4)
This commit is contained in:
parent
6bbb783af8
commit
6f695892bf
@ -57,7 +57,11 @@ func main() {
|
|||||||
}
|
}
|
||||||
zapLoggerCore := zapLogger.Named("core")
|
zapLoggerCore := zapLogger.Named("core")
|
||||||
|
|
||||||
wg := wireguard.New()
|
wg, err := wireguard.New()
|
||||||
|
if err != nil {
|
||||||
|
zapLogger.Panic("error opening wgctrl client")
|
||||||
|
}
|
||||||
|
defer wg.Close()
|
||||||
|
|
||||||
var issuer core.QuoteIssuer
|
var issuer core.QuoteIssuer
|
||||||
var validator core.QuoteValidator
|
var validator core.QuoteValidator
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||||
"github.com/edgelesssys/constellation/coordinator/core"
|
"github.com/edgelesssys/constellation/coordinator/core"
|
||||||
"github.com/edgelesssys/constellation/coordinator/kms"
|
"github.com/edgelesssys/constellation/coordinator/kms"
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi"
|
"github.com/edgelesssys/constellation/coordinator/pubapi"
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
"github.com/edgelesssys/constellation/coordinator/store"
|
"github.com/edgelesssys/constellation/coordinator/store"
|
||||||
@ -339,6 +340,19 @@ func (v *fakeVPN) RemovePeer(pubKey []byte) error {
|
|||||||
panic("dummy")
|
panic("dummy")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *fakeVPN) UpdatePeers(peers []peer.Peer) error {
|
||||||
|
for _, peer := range peers {
|
||||||
|
peerIP, _, err := net.SplitHostPort(peer.PublicEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := v.AddPeer(peer.VPNPubKey, peerIP, peer.VPNIP); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (v *fakeVPN) send(dst string, data string) {
|
func (v *fakeVPN) send(dst string, data string) {
|
||||||
pubdst := v.peers[dst]
|
pubdst := v.peers[dst]
|
||||||
packets := v.netw.packets
|
packets := v.netw.packets
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||||
"github.com/edgelesssys/constellation/coordinator/storewrapper"
|
"github.com/edgelesssys/constellation/coordinator/storewrapper"
|
||||||
"go.uber.org/multierr"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -78,45 +77,5 @@ func (c *Core) AddPeer(peer peer.Peer) error {
|
|||||||
|
|
||||||
// UpdatePeers synchronizes the peers known to the store and the vpn with the passed peers.
|
// UpdatePeers synchronizes the peers known to the store and the vpn with the passed peers.
|
||||||
func (c *Core) UpdatePeers(peers []peer.Peer) error {
|
func (c *Core) UpdatePeers(peers []peer.Peer) error {
|
||||||
// exclude myself
|
return c.vpn.UpdatePeers(peers)
|
||||||
myIP, err := c.vpn.GetInterfaceIP()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for i, p := range peers {
|
|
||||||
if p.VPNIP == myIP {
|
|
||||||
peers = append(peers[:i], peers[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := c.store.BeginTransaction()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
added, removed, err := storewrapper.StoreWrapper{Store: tx}.UpdatePeers(peers)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// perform remove and add on the vpn
|
|
||||||
var vpnErr error
|
|
||||||
for _, p := range removed {
|
|
||||||
vpnErr = multierr.Append(vpnErr, c.vpn.RemovePeer(p.VPNPubKey))
|
|
||||||
}
|
|
||||||
for _, p := range added {
|
|
||||||
publicIP, _, err := net.SplitHostPort(p.PublicEndpoint)
|
|
||||||
if err != nil {
|
|
||||||
vpnErr = multierr.Append(vpnErr, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
vpnErr = multierr.Append(vpnErr, c.vpn.AddPeer(p.VPNPubKey, publicIP, p.VPNIP))
|
|
||||||
}
|
|
||||||
if vpnErr != nil {
|
|
||||||
return vpnErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/peer"
|
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||||
@ -139,129 +138,3 @@ func TestAddPeer(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdatePeer(t *testing.T) {
|
|
||||||
someErr := errors.New("failed")
|
|
||||||
peer1 := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1}}
|
|
||||||
peer1KeyUpd := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1, 1}}
|
|
||||||
peer1EndpUpd := peer.Peer{PublicEndpoint: "192.0.2.110:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1}}
|
|
||||||
peer2 := peer.Peer{PublicEndpoint: "192.0.2.12:2000", VPNIP: "192.0.2.22", VPNPubKey: []byte{2}}
|
|
||||||
peer3 := peer.Peer{PublicEndpoint: "192.0.2.13:2000", VPNIP: "192.0.2.23", VPNPubKey: []byte{3}}
|
|
||||||
|
|
||||||
makeVPNPeers := func(peers ...peer.Peer) []stubVPNPeer {
|
|
||||||
var result []stubVPNPeer
|
|
||||||
for _, p := range peers {
|
|
||||||
publicIP, _, err := net.SplitHostPort(p.PublicEndpoint)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
result = append(result, stubVPNPeer{pubKey: p.VPNPubKey, publicIP: publicIP, vpnIP: p.VPNIP})
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := map[string]struct {
|
|
||||||
peers []peer.Peer
|
|
||||||
storePeers []peer.Peer
|
|
||||||
vpn stubVPN
|
|
||||||
expectErr bool
|
|
||||||
expectedVPNPeers []stubVPNPeer
|
|
||||||
expectedStorePeers []peer.Peer
|
|
||||||
}{
|
|
||||||
"basic": {
|
|
||||||
peers: []peer.Peer{peer1, peer3},
|
|
||||||
storePeers: []peer.Peer{peer1, peer2},
|
|
||||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
|
|
||||||
expectedVPNPeers: makeVPNPeers(peer1, peer3),
|
|
||||||
expectedStorePeers: []peer.Peer{peer1, peer3},
|
|
||||||
},
|
|
||||||
"previously empty": {
|
|
||||||
peers: []peer.Peer{peer1, peer2},
|
|
||||||
expectedVPNPeers: makeVPNPeers(peer1, peer2),
|
|
||||||
expectedStorePeers: []peer.Peer{peer1, peer2},
|
|
||||||
},
|
|
||||||
"no changes": {
|
|
||||||
peers: []peer.Peer{peer1, peer2},
|
|
||||||
storePeers: []peer.Peer{peer1, peer2},
|
|
||||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
|
|
||||||
expectedVPNPeers: makeVPNPeers(peer1, peer2),
|
|
||||||
expectedStorePeers: []peer.Peer{peer1, peer2},
|
|
||||||
},
|
|
||||||
"key update": {
|
|
||||||
peers: []peer.Peer{peer1KeyUpd, peer3},
|
|
||||||
storePeers: []peer.Peer{peer1, peer2},
|
|
||||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
|
|
||||||
expectedVPNPeers: makeVPNPeers(peer1KeyUpd, peer3),
|
|
||||||
expectedStorePeers: []peer.Peer{peer1KeyUpd, peer3},
|
|
||||||
},
|
|
||||||
"public endpoint update": {
|
|
||||||
peers: []peer.Peer{peer1EndpUpd, peer3},
|
|
||||||
storePeers: []peer.Peer{peer1, peer2},
|
|
||||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
|
|
||||||
expectedVPNPeers: makeVPNPeers(peer1EndpUpd, peer3),
|
|
||||||
expectedStorePeers: []peer.Peer{peer1EndpUpd, peer3},
|
|
||||||
},
|
|
||||||
"don't add self": {
|
|
||||||
peers: []peer.Peer{peer1, peer3},
|
|
||||||
storePeers: []peer.Peer{peer1, peer2},
|
|
||||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), interfaceIP: peer3.VPNIP},
|
|
||||||
expectedVPNPeers: makeVPNPeers(peer1),
|
|
||||||
expectedStorePeers: []peer.Peer{peer1},
|
|
||||||
},
|
|
||||||
"public endpoint without port": {
|
|
||||||
peers: []peer.Peer{
|
|
||||||
peer1,
|
|
||||||
{
|
|
||||||
PublicEndpoint: "192.0.2.13",
|
|
||||||
VPNIP: "192.0.2.23",
|
|
||||||
VPNPubKey: []byte{3},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
storePeers: []peer.Peer{peer1, peer2},
|
|
||||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
"vpn add peer error": {
|
|
||||||
peers: []peer.Peer{peer1, peer3},
|
|
||||||
storePeers: []peer.Peer{peer1, peer2},
|
|
||||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), addPeerErr: someErr},
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
"vpn remove peer error": {
|
|
||||||
peers: []peer.Peer{peer1, peer3},
|
|
||||||
storePeers: []peer.Peer{peer1, peer2},
|
|
||||||
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), removePeerErr: someErr},
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
require := require.New(t)
|
|
||||||
|
|
||||||
core, err := NewCore(&tc.vpn, nil, nil, nil, nil, zaptest.NewLogger(t), nil, nil)
|
|
||||||
require.NoError(err)
|
|
||||||
|
|
||||||
// prepare store
|
|
||||||
for _, p := range tc.storePeers {
|
|
||||||
require.NoError(core.data().PutPeer(p))
|
|
||||||
}
|
|
||||||
|
|
||||||
updateErr := core.UpdatePeers(tc.peers)
|
|
||||||
|
|
||||||
actualStorePeers, err := core.data().GetPeers()
|
|
||||||
require.NoError(err)
|
|
||||||
|
|
||||||
if tc.expectErr {
|
|
||||||
assert.Error(updateErr)
|
|
||||||
assert.ElementsMatch(tc.storePeers, actualStorePeers, "store has been changed despite failure")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
require.NoError(updateErr)
|
|
||||||
|
|
||||||
assert.ElementsMatch(tc.expectedVPNPeers, tc.vpn.peers)
|
|
||||||
assert.ElementsMatch(tc.expectedStorePeers, actualStorePeers)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -3,6 +3,9 @@ package core
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
type VPN interface {
|
type VPN interface {
|
||||||
@ -12,6 +15,7 @@ type VPN interface {
|
|||||||
SetInterfaceIP(ip string) error
|
SetInterfaceIP(ip string) error
|
||||||
AddPeer(pubKey []byte, publicIP string, vpnIP string) error
|
AddPeer(pubKey []byte, publicIP string, vpnIP string) error
|
||||||
RemovePeer(pubKey []byte) error
|
RemovePeer(pubKey []byte) error
|
||||||
|
UpdatePeers(peers []peer.Peer) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubVPN struct {
|
type stubVPN struct {
|
||||||
@ -57,6 +61,19 @@ func (v *stubVPN) RemovePeer(pubKey []byte) error {
|
|||||||
return v.removePeerErr
|
return v.removePeerErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *stubVPN) UpdatePeers(peers []peer.Peer) error {
|
||||||
|
for _, peer := range peers {
|
||||||
|
peerIP, _, err := net.SplitHostPort(peer.PublicEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := v.AddPeer(peer.VPNPubKey, peerIP, peer.VPNIP); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type stubVPNPeer struct {
|
type stubVPNPeer struct {
|
||||||
pubKey []byte
|
pubKey []byte
|
||||||
publicIP string
|
publicIP string
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
package wireguard
|
package wireguard
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||||
"github.com/edgelesssys/constellation/coordinator/util"
|
"github.com/edgelesssys/constellation/coordinator/util"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
@ -18,20 +21,22 @@ const (
|
|||||||
port = 51820
|
port = 51820
|
||||||
)
|
)
|
||||||
|
|
||||||
type Wireguard struct{}
|
type Wireguard struct {
|
||||||
|
client wgClient
|
||||||
|
getInterfaceIP func(string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
func New() *Wireguard {
|
func New() (*Wireguard, error) {
|
||||||
return &Wireguard{}
|
client, err := wgctrl.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Wireguard{client: client, getInterfaceIP: util.GetInterfaceIP}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Wireguard) Setup(privKey []byte) ([]byte, error) {
|
func (w *Wireguard) Setup(privKey []byte) ([]byte, error) {
|
||||||
client, err := wgctrl.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to open wgctrl: %w", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
var key wgtypes.Key
|
var key wgtypes.Key
|
||||||
|
var err error
|
||||||
if len(privKey) == 0 {
|
if len(privKey) == 0 {
|
||||||
key, err = wgtypes.GeneratePrivateKey()
|
key, err = wgtypes.GeneratePrivateKey()
|
||||||
} else {
|
} else {
|
||||||
@ -42,7 +47,7 @@ func (w *Wireguard) Setup(privKey []byte) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
listenPort := port
|
listenPort := port
|
||||||
if err := client.ConfigureDevice(netInterface, wgtypes.Config{PrivateKey: &key, ListenPort: &listenPort}); err != nil {
|
if err := w.client.ConfigureDevice(netInterface, wgtypes.Config{PrivateKey: &key, ListenPort: &listenPort}); err != nil {
|
||||||
return nil, prettyWgError(err)
|
return nil, prettyWgError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,7 +64,7 @@ func (w *Wireguard) GetPublicKey(privKey []byte) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *Wireguard) GetInterfaceIP() (string, error) {
|
func (w *Wireguard) GetInterfaceIP() (string, error) {
|
||||||
return util.GetInterfaceIP(netInterface)
|
return w.getInterfaceIP(netInterface)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInterfaceIP sets the ip interface ip.
|
// SetInterfaceIP sets the ip interface ip.
|
||||||
@ -80,12 +85,6 @@ func (w *Wireguard) SetInterfaceIP(ip string) error {
|
|||||||
|
|
||||||
// AddPeer adds a new peer to a wireguard interface.
|
// AddPeer adds a new peer to a wireguard interface.
|
||||||
func (w *Wireguard) AddPeer(pubKey []byte, publicIP string, vpnIP string) error {
|
func (w *Wireguard) AddPeer(pubKey []byte, publicIP string, vpnIP string) error {
|
||||||
client, err := wgctrl.New()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to open wgctrl: %w", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
_, allowedIPs, err := net.ParseCIDR(vpnIP + "/32")
|
_, allowedIPs, err := net.ParseCIDR(vpnIP + "/32")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -115,17 +114,11 @@ func (w *Wireguard) AddPeer(pubKey []byte, publicIP string, vpnIP string) error
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return prettyWgError(client.ConfigureDevice(netInterface, cfg))
|
return prettyWgError(w.client.ConfigureDevice(netInterface, cfg))
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePeer removes a peer from the wireguard interface.
|
// RemovePeer removes a peer from the wireguard interface.
|
||||||
func (w *Wireguard) RemovePeer(pubKey []byte) error {
|
func (w *Wireguard) RemovePeer(pubKey []byte) error {
|
||||||
client, err := wgctrl.New()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to open wgctrl: %w", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
key, err := wgtypes.NewKey(pubKey)
|
key, err := wgtypes.NewKey(pubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -133,7 +126,7 @@ func (w *Wireguard) RemovePeer(pubKey []byte) error {
|
|||||||
|
|
||||||
cfg := wgtypes.Config{Peers: []wgtypes.PeerConfig{{PublicKey: key, Remove: true}}}
|
cfg := wgtypes.Config{Peers: []wgtypes.PeerConfig{{PublicKey: key, Remove: true}}}
|
||||||
|
|
||||||
return prettyWgError(client.ConfigureDevice(netInterface, cfg))
|
return prettyWgError(w.client.ConfigureDevice(netInterface, cfg))
|
||||||
}
|
}
|
||||||
|
|
||||||
func prettyWgError(err error) error {
|
func prettyWgError(err error) error {
|
||||||
@ -142,3 +135,126 @@ func prettyWgError(err error) error {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *Wireguard) UpdatePeers(peers []peer.Peer) error {
|
||||||
|
ownVPNIP, err := w.getInterfaceIP(netInterface)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to obtain vpn ip: %w", err)
|
||||||
|
}
|
||||||
|
wgPeers, err := transformToWgpeer(peers, ownVPNIP)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to transform peers to wireguard-peers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceData, err := w.client.Device(netInterface)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to obtain device data: %w", err)
|
||||||
|
}
|
||||||
|
// convert to map for easier lookup
|
||||||
|
storePeers := make(map[string]wgtypes.Peer)
|
||||||
|
for _, p := range wgPeers {
|
||||||
|
storePeers[p.AllowedIPs[0].String()] = p
|
||||||
|
}
|
||||||
|
var added []wgtypes.Peer
|
||||||
|
var removed []wgtypes.Peer
|
||||||
|
var updated []wgtypes.Peer
|
||||||
|
|
||||||
|
for _, interfacePeer := range deviceData.Peers {
|
||||||
|
if updPeer, ok := storePeers[interfacePeer.AllowedIPs[0].String()]; ok {
|
||||||
|
if updPeer.Endpoint.String() != interfacePeer.Endpoint.String() {
|
||||||
|
updated = append(updated, updPeer)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(updPeer.PublicKey[:], interfacePeer.PublicKey[:]) {
|
||||||
|
added = append(added, updPeer)
|
||||||
|
removed = append(removed, interfacePeer)
|
||||||
|
}
|
||||||
|
delete(storePeers, updPeer.AllowedIPs[0].String())
|
||||||
|
} else {
|
||||||
|
removed = append(removed, interfacePeer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// remaining store peers are new ones
|
||||||
|
for _, peer := range storePeers {
|
||||||
|
added = append(added, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
keepAlive := 10 * time.Second
|
||||||
|
var newPeerConfig []wgtypes.PeerConfig
|
||||||
|
for _, peer := range removed {
|
||||||
|
newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{
|
||||||
|
// pub Key for remove matching is enought
|
||||||
|
PublicKey: peer.PublicKey,
|
||||||
|
Remove: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, peer := range updated {
|
||||||
|
newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{
|
||||||
|
PublicKey: peer.PublicKey,
|
||||||
|
Remove: false,
|
||||||
|
UpdateOnly: true,
|
||||||
|
Endpoint: peer.Endpoint,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, peer := range added {
|
||||||
|
newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{
|
||||||
|
PublicKey: peer.PublicKey,
|
||||||
|
Remove: false,
|
||||||
|
UpdateOnly: false,
|
||||||
|
Endpoint: peer.Endpoint,
|
||||||
|
AllowedIPs: peer.AllowedIPs,
|
||||||
|
// needed, otherwise gRPC has problems establishing the initial connection.
|
||||||
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(newPeerConfig) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cfg := wgtypes.Config{
|
||||||
|
ReplacePeers: false,
|
||||||
|
Peers: newPeerConfig,
|
||||||
|
}
|
||||||
|
return prettyWgError(w.client.ConfigureDevice(netInterface, cfg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Wireguard) Close() error {
|
||||||
|
return w.client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// A wgClient is a type which can control a WireGuard device.
|
||||||
|
type wgClient interface {
|
||||||
|
io.Closer
|
||||||
|
Device(name string) (*wgtypes.Device, error)
|
||||||
|
ConfigureDevice(name string, cfg wgtypes.Config) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func transformToWgpeer(corePeers []peer.Peer, excludedIP string) ([]wgtypes.Peer, error) {
|
||||||
|
var wgPeers []wgtypes.Peer
|
||||||
|
for _, peer := range corePeers {
|
||||||
|
if peer.VPNIP == excludedIP {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key, err := wgtypes.NewKey(peer.VPNPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, allowedIPs, err := net.ParseCIDR(peer.VPNIP + "/32")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
publicIP, _, err := net.SplitHostPort(peer.PublicEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var endpoint *net.UDPAddr
|
||||||
|
if ip := net.ParseIP(publicIP); ip != nil {
|
||||||
|
endpoint = &net.UDPAddr{IP: ip, Port: port}
|
||||||
|
}
|
||||||
|
wgPeers = append(wgPeers, wgtypes.Peer{
|
||||||
|
PublicKey: key,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
AllowedIPs: []net.IPNet{*allowedIPs},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return wgPeers, nil
|
||||||
|
}
|
||||||
|
158
coordinator/wireguard/wireguard_test.go
Normal file
158
coordinator/wireguard/wireguard_test.go
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
package wireguard
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/edgelesssys/constellation/coordinator/peer"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdatePeer(t *testing.T) {
|
||||||
|
requirePre := require.New(t)
|
||||||
|
|
||||||
|
firstKey, err := wgtypes.GenerateKey()
|
||||||
|
requirePre.NoError(err)
|
||||||
|
peer1 := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKey[:]}
|
||||||
|
firstKeyUpd, err := wgtypes.GenerateKey()
|
||||||
|
requirePre.NoError(err)
|
||||||
|
peer1KeyUpd := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKeyUpd[:]}
|
||||||
|
peer1EndpUpd := peer.Peer{PublicEndpoint: "192.0.2.110:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKey[:]}
|
||||||
|
secondKey, err := wgtypes.GenerateKey()
|
||||||
|
requirePre.NoError(err)
|
||||||
|
peer2 := peer.Peer{PublicEndpoint: "192.0.2.12:2000", VPNIP: "192.0.2.22", VPNPubKey: secondKey[:]}
|
||||||
|
thirdKey, err := wgtypes.GenerateKey()
|
||||||
|
requirePre.NoError(err)
|
||||||
|
peer3 := peer.Peer{PublicEndpoint: "192.0.2.13:2000", VPNIP: "192.0.2.23", VPNPubKey: thirdKey[:]}
|
||||||
|
fourthKey, err := wgtypes.GenerateKey()
|
||||||
|
requirePre.NoError(err)
|
||||||
|
peerSelf := peer.Peer{PublicEndpoint: "192.0.2.10:2000", VPNIP: "192.0.2.20", VPNPubKey: fourthKey[:]}
|
||||||
|
|
||||||
|
checkError := func(peers []wgtypes.Peer, err error) []wgtypes.Peer {
|
||||||
|
requirePre.NoError(err)
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
storePeers []peer.Peer
|
||||||
|
vpnPeers []wgtypes.Peer
|
||||||
|
expectErr bool
|
||||||
|
expectedVPNPeers []wgtypes.Peer
|
||||||
|
}{
|
||||||
|
"basic": {
|
||||||
|
storePeers: []peer.Peer{peer1, peer3},
|
||||||
|
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||||
|
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer3}, "")),
|
||||||
|
},
|
||||||
|
"previously empty": {
|
||||||
|
storePeers: []peer.Peer{peer1, peer2},
|
||||||
|
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||||
|
},
|
||||||
|
"no changes": {
|
||||||
|
storePeers: []peer.Peer{peer1, peer2},
|
||||||
|
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||||
|
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||||
|
},
|
||||||
|
"key update": {
|
||||||
|
storePeers: []peer.Peer{peer1KeyUpd, peer3},
|
||||||
|
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||||
|
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1KeyUpd, peer3}, "")),
|
||||||
|
},
|
||||||
|
"public endpoint update": {
|
||||||
|
storePeers: []peer.Peer{peer1EndpUpd, peer3},
|
||||||
|
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
|
||||||
|
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1EndpUpd, peer3}, "")),
|
||||||
|
},
|
||||||
|
"dont add self": {
|
||||||
|
storePeers: []peer.Peer{peerSelf, peer3},
|
||||||
|
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer2, peer3}, "")),
|
||||||
|
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer3}, "")),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
require := require.New(t)
|
||||||
|
|
||||||
|
fakewg := fakewgClient{}
|
||||||
|
fakewg.devices = make(map[string]*wgtypes.Device)
|
||||||
|
wg := Wireguard{client: &fakewg, getInterfaceIP: func(s string) (string, error) {
|
||||||
|
return "192.0.2.20", nil
|
||||||
|
}}
|
||||||
|
|
||||||
|
fakewg.devices[netInterface] = &wgtypes.Device{Peers: tc.vpnPeers}
|
||||||
|
|
||||||
|
updateErr := wg.UpdatePeers(tc.storePeers)
|
||||||
|
|
||||||
|
if tc.expectErr {
|
||||||
|
assert.Error(updateErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(updateErr)
|
||||||
|
|
||||||
|
assert.ElementsMatch(tc.expectedVPNPeers, fakewg.devices[netInterface].Peers)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakewgClient struct {
|
||||||
|
devices map[string]*wgtypes.Device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *fakewgClient) Device(name string) (*wgtypes.Device, error) {
|
||||||
|
if val, ok := w.devices[name]; ok {
|
||||||
|
return val, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("device does not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *fakewgClient) ConfigureDevice(name string, cfg wgtypes.Config) error {
|
||||||
|
var newPeerList []wgtypes.Peer
|
||||||
|
var operation bool
|
||||||
|
vpnPeers := make(map[wgtypes.Key]wgtypes.Peer)
|
||||||
|
|
||||||
|
for _, peer := range w.devices[netInterface].Peers {
|
||||||
|
vpnPeers[peer.PublicKey] = peer
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, configPeer := range cfg.Peers {
|
||||||
|
operation = false
|
||||||
|
for _, vpnPeer := range w.devices[netInterface].Peers {
|
||||||
|
// wireguard matches internally via pubkey
|
||||||
|
if vpnPeer.PublicKey == configPeer.PublicKey {
|
||||||
|
operation = true
|
||||||
|
if configPeer.Remove {
|
||||||
|
delete(vpnPeers, vpnPeer.PublicKey)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if configPeer.UpdateOnly {
|
||||||
|
vpnPeers[vpnPeer.PublicKey] = wgtypes.Peer{
|
||||||
|
PublicKey: vpnPeer.PublicKey,
|
||||||
|
AllowedIPs: vpnPeer.AllowedIPs,
|
||||||
|
Endpoint: configPeer.Endpoint,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !operation {
|
||||||
|
vpnPeers[configPeer.PublicKey] = wgtypes.Peer{
|
||||||
|
PublicKey: configPeer.PublicKey,
|
||||||
|
AllowedIPs: configPeer.AllowedIPs,
|
||||||
|
Endpoint: configPeer.Endpoint,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, peer := range vpnPeers {
|
||||||
|
newPeerList = append(newPeerList, peer)
|
||||||
|
}
|
||||||
|
w.devices[netInterface].Peers = newPeerList
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *fakewgClient) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user