diff --git a/coordinator/cmd/coordinator/main.go b/coordinator/cmd/coordinator/main.go index e7957e042..fb781a564 100644 --- a/coordinator/cmd/coordinator/main.go +++ b/coordinator/cmd/coordinator/main.go @@ -57,7 +57,11 @@ func main() { } zapLoggerCore := zapLogger.Named("core") - wg := wireguard.New() + wg, err := wireguard.New() + if err != nil { + zapLogger.Panic("error opening wgctrl client") + } + defer wg.Close() var issuer core.QuoteIssuer var validator core.QuoteValidator diff --git a/coordinator/coordinator_test.go b/coordinator/coordinator_test.go index 22c5b72aa..48a4176a1 100644 --- a/coordinator/coordinator_test.go +++ b/coordinator/coordinator_test.go @@ -12,6 +12,7 @@ import ( "github.com/edgelesssys/constellation/coordinator/attestation/vtpm" "github.com/edgelesssys/constellation/coordinator/core" "github.com/edgelesssys/constellation/coordinator/kms" + "github.com/edgelesssys/constellation/coordinator/peer" "github.com/edgelesssys/constellation/coordinator/pubapi" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/store" @@ -339,6 +340,19 @@ func (v *fakeVPN) RemovePeer(pubKey []byte) error { panic("dummy") } +func (v *fakeVPN) UpdatePeers(peers []peer.Peer) error { + for _, peer := range peers { + peerIP, _, err := net.SplitHostPort(peer.PublicEndpoint) + if err != nil { + return err + } + if err := v.AddPeer(peer.VPNPubKey, peerIP, peer.VPNIP); err != nil { + return err + } + } + return nil +} + func (v *fakeVPN) send(dst string, data string) { pubdst := v.peers[dst] packets := v.netw.packets diff --git a/coordinator/core/peer.go b/coordinator/core/peer.go index 62ef11560..695cee787 100644 --- a/coordinator/core/peer.go +++ b/coordinator/core/peer.go @@ -5,7 +5,6 @@ import ( "github.com/edgelesssys/constellation/coordinator/peer" "github.com/edgelesssys/constellation/coordinator/storewrapper" - "go.uber.org/multierr" "go.uber.org/zap" ) @@ -78,45 +77,5 @@ func (c *Core) AddPeer(peer peer.Peer) error { // UpdatePeers synchronizes the peers known to the store and the vpn with the passed peers. func (c *Core) UpdatePeers(peers []peer.Peer) error { - // exclude myself - myIP, err := c.vpn.GetInterfaceIP() - if err != nil { - return err - } - for i, p := range peers { - if p.VPNIP == myIP { - peers = append(peers[:i], peers[i+1:]...) - break - } - } - - tx, err := c.store.BeginTransaction() - if err != nil { - return err - } - defer tx.Rollback() - - added, removed, err := storewrapper.StoreWrapper{Store: tx}.UpdatePeers(peers) - if err != nil { - return err - } - - // perform remove and add on the vpn - var vpnErr error - for _, p := range removed { - vpnErr = multierr.Append(vpnErr, c.vpn.RemovePeer(p.VPNPubKey)) - } - for _, p := range added { - publicIP, _, err := net.SplitHostPort(p.PublicEndpoint) - if err != nil { - vpnErr = multierr.Append(vpnErr, err) - continue - } - vpnErr = multierr.Append(vpnErr, c.vpn.AddPeer(p.VPNPubKey, publicIP, p.VPNIP)) - } - if vpnErr != nil { - return vpnErr - } - - return tx.Commit() + return c.vpn.UpdatePeers(peers) } diff --git a/coordinator/core/peer_test.go b/coordinator/core/peer_test.go index 304f76ce5..a836f71c6 100644 --- a/coordinator/core/peer_test.go +++ b/coordinator/core/peer_test.go @@ -2,7 +2,6 @@ package core import ( "errors" - "net" "testing" "github.com/edgelesssys/constellation/coordinator/peer" @@ -139,129 +138,3 @@ func TestAddPeer(t *testing.T) { }) } } - -func TestUpdatePeer(t *testing.T) { - someErr := errors.New("failed") - peer1 := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1}} - peer1KeyUpd := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1, 1}} - peer1EndpUpd := peer.Peer{PublicEndpoint: "192.0.2.110:2000", VPNIP: "192.0.2.21", VPNPubKey: []byte{1}} - peer2 := peer.Peer{PublicEndpoint: "192.0.2.12:2000", VPNIP: "192.0.2.22", VPNPubKey: []byte{2}} - peer3 := peer.Peer{PublicEndpoint: "192.0.2.13:2000", VPNIP: "192.0.2.23", VPNPubKey: []byte{3}} - - makeVPNPeers := func(peers ...peer.Peer) []stubVPNPeer { - var result []stubVPNPeer - for _, p := range peers { - publicIP, _, err := net.SplitHostPort(p.PublicEndpoint) - if err != nil { - panic(err) - } - result = append(result, stubVPNPeer{pubKey: p.VPNPubKey, publicIP: publicIP, vpnIP: p.VPNIP}) - } - return result - } - - testCases := map[string]struct { - peers []peer.Peer - storePeers []peer.Peer - vpn stubVPN - expectErr bool - expectedVPNPeers []stubVPNPeer - expectedStorePeers []peer.Peer - }{ - "basic": { - peers: []peer.Peer{peer1, peer3}, - storePeers: []peer.Peer{peer1, peer2}, - vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)}, - expectedVPNPeers: makeVPNPeers(peer1, peer3), - expectedStorePeers: []peer.Peer{peer1, peer3}, - }, - "previously empty": { - peers: []peer.Peer{peer1, peer2}, - expectedVPNPeers: makeVPNPeers(peer1, peer2), - expectedStorePeers: []peer.Peer{peer1, peer2}, - }, - "no changes": { - peers: []peer.Peer{peer1, peer2}, - storePeers: []peer.Peer{peer1, peer2}, - vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)}, - expectedVPNPeers: makeVPNPeers(peer1, peer2), - expectedStorePeers: []peer.Peer{peer1, peer2}, - }, - "key update": { - peers: []peer.Peer{peer1KeyUpd, peer3}, - storePeers: []peer.Peer{peer1, peer2}, - vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)}, - expectedVPNPeers: makeVPNPeers(peer1KeyUpd, peer3), - expectedStorePeers: []peer.Peer{peer1KeyUpd, peer3}, - }, - "public endpoint update": { - peers: []peer.Peer{peer1EndpUpd, peer3}, - storePeers: []peer.Peer{peer1, peer2}, - vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)}, - expectedVPNPeers: makeVPNPeers(peer1EndpUpd, peer3), - expectedStorePeers: []peer.Peer{peer1EndpUpd, peer3}, - }, - "don't add self": { - peers: []peer.Peer{peer1, peer3}, - storePeers: []peer.Peer{peer1, peer2}, - vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), interfaceIP: peer3.VPNIP}, - expectedVPNPeers: makeVPNPeers(peer1), - expectedStorePeers: []peer.Peer{peer1}, - }, - "public endpoint without port": { - peers: []peer.Peer{ - peer1, - { - PublicEndpoint: "192.0.2.13", - VPNIP: "192.0.2.23", - VPNPubKey: []byte{3}, - }, - }, - storePeers: []peer.Peer{peer1, peer2}, - vpn: stubVPN{peers: makeVPNPeers(peer1, peer2)}, - expectErr: true, - }, - "vpn add peer error": { - peers: []peer.Peer{peer1, peer3}, - storePeers: []peer.Peer{peer1, peer2}, - vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), addPeerErr: someErr}, - expectErr: true, - }, - "vpn remove peer error": { - peers: []peer.Peer{peer1, peer3}, - storePeers: []peer.Peer{peer1, peer2}, - vpn: stubVPN{peers: makeVPNPeers(peer1, peer2), removePeerErr: someErr}, - expectErr: true, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - require := require.New(t) - - core, err := NewCore(&tc.vpn, nil, nil, nil, nil, zaptest.NewLogger(t), nil, nil) - require.NoError(err) - - // prepare store - for _, p := range tc.storePeers { - require.NoError(core.data().PutPeer(p)) - } - - updateErr := core.UpdatePeers(tc.peers) - - actualStorePeers, err := core.data().GetPeers() - require.NoError(err) - - if tc.expectErr { - assert.Error(updateErr) - assert.ElementsMatch(tc.storePeers, actualStorePeers, "store has been changed despite failure") - return - } - require.NoError(updateErr) - - assert.ElementsMatch(tc.expectedVPNPeers, tc.vpn.peers) - assert.ElementsMatch(tc.expectedStorePeers, actualStorePeers) - }) - } -} diff --git a/coordinator/core/vpn.go b/coordinator/core/vpn.go index 54ab83f9f..50ada6b57 100644 --- a/coordinator/core/vpn.go +++ b/coordinator/core/vpn.go @@ -3,6 +3,9 @@ package core import ( "bytes" "errors" + "net" + + "github.com/edgelesssys/constellation/coordinator/peer" ) type VPN interface { @@ -12,6 +15,7 @@ type VPN interface { SetInterfaceIP(ip string) error AddPeer(pubKey []byte, publicIP string, vpnIP string) error RemovePeer(pubKey []byte) error + UpdatePeers(peers []peer.Peer) error } type stubVPN struct { @@ -57,6 +61,19 @@ func (v *stubVPN) RemovePeer(pubKey []byte) error { return v.removePeerErr } +func (v *stubVPN) UpdatePeers(peers []peer.Peer) error { + for _, peer := range peers { + peerIP, _, err := net.SplitHostPort(peer.PublicEndpoint) + if err != nil { + return err + } + if err := v.AddPeer(peer.VPNPubKey, peerIP, peer.VPNIP); err != nil { + return err + } + } + return nil +} + type stubVPNPeer struct { pubKey []byte publicIP string diff --git a/coordinator/wireguard/wireguard.go b/coordinator/wireguard/wireguard.go index a8fd604b9..1b85b3759 100644 --- a/coordinator/wireguard/wireguard.go +++ b/coordinator/wireguard/wireguard.go @@ -1,12 +1,15 @@ package wireguard import ( + "bytes" "errors" "fmt" + "io" "net" "os" "time" + "github.com/edgelesssys/constellation/coordinator/peer" "github.com/edgelesssys/constellation/coordinator/util" "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl" @@ -18,20 +21,22 @@ const ( port = 51820 ) -type Wireguard struct{} +type Wireguard struct { + client wgClient + getInterfaceIP func(string) (string, error) +} -func New() *Wireguard { - return &Wireguard{} +func New() (*Wireguard, error) { + client, err := wgctrl.New() + if err != nil { + return nil, err + } + return &Wireguard{client: client, getInterfaceIP: util.GetInterfaceIP}, nil } func (w *Wireguard) Setup(privKey []byte) ([]byte, error) { - client, err := wgctrl.New() - if err != nil { - return nil, fmt.Errorf("failed to open wgctrl: %w", err) - } - defer client.Close() - var key wgtypes.Key + var err error if len(privKey) == 0 { key, err = wgtypes.GeneratePrivateKey() } else { @@ -42,7 +47,7 @@ func (w *Wireguard) Setup(privKey []byte) ([]byte, error) { } listenPort := port - if err := client.ConfigureDevice(netInterface, wgtypes.Config{PrivateKey: &key, ListenPort: &listenPort}); err != nil { + if err := w.client.ConfigureDevice(netInterface, wgtypes.Config{PrivateKey: &key, ListenPort: &listenPort}); err != nil { return nil, prettyWgError(err) } @@ -59,7 +64,7 @@ func (w *Wireguard) GetPublicKey(privKey []byte) ([]byte, error) { } func (w *Wireguard) GetInterfaceIP() (string, error) { - return util.GetInterfaceIP(netInterface) + return w.getInterfaceIP(netInterface) } // SetInterfaceIP sets the ip interface ip. @@ -80,12 +85,6 @@ func (w *Wireguard) SetInterfaceIP(ip string) error { // AddPeer adds a new peer to a wireguard interface. func (w *Wireguard) AddPeer(pubKey []byte, publicIP string, vpnIP string) error { - client, err := wgctrl.New() - if err != nil { - return fmt.Errorf("failed to open wgctrl: %w", err) - } - defer client.Close() - _, allowedIPs, err := net.ParseCIDR(vpnIP + "/32") if err != nil { return err @@ -115,17 +114,11 @@ func (w *Wireguard) AddPeer(pubKey []byte, publicIP string, vpnIP string) error }, } - return prettyWgError(client.ConfigureDevice(netInterface, cfg)) + return prettyWgError(w.client.ConfigureDevice(netInterface, cfg)) } // RemovePeer removes a peer from the wireguard interface. func (w *Wireguard) RemovePeer(pubKey []byte) error { - client, err := wgctrl.New() - if err != nil { - return fmt.Errorf("failed to open wgctrl: %w", err) - } - defer client.Close() - key, err := wgtypes.NewKey(pubKey) if err != nil { return err @@ -133,7 +126,7 @@ func (w *Wireguard) RemovePeer(pubKey []byte) error { cfg := wgtypes.Config{Peers: []wgtypes.PeerConfig{{PublicKey: key, Remove: true}}} - return prettyWgError(client.ConfigureDevice(netInterface, cfg)) + return prettyWgError(w.client.ConfigureDevice(netInterface, cfg)) } func prettyWgError(err error) error { @@ -142,3 +135,126 @@ func prettyWgError(err error) error { } return err } + +func (w *Wireguard) UpdatePeers(peers []peer.Peer) error { + ownVPNIP, err := w.getInterfaceIP(netInterface) + if err != nil { + return fmt.Errorf("failed to obtain vpn ip: %w", err) + } + wgPeers, err := transformToWgpeer(peers, ownVPNIP) + if err != nil { + return fmt.Errorf("failed to transform peers to wireguard-peers: %w", err) + } + + deviceData, err := w.client.Device(netInterface) + if err != nil { + return fmt.Errorf("failed to obtain device data: %w", err) + } + // convert to map for easier lookup + storePeers := make(map[string]wgtypes.Peer) + for _, p := range wgPeers { + storePeers[p.AllowedIPs[0].String()] = p + } + var added []wgtypes.Peer + var removed []wgtypes.Peer + var updated []wgtypes.Peer + + for _, interfacePeer := range deviceData.Peers { + if updPeer, ok := storePeers[interfacePeer.AllowedIPs[0].String()]; ok { + if updPeer.Endpoint.String() != interfacePeer.Endpoint.String() { + updated = append(updated, updPeer) + } + if !bytes.Equal(updPeer.PublicKey[:], interfacePeer.PublicKey[:]) { + added = append(added, updPeer) + removed = append(removed, interfacePeer) + } + delete(storePeers, updPeer.AllowedIPs[0].String()) + } else { + removed = append(removed, interfacePeer) + } + } + // remaining store peers are new ones + for _, peer := range storePeers { + added = append(added, peer) + } + + keepAlive := 10 * time.Second + var newPeerConfig []wgtypes.PeerConfig + for _, peer := range removed { + newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{ + // pub Key for remove matching is enought + PublicKey: peer.PublicKey, + Remove: true, + }) + } + for _, peer := range updated { + newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{ + PublicKey: peer.PublicKey, + Remove: false, + UpdateOnly: true, + Endpoint: peer.Endpoint, + }) + } + for _, peer := range added { + newPeerConfig = append(newPeerConfig, wgtypes.PeerConfig{ + PublicKey: peer.PublicKey, + Remove: false, + UpdateOnly: false, + Endpoint: peer.Endpoint, + AllowedIPs: peer.AllowedIPs, + // needed, otherwise gRPC has problems establishing the initial connection. + PersistentKeepaliveInterval: &keepAlive, + }) + } + if len(newPeerConfig) == 0 { + return nil + } + cfg := wgtypes.Config{ + ReplacePeers: false, + Peers: newPeerConfig, + } + return prettyWgError(w.client.ConfigureDevice(netInterface, cfg)) +} + +func (w *Wireguard) Close() error { + return w.client.Close() +} + +// A wgClient is a type which can control a WireGuard device. +type wgClient interface { + io.Closer + Device(name string) (*wgtypes.Device, error) + ConfigureDevice(name string, cfg wgtypes.Config) error +} + +func transformToWgpeer(corePeers []peer.Peer, excludedIP string) ([]wgtypes.Peer, error) { + var wgPeers []wgtypes.Peer + for _, peer := range corePeers { + if peer.VPNIP == excludedIP { + continue + } + key, err := wgtypes.NewKey(peer.VPNPubKey) + if err != nil { + return nil, err + } + _, allowedIPs, err := net.ParseCIDR(peer.VPNIP + "/32") + if err != nil { + return nil, err + } + + publicIP, _, err := net.SplitHostPort(peer.PublicEndpoint) + if err != nil { + return nil, err + } + var endpoint *net.UDPAddr + if ip := net.ParseIP(publicIP); ip != nil { + endpoint = &net.UDPAddr{IP: ip, Port: port} + } + wgPeers = append(wgPeers, wgtypes.Peer{ + PublicKey: key, + Endpoint: endpoint, + AllowedIPs: []net.IPNet{*allowedIPs}, + }) + } + return wgPeers, nil +} diff --git a/coordinator/wireguard/wireguard_test.go b/coordinator/wireguard/wireguard_test.go new file mode 100644 index 000000000..a9fe3c953 --- /dev/null +++ b/coordinator/wireguard/wireguard_test.go @@ -0,0 +1,158 @@ +package wireguard + +import ( + "errors" + "testing" + + "github.com/edgelesssys/constellation/coordinator/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +func TestUpdatePeer(t *testing.T) { + requirePre := require.New(t) + + firstKey, err := wgtypes.GenerateKey() + requirePre.NoError(err) + peer1 := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKey[:]} + firstKeyUpd, err := wgtypes.GenerateKey() + requirePre.NoError(err) + peer1KeyUpd := peer.Peer{PublicEndpoint: "192.0.2.11:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKeyUpd[:]} + peer1EndpUpd := peer.Peer{PublicEndpoint: "192.0.2.110:2000", VPNIP: "192.0.2.21", VPNPubKey: firstKey[:]} + secondKey, err := wgtypes.GenerateKey() + requirePre.NoError(err) + peer2 := peer.Peer{PublicEndpoint: "192.0.2.12:2000", VPNIP: "192.0.2.22", VPNPubKey: secondKey[:]} + thirdKey, err := wgtypes.GenerateKey() + requirePre.NoError(err) + peer3 := peer.Peer{PublicEndpoint: "192.0.2.13:2000", VPNIP: "192.0.2.23", VPNPubKey: thirdKey[:]} + fourthKey, err := wgtypes.GenerateKey() + requirePre.NoError(err) + peerSelf := peer.Peer{PublicEndpoint: "192.0.2.10:2000", VPNIP: "192.0.2.20", VPNPubKey: fourthKey[:]} + + checkError := func(peers []wgtypes.Peer, err error) []wgtypes.Peer { + requirePre.NoError(err) + return peers + } + + testCases := map[string]struct { + storePeers []peer.Peer + vpnPeers []wgtypes.Peer + expectErr bool + expectedVPNPeers []wgtypes.Peer + }{ + "basic": { + storePeers: []peer.Peer{peer1, peer3}, + vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")), + expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer3}, "")), + }, + "previously empty": { + storePeers: []peer.Peer{peer1, peer2}, + expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")), + }, + "no changes": { + storePeers: []peer.Peer{peer1, peer2}, + vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")), + expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")), + }, + "key update": { + storePeers: []peer.Peer{peer1KeyUpd, peer3}, + vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")), + expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1KeyUpd, peer3}, "")), + }, + "public endpoint update": { + storePeers: []peer.Peer{peer1EndpUpd, peer3}, + vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer1, peer2}, "")), + expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer1EndpUpd, peer3}, "")), + }, + "dont add self": { + storePeers: []peer.Peer{peerSelf, peer3}, + vpnPeers: checkError(transformToWgpeer([]peer.Peer{peer2, peer3}, "")), + expectedVPNPeers: checkError(transformToWgpeer([]peer.Peer{peer3}, "")), + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + fakewg := fakewgClient{} + fakewg.devices = make(map[string]*wgtypes.Device) + wg := Wireguard{client: &fakewg, getInterfaceIP: func(s string) (string, error) { + return "192.0.2.20", nil + }} + + fakewg.devices[netInterface] = &wgtypes.Device{Peers: tc.vpnPeers} + + updateErr := wg.UpdatePeers(tc.storePeers) + + if tc.expectErr { + assert.Error(updateErr) + return + } + require.NoError(updateErr) + + assert.ElementsMatch(tc.expectedVPNPeers, fakewg.devices[netInterface].Peers) + }) + } +} + +type fakewgClient struct { + devices map[string]*wgtypes.Device +} + +func (w *fakewgClient) Device(name string) (*wgtypes.Device, error) { + if val, ok := w.devices[name]; ok { + return val, nil + } + return nil, errors.New("device does not exist") +} + +func (w *fakewgClient) ConfigureDevice(name string, cfg wgtypes.Config) error { + var newPeerList []wgtypes.Peer + var operation bool + vpnPeers := make(map[wgtypes.Key]wgtypes.Peer) + + for _, peer := range w.devices[netInterface].Peers { + vpnPeers[peer.PublicKey] = peer + } + + for _, configPeer := range cfg.Peers { + operation = false + for _, vpnPeer := range w.devices[netInterface].Peers { + // wireguard matches internally via pubkey + if vpnPeer.PublicKey == configPeer.PublicKey { + operation = true + if configPeer.Remove { + delete(vpnPeers, vpnPeer.PublicKey) + continue + } + if configPeer.UpdateOnly { + vpnPeers[vpnPeer.PublicKey] = wgtypes.Peer{ + PublicKey: vpnPeer.PublicKey, + AllowedIPs: vpnPeer.AllowedIPs, + Endpoint: configPeer.Endpoint, + } + } + } + } + if !operation { + vpnPeers[configPeer.PublicKey] = wgtypes.Peer{ + PublicKey: configPeer.PublicKey, + AllowedIPs: configPeer.AllowedIPs, + Endpoint: configPeer.Endpoint, + } + } + } + for _, peer := range vpnPeers { + newPeerList = append(newPeerList, peer) + } + w.devices[netInterface].Peers = newPeerList + + return nil +} + +func (w *fakewgClient) Close() error { + return nil +}