move updatePeers directly to the VPN and omit the store layer (#4)

This commit is contained in:
Benedict Schlüter 2022-03-25 16:05:17 +01:00 committed by GitHub
parent 6bbb783af8
commit 6f695892bf
7 changed files with 336 additions and 195 deletions

View File

@ -57,7 +57,11 @@ func main() {
}
zapLoggerCore := zapLogger.Named("core")
wg := wireguard.New()
wg, err := wireguard.New()
if err != nil {
zapLogger.Panic("error opening wgctrl client")
}
defer wg.Close()
var issuer core.QuoteIssuer
var validator core.QuoteValidator

View File

@ -12,6 +12,7 @@ import (
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/core"
"github.com/edgelesssys/constellation/coordinator/kms"
"github.com/edgelesssys/constellation/coordinator/peer"
"github.com/edgelesssys/constellation/coordinator/pubapi"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/store"
@ -339,6 +340,19 @@ func (v *fakeVPN) RemovePeer(pubKey []byte) error {
panic("dummy")
}
func (v *fakeVPN) UpdatePeers(peers []peer.Peer) error {
for _, peer := range peers {
peerIP, _, err := net.SplitHostPort(peer.PublicEndpoint)
if err != nil {
return err
}
if err := v.AddPeer(peer.VPNPubKey, peerIP, peer.VPNIP); err != nil {
return err
}
}
return nil
}
func (v *fakeVPN) send(dst string, data string) {
pubdst := v.peers[dst]
packets := v.netw.packets

View File

@ -5,7 +5,6 @@ import (
"github.com/edgelesssys/constellation/coordinator/peer"
"github.com/edgelesssys/constellation/coordinator/storewrapper"
"go.uber.org/multierr"
"go.uber.org/zap"
)
@ -78,45 +77,5 @@ func (c *Core) AddPeer(peer peer.Peer) error {
// UpdatePeers synchronizes the peers known to the store and the vpn with the passed peers.
func (c *Core) UpdatePeers(peers []peer.Peer) error {
// exclude myself
myIP, err := c.vpn.GetInterfaceIP()
if err != nil {
return err
}
for i, p := range peers {
if p.VPNIP == myIP {
peers = append(peers[:i], peers[i+1:]...)
break
}
}
tx, err := c.store.BeginTransaction()
if err != nil {
return err
}
defer tx.Rollback()
added, removed, err := storewrapper.StoreWrapper{Store: tx}.UpdatePeers(peers)
if err != nil {
return err
}
// perform remove and add on the vpn
var vpnErr error
for _, p := range removed {
vpnErr = multierr.Append(vpnErr, c.vpn.RemovePeer(p.VPNPubKey))
}
for _, p := range added {
publicIP, _, err := net.SplitHostPort(p.PublicEndpoint)
if err != nil {
vpnErr = multierr.Append(vpnErr, err)
continue
}
vpnErr = multierr.Append(vpnErr, c.vpn.AddPeer(p.VPNPubKey, publicIP, p.VPNIP))
}
if vpnErr != nil {
return vpnErr
}
return tx.Commit()
return c.vpn.UpdatePeers(peers)
}

View File

@ -2,7 +2,6 @@ package core
import (
"errors"
"net"
"testing"
"github.com/edgelesssys/constellation/coordinator/peer"
@ -139,129 +138,3 @@ func TestAddPeer(t *testing.T) {
})
}
}
func TestUpdatePeer(t *testing.T) {
someErr := errors.New("failed")
peer1 := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1}}
peer1KeyUpd := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1, 1}}
peer1EndpUpd := peer.Peer{PublicEndpoint: "192.0.2.110:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1}}
peer2 := peer.Peer{PublicEndpoint: "192.0.2.12:2000", VPNIP: "192.0.2.22", VPNPubKey: []byte{2}}
peer3 := peer.Peer{PublicEndpoint: "192.0.2.13:2000", VPNIP: "192.0.2.23", VPNPubKey: []byte{3}}
makeVPNPeers := func(peers ...peer.Peer) []stubVPNPeer {
var result []stubVPNPeer
for _, p := range peers {
publicIP, _, err := net.SplitHostPort(p.PublicEndpoint)
if err != nil {
panic(err)
}
result = append(result, stubVPNPeer{pubKey: p.VPNPubKey, publicIP: publicIP, vpnIP: p.VPNIP})
}
return result
}
testCases := map[string]struct {
peers []peer.Peer
storePeers []peer.Peer
vpn stubVPN
expectErr bool
expectedVPNPeers []stubVPNPeer
expectedStorePeers []peer.Peer
}{
"basic": {
peers: []peer.Peer{peer1, peer3},
storePeers: []peer.Peer{peer1, peer2},
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
expectedVPNPeers: makeVPNPeers(peer1, peer3),
expectedStorePeers: []peer.Peer{peer1, peer3},
},
"previously empty": {
peers: []peer.Peer{peer1, peer2},
expectedVPNPeers: makeVPNPeers(peer1, peer2),
expectedStorePeers: []peer.Peer{peer1, peer2},
},
"no changes": {
peers: []peer.Peer{peer1, peer2},
storePeers: []peer.Peer{peer1, peer2},
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
expectedVPNPeers: makeVPNPeers(peer1, peer2),
expectedStorePeers: []peer.Peer{peer1, peer2},
},
"key update": {
peers: []peer.Peer{peer1KeyUpd, peer3},
storePeers: []peer.Peer{peer1, peer2},
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
expectedVPNPeers: makeVPNPeers(peer1KeyUpd, peer3),
expectedStorePeers: []peer.Peer{peer1KeyUpd, peer3},
},
"public endpoint update": {
peers: []peer.Peer{peer1EndpUpd, peer3},
storePeers: []peer.Peer{peer1, peer2},
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
expectedVPNPeers: makeVPNPeers(peer1EndpUpd, peer3),
expectedStorePeers: []peer.Peer{peer1EndpUpd, peer3},
},
"don't add self": {
peers: []peer.Peer{peer1, peer3},
storePeers: []peer.Peer{peer1, peer2},
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), interfaceIP: peer3.VPNIP},
expectedVPNPeers: makeVPNPeers(peer1),
expectedStorePeers: []peer.Peer{peer1},
},
"public endpoint without port": {
peers: []peer.Peer{
peer1,
{
PublicEndpoint: "192.0.2.13",
VPNIP: "192.0.2.23",
VPNPubKey: []byte{3},
},
},
storePeers: []peer.Peer{peer1, peer2},
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)},
expectErr: true,
},
"vpn add peer error": {
peers: []peer.Peer{peer1, peer3},
storePeers: []peer.Peer{peer1, peer2},
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), addPeerErr: someErr},
expectErr: true,
},
"vpn remove peer error": {
peers: []peer.Peer{peer1, peer3},
storePeers: []peer.Peer{peer1, peer2},
vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), removePeerErr: someErr},
expectErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
core, err := NewCore(&tc.vpn, nil, nil, nil, nil, zaptest.NewLogger(t), nil, nil)
require.NoError(err)
// prepare store
for _, p := range tc.storePeers {
require.NoError(core.data().PutPeer(p))
}
updateErr := core.UpdatePeers(tc.peers)
actualStorePeers, err := core.data().GetPeers()
require.NoError(err)
if tc.expectErr {
assert.Error(updateErr)
assert.ElementsMatch(tc.storePeers, actualStorePeers, "store has been changed despite failure")
return
}
require.NoError(updateErr)
assert.ElementsMatch(tc.expectedVPNPeers, tc.vpn.peers)
assert.ElementsMatch(tc.expectedStorePeers, actualStorePeers)
})
}
}

View File

@ -3,6 +3,9 @@ package core
import (
"bytes"
"errors"
"net"
"github.com/edgelesssys/constellation/coordinator/peer"
)
type VPN interface {
@ -12,6 +15,7 @@ type VPN interface {
SetInterfaceIP(ip string) error
AddPeer(pubKey []byte, publicIP string, vpnIP string) error
RemovePeer(pubKey []byte) error
UpdatePeers(peers []peer.Peer) error
}
type stubVPN struct {
@ -57,6 +61,19 @@ func (v *stubVPN) RemovePeer(pubKey []byte) error {
return v.removePeerErr
}
func (v *stubVPN) UpdatePeers(peers []peer.Peer) error {
for _, peer := range peers {
peerIP, _, err := net.SplitHostPort(peer.PublicEndpoint)
if err != nil {
return err
}
if err := v.AddPeer(peer.VPNPubKey, peerIP, peer.VPNIP); err != nil {
return err
}
}
return nil
}
type stubVPNPeer struct {
pubKey []byte
publicIP string

View File

@ -1,12 +1,15 @@
package wireguard
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"os"
"time"
"github.com/edgelesssys/constellation/coordinator/peer"
"github.com/edgelesssys/constellation/coordinator/util"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl"
@ -18,20 +21,22 @@ const (
port = 51820
)
type Wireguard struct{}
type Wireguard struct {
client wgClient
getInterfaceIP func(string) (string, error)
}
func New() *Wireguard {
return &Wireguard{}
func New() (*Wireguard, error) {
client, err := wgctrl.New()
if err != nil {
return nil, err
}
return &Wireguard{client: client, getInterfaceIP: util.GetInterfaceIP}, nil
}
func (w *Wireguard) Setup(privKey []byte) ([]byte, error) {
client, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("failed to open wgctrl: %w", err)
}
defer client.Close()
var key wgtypes.Key
var err error
if len(privKey) == 0 {
key, err = wgtypes.GeneratePrivateKey()
} else {
@ -42,7 +47,7 @@ func (w *Wireguard) Setup(privKey []byte) ([]byte, error) {
}
listenPort := port
if err := client.ConfigureDevice(netInterface, wgtypes.Config{PrivateKey: &key, ListenPort: &listenPort}); err != nil {
if err := w.client.ConfigureDevice(netInterface, wgtypes.Config{PrivateKey: &key, ListenPort: &listenPort}); err != nil {
return nil, prettyWgError(err)
}
@ -59,7 +64,7 @@ func (w *Wireguard) GetPublicKey(privKey []byte) ([]byte, error) {
}
func (w *Wireguard) GetInterfaceIP() (string, error) {
return util.GetInterfaceIP(netInterface)
return w.getInterfaceIP(netInterface)
}
// SetInterfaceIP sets the ip interface ip.
@ -80,12 +85,6 @@ func (w *Wireguard) SetInterfaceIP(ip string) error {
// AddPeer adds a new peer to a wireguard interface.
func (w *Wireguard) AddPeer(pubKey []byte, publicIP string, vpnIP string) error {
client, err := wgctrl.New()
if err != nil {
return fmt.Errorf("failed to open wgctrl: %w", err)
}
defer client.Close()
_, allowedIPs, err := net.ParseCIDR(vpnIP + "/32")
if err != nil {
return err
@ -115,17 +114,11 @@ func (w *Wireguard) AddPeer(pubKey []byte, publicIP string, vpnIP string) error
},
}
return prettyWgError(client.ConfigureDevice(netInterface, cfg))
return prettyWgError(w.client.ConfigureDevice(netInterface, cfg))
}
// RemovePeer removes a peer from the wireguard interface.
func (w *Wireguard) RemovePeer(pubKey []byte) error {
client, err := wgctrl.New()
if err != nil {
return fmt.Errorf("failed to open wgctrl: %w", err)
}
defer client.Close()
key, err := wgtypes.NewKey(pubKey)
if err != nil {
return err
@ -133,7 +126,7 @@ func (w *Wireguard) RemovePeer(pubKey []byte) error {
cfg := wgtypes.Config{Peers: []wgtypes.PeerConfig{{PublicKey: key, Remove: true}}}
return prettyWgError(client.ConfigureDevice(netInterface, cfg))
return prettyWgError(w.client.ConfigureDevice(netInterface, cfg))
}
func prettyWgError(err error) error {
@ -142,3 +135,126 @@ func prettyWgError(err error) error {
}
return err
}
func (w *Wireguard) UpdatePeers(peers []peer.Peer) error {
ownVPNIP, err := w.getInterfaceIP(netInterface)
if err != nil {
return fmt.Errorf("failed to obtain vpn ip: %w", err)
}
wgPeers, err := transformToWgpeer(peers, ownVPNIP)
if err != nil {
return fmt.Errorf("failed to transform peers to wireguard-peers: %w", err)
}
deviceData, err := w.client.Device(netInterface)
if err != nil {
return fmt.Errorf("failed to obtain device data: %w", err)
}
// convert to map for easier lookup
storePeers := make(map[string]wgtypes.Peer)
for _, p := range wgPeers {
storePeers[p.AllowedIPs[0].String()] = p
}
var added []wgtypes.Peer
var removed []wgtypes.Peer
var updated []wgtypes.Peer
for _, interfacePeer := range deviceData.Peers {
if updPeer, ok := storePeers[interfacePeer.AllowedIPs[0].String()]; ok {
if updPeer.Endpoint.String() != interfacePeer.Endpoint.String() {
updated = append(updated, updPeer)
}
if !bytes.Equal(updPeer.PublicKey[:], interfacePeer.PublicKey[:]) {
added = append(added, updPeer)
removed = append(removed, interfacePeer)
}
delete(storePeers, updPeer.AllowedIPs[0].String())
} else {
removed = append(removed, interfacePeer)
}
}
// remaining store peers are new ones
for _, peer := range storePeers {
added = append(added, peer)
}
keepAlive := 10 * time.Second
var newPeerConfig []wgtypes.PeerConfig
for _, peer := range removed {
newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{
// pub Key for remove matching is enought
PublicKey: peer.PublicKey,
Remove: true,
})
}
for _, peer := range updated {
newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{
PublicKey: peer.PublicKey,
Remove: false,
UpdateOnly: true,
Endpoint: peer.Endpoint,
})
}
for _, peer := range added {
newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{
PublicKey: peer.PublicKey,
Remove: false,
UpdateOnly: false,
Endpoint: peer.Endpoint,
AllowedIPs: peer.AllowedIPs,
// needed, otherwise gRPC has problems establishing the initial connection.
PersistentKeepaliveInterval: &keepAlive,
})
}
if len(newPeerConfig) == 0 {
return nil
}
cfg := wgtypes.Config{
ReplacePeers: false,
Peers: newPeerConfig,
}
return prettyWgError(w.client.ConfigureDevice(netInterface, cfg))
}
func (w *Wireguard) Close() error {
return w.client.Close()
}
// A wgClient is a type which can control a WireGuard device.
type wgClient interface {
io.Closer
Device(name string) (*wgtypes.Device, error)
ConfigureDevice(name string, cfg wgtypes.Config) error
}
func transformToWgpeer(corePeers []peer.Peer, excludedIP string) ([]wgtypes.Peer, error) {
var wgPeers []wgtypes.Peer
for _, peer := range corePeers {
if peer.VPNIP == excludedIP {
continue
}
key, err := wgtypes.NewKey(peer.VPNPubKey)
if err != nil {
return nil, err
}
_, allowedIPs, err := net.ParseCIDR(peer.VPNIP + "/32")
if err != nil {
return nil, err
}
publicIP, _, err := net.SplitHostPort(peer.PublicEndpoint)
if err != nil {
return nil, err
}
var endpoint *net.UDPAddr
if ip := net.ParseIP(publicIP); ip != nil {
endpoint = &net.UDPAddr{IP: ip, Port: port}
}
wgPeers = append(wgPeers, wgtypes.Peer{
PublicKey: key,
Endpoint: endpoint,
AllowedIPs: []net.IPNet{*allowedIPs},
})
}
return wgPeers, nil
}

View File

@ -0,0 +1,158 @@
package wireguard
import (
"errors"
"testing"
"github.com/edgelesssys/constellation/coordinator/peer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func TestUpdatePeer(t *testing.T) {
requirePre := require.New(t)
firstKey, err := wgtypes.GenerateKey()
requirePre.NoError(err)
peer1 := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKey[:]}
firstKeyUpd, err := wgtypes.GenerateKey()
requirePre.NoError(err)
peer1KeyUpd := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKeyUpd[:]}
peer1EndpUpd := peer.Peer{PublicEndpoint: "192.0.2.110:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKey[:]}
secondKey, err := wgtypes.GenerateKey()
requirePre.NoError(err)
peer2 := peer.Peer{PublicEndpoint: "192.0.2.12:2000", VPNIP: "192.0.2.22", VPNPubKey: secondKey[:]}
thirdKey, err := wgtypes.GenerateKey()
requirePre.NoError(err)
peer3 := peer.Peer{PublicEndpoint: "192.0.2.13:2000", VPNIP: "192.0.2.23", VPNPubKey: thirdKey[:]}
fourthKey, err := wgtypes.GenerateKey()
requirePre.NoError(err)
peerSelf := peer.Peer{PublicEndpoint: "192.0.2.10:2000", VPNIP: "192.0.2.20", VPNPubKey: fourthKey[:]}
checkError := func(peers []wgtypes.Peer, err error) []wgtypes.Peer {
requirePre.NoError(err)
return peers
}
testCases := map[string]struct {
storePeers []peer.Peer
vpnPeers []wgtypes.Peer
expectErr bool
expectedVPNPeers []wgtypes.Peer
}{
"basic": {
storePeers: []peer.Peer{peer1, peer3},
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer3}, "")),
},
"previously empty": {
storePeers: []peer.Peer{peer1, peer2},
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
},
"no changes": {
storePeers: []peer.Peer{peer1, peer2},
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
},
"key update": {
storePeers: []peer.Peer{peer1KeyUpd, peer3},
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1KeyUpd, peer3}, "")),
},
"public endpoint update": {
storePeers: []peer.Peer{peer1EndpUpd, peer3},
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")),
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1EndpUpd, peer3}, "")),
},
"dont add self": {
storePeers: []peer.Peer{peerSelf, peer3},
vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer2, peer3}, "")),
expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer3}, "")),
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fakewg := fakewgClient{}
fakewg.devices = make(map[string]*wgtypes.Device)
wg := Wireguard{client: &fakewg, getInterfaceIP: func(s string) (string, error) {
return "192.0.2.20", nil
}}
fakewg.devices[netInterface] = &wgtypes.Device{Peers: tc.vpnPeers}
updateErr := wg.UpdatePeers(tc.storePeers)
if tc.expectErr {
assert.Error(updateErr)
return
}
require.NoError(updateErr)
assert.ElementsMatch(tc.expectedVPNPeers, fakewg.devices[netInterface].Peers)
})
}
}
type fakewgClient struct {
devices map[string]*wgtypes.Device
}
func (w *fakewgClient) Device(name string) (*wgtypes.Device, error) {
if val, ok := w.devices[name]; ok {
return val, nil
}
return nil, errors.New("device does not exist")
}
func (w *fakewgClient) ConfigureDevice(name string, cfg wgtypes.Config) error {
var newPeerList []wgtypes.Peer
var operation bool
vpnPeers := make(map[wgtypes.Key]wgtypes.Peer)
for _, peer := range w.devices[netInterface].Peers {
vpnPeers[peer.PublicKey] = peer
}
for _, configPeer := range cfg.Peers {
operation = false
for _, vpnPeer := range w.devices[netInterface].Peers {
// wireguard matches internally via pubkey
if vpnPeer.PublicKey == configPeer.PublicKey {
operation = true
if configPeer.Remove {
delete(vpnPeers, vpnPeer.PublicKey)
continue
}
if configPeer.UpdateOnly {
vpnPeers[vpnPeer.PublicKey] = wgtypes.Peer{
PublicKey: vpnPeer.PublicKey,
AllowedIPs: vpnPeer.AllowedIPs,
Endpoint: configPeer.Endpoint,
}
}
}
}
if !operation {
vpnPeers[configPeer.PublicKey] = wgtypes.Peer{
PublicKey: configPeer.PublicKey,
AllowedIPs: configPeer.AllowedIPs,
Endpoint: configPeer.Endpoint,
}
}
}
for _, peer := range vpnPeers {
newPeerList = append(newPeerList, peer)
}
w.devices[netInterface].Peers = newPeerList
return nil
}
func (w *fakewgClient) Close() error {
return nil
}