diff --git a/cli/cmd/init.go b/cli/cmd/init.go index f586757e4..211cb7da2 100644 --- a/cli/cmd/init.go +++ b/cli/cmd/init.go @@ -21,6 +21,7 @@ import ( "github.com/edgelesssys/constellation/internal/config" "github.com/edgelesssys/constellation/internal/state" "github.com/kr/text" + wgquick "github.com/nmiculinic/wg-quick-go" "github.com/spf13/afero" "github.com/spf13/cobra" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -45,6 +46,7 @@ func newInitCmd() *cobra.Command { // runInitialize runs the initialize command. func runInitialize(cmd *cobra.Command, args []string) error { fileHandler := file.NewHandler(afero.NewOsFs()) + vpnHandler := vpn.NewConfigHandler() devConfigName, err := cmd.Flags().GetString("dev-config") if err != nil { return err @@ -56,20 +58,19 @@ func runInitialize(cmd *cobra.Command, args []string) error { protoClient := proto.NewClient(*config.Provider.GCP.PCRs) defer protoClient.Close() - vpnClient, err := vpn.NewConfigurerWithDefaults() if err != nil { return err } // We have to parse the context separately, since cmd.Context() // returns nil during the tests otherwise. - return initialize(cmd.Context(), cmd, protoClient, vpnClient, serviceAccountClient{}, fileHandler, config, status.NewWaiter(*config.Provider.GCP.PCRs)) + return initialize(cmd.Context(), cmd, protoClient, serviceAccountClient{}, fileHandler, config, status.NewWaiter(*config.Provider.GCP.PCRs), vpnHandler) } // initialize initializes a Constellation. Coordinator instances are activated as Coordinators and will // themself activate the other peers as nodes. -func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, vpnCl vpnConfigurer, serviceAccountCr serviceAccountCreator, - fileHandler file.Handler, config *config.Config, waiter statusWaiter, +func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccountCr serviceAccountCreator, + fileHandler file.Handler, config *config.Config, waiter statusWaiter, vpnHandler vpnHandler, ) error { flagArgs, err := evalFlagArgs(cmd, fileHandler, config) if err != nil { @@ -138,12 +139,17 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, vpn return err } - if err := result.writeWGQuickFile(fileHandler, config, string(flagArgs.userPrivKey)); err != nil { + vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flagArgs.userPrivKey), result.clientVpnIP, wireguardAdminMTU) + if err != nil { + return err + } + + if err := writeWGQuickFile(fileHandler, config, vpnHandler, vpnConfig); err != nil { return fmt.Errorf("write wg-quick file: %w", err) } if flagArgs.autoconfigureWG { - if err := configureVpn(vpnCl, result.clientVpnIP, result.coordinatorPubKey, result.coordinatorPubIP, flagArgs.userPrivKey); err != nil { + if err := vpnHandler.Apply(vpnConfig); err != nil { return err } } @@ -217,14 +223,10 @@ type activationResult struct { } // writeWGQuickFile writes the wg-quick file to the default path. -func (r activationResult) writeWGQuickFile(fileHandler file.Handler, config *config.Config, clientPrivKey string) error { - wgConf, err := vpn.NewConfig(r.coordinatorPubKey, r.coordinatorPubIP, clientPrivKey) +func writeWGQuickFile(fileHandler file.Handler, config *config.Config, vpnHandler vpnHandler, vpnConfig *wgquick.Config) error { + data, err := vpnHandler.Marshal(vpnConfig) if err != nil { - return fmt.Errorf("create wg config: %w", err) - } - data, err := vpn.NewWGQuickConfig(wgConf, r.clientVpnIP, wireguardAdminMTU) - if err != nil { - return fmt.Errorf("create wg-quick config: %w", err) + return err } return fileHandler.Write(*config.WGQuickConfigPath, data, false) } @@ -327,14 +329,6 @@ func readOrGenerateVPNKey(fileHandler file.Handler, privKeyPath string) (privKey return privKey, pubKey, nil } -func configureVpn(vpnCl vpnConfigurer, clientVpnIp, coordinatorPubKey, coordinatorPublicIp string, privKey []byte) error { - err := vpnCl.Configure(clientVpnIp, coordinatorPubKey, coordinatorPublicIp, string(privKey)) - if err != nil { - return fmt.Errorf("could not configure WireGuard automatically: %w", err) - } - return nil -} - func ipsToEndpoints(ips []string, port string) []string { var endpoints []string for _, ip := range ips { diff --git a/cli/cmd/init_test.go b/cli/cmd/init_test.go index 782d51d48..3394e375e 100644 --- a/cli/cmd/init_test.go +++ b/cli/cmd/init_test.go @@ -5,7 +5,6 @@ import ( "context" "encoding/base64" "errors" - "fmt" "strconv" "strings" "testing" @@ -17,11 +16,11 @@ import ( "github.com/edgelesssys/constellation/cli/gcp" "github.com/edgelesssys/constellation/internal/config" "github.com/edgelesssys/constellation/internal/state" + wgquick "github.com/nmiculinic/wg-quick-go" "github.com/spf13/afero" "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) func TestInitArgumentValidation(t *testing.T) { @@ -112,6 +111,8 @@ func TestInitialize(t *testing.T) { serviceAccountCreator stubServiceAccountCreator waiter statusWaiter privKey string + vpnHandler vpnHandler + initVPN bool errExpected bool }{ "initialize some ec2 instances": { @@ -119,30 +120,77 @@ func TestInitialize(t *testing.T) { client: &fakeProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, - privKey: testKey, + waiter: stubStatusWaiter{}, + vpnHandler: &stubVPNHandler{}, + privKey: testKey, }, "initialize some gcp instances": { existingState: testGcpState, client: &fakeProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, - privKey: testKey, + waiter: stubStatusWaiter{}, + vpnHandler: &stubVPNHandler{}, + privKey: testKey, }, "initialize some azure instances": { existingState: testAzureState, client: &fakeProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, - privKey: testKey, + waiter: stubStatusWaiter{}, + vpnHandler: &stubVPNHandler{}, + privKey: testKey, + }, + "initialize vpn": { + existingState: testAzureState, + client: &fakeProtoClient{ + respClient: &fakeActivationRespClient{responses: testActivationResps}, + }, + waiter: stubStatusWaiter{}, + vpnHandler: &stubVPNHandler{}, + initVPN: true, + privKey: testKey, + }, + "invalid initialize vpn": { + existingState: testAzureState, + client: &fakeProtoClient{ + respClient: &fakeActivationRespClient{responses: testActivationResps}, + }, + waiter: stubStatusWaiter{}, + vpnHandler: &stubVPNHandler{applyErr: someErr}, + initVPN: true, + privKey: testKey, + errExpected: true, + }, + "invalid create vpn config": { + existingState: testAzureState, + client: &fakeProtoClient{ + respClient: &fakeActivationRespClient{responses: testActivationResps}, + }, + waiter: stubStatusWaiter{}, + vpnHandler: &stubVPNHandler{createErr: someErr}, + initVPN: true, + privKey: testKey, + errExpected: true, + }, + "invalid write vpn config": { + existingState: testAzureState, + client: &fakeProtoClient{ + respClient: &fakeActivationRespClient{responses: testActivationResps}, + }, + waiter: stubStatusWaiter{}, + vpnHandler: &stubVPNHandler{marshalErr: someErr}, + initVPN: true, + privKey: testKey, + errExpected: true, }, "no state exists": { existingState: state.ConstellationState{}, client: &stubProtoClient{}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "no instances to pick one": { @@ -153,6 +201,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "only one instance": { @@ -163,6 +212,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "public key to short": { @@ -170,6 +220,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{}, waiter: stubStatusWaiter{}, privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")), + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "public key to long": { @@ -177,6 +228,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{}, waiter: stubStatusWaiter{}, privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")), + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "public key not base64": { @@ -184,6 +236,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{}, waiter: stubStatusWaiter{}, privKey: "this is not base64 encoded", + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "fail Connect": { @@ -191,6 +244,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{connectErr: someErr}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "fail Activate": { @@ -198,6 +252,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{activateErr: someErr}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "fail respClient WriteLogStream": { @@ -205,6 +260,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "fail respClient getKubeconfig": { @@ -212,6 +268,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "fail respClient getCoordinatorVpnKey": { @@ -219,6 +276,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "fail respClient getClientVpnIp": { @@ -226,6 +284,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "fail respClient getOwnerID": { @@ -233,6 +292,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "fail respClient getClusterID": { @@ -240,6 +300,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}}, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "fail to wait for required status": { @@ -247,6 +308,7 @@ func TestInitialize(t *testing.T) { client: &stubProtoClient{}, waiter: stubStatusWaiter{waitForAllErr: someErr}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, "fail to create service account": { @@ -257,6 +319,7 @@ func TestInitialize(t *testing.T) { }, waiter: stubStatusWaiter{}, privKey: testKey, + vpnHandler: &stubVPNHandler{}, errExpected: true, }, } @@ -278,16 +341,21 @@ func TestInitialize(t *testing.T) { // Write key file to filesystem and set path in flag. require.NoError(afero.Afero{Fs: fs}.WriteFile("privK", []byte(tc.privKey), 0o600)) require.NoError(cmd.Flags().Set("privatekey", "privK")) + if tc.initVPN { + require.NoError(cmd.Flags().Set("wg-autoconfig", "true")) + } + ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() - err := initialize(ctx, cmd, tc.client, &dummyVPNConfigurer{}, &tc.serviceAccountCreator, fileHandler, config, tc.waiter) + err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, tc.vpnHandler) if tc.errExpected { assert.Error(err) } else { require.NoError(err) + assert.Equal(tc.initVPN, tc.vpnHandler.(*stubVPNHandler).configured) assert.Contains(out.String(), "192.0.2.2") assert.Contains(out.String(), "ownerID") assert.Contains(out.String(), "clusterID") @@ -296,21 +364,6 @@ func TestInitialize(t *testing.T) { } } -func TestConfigureVPN(t *testing.T) { - assert := assert.New(t) - - key := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))) - ip := "192.0.2.1" - someErr := errors.New("failed") - - configurer := stubVPNConfigurer{} - assert.NoError(configureVpn(&configurer, ip, string(key), ip, key)) - assert.True(configurer.configured) - - configurer = stubVPNConfigurer{configureErr: someErr} - assert.Error(configureVpn(&configurer, ip, string(key), ip, key)) -} - func TestWriteOutput(t *testing.T) { assert := assert.New(t) @@ -643,6 +696,7 @@ func TestAutoscaleFlag(t *testing.T) { cmd.SetErr(&errOut) fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) + vpnHandler := stubVPNHandler{} require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false)) // Write key file to filesystem and set path in flag. @@ -652,7 +706,7 @@ func TestAutoscaleFlag(t *testing.T) { require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag))) ctx := context.Background() - require.NoError(initialize(ctx, cmd, tc.client, &dummyVPNConfigurer{}, &tc.serviceAccountCreator, fileHandler, config, tc.waiter)) + require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, &vpnHandler)) if tc.autoscaleFlag { assert.Len(tc.client.activateAutoscalingNodeGroups, 1) } else { @@ -663,52 +717,29 @@ func TestAutoscaleFlag(t *testing.T) { } func TestWriteWGQuickFile(t *testing.T) { - require := require.New(t) - - testKey, err := wgtypes.GeneratePrivateKey() - require.NoError(err) - testCases := map[string]struct { - coordinatorPubKey string - coordinatorPubIP string - clientVpnIp string - fileHandler file.Handler - config *config.Config - clientPrivKey string - wantErr bool + fileHandler file.Handler + config *config.Config + vpnHandler *stubVPNHandler + vpnConfig *wgquick.Config + wantErr bool }{ "write wg quick file": { - coordinatorPubKey: testKey.PublicKey().String(), - coordinatorPubIP: "192.0.2.1", - clientVpnIp: "192.0.2.2", - fileHandler: file.NewHandler(afero.NewMemMapFs()), - config: &config.Config{WGQuickConfigPath: func(s string) *string { return &s }("a.conf")}, - clientPrivKey: testKey.String(), + fileHandler: file.NewHandler(afero.NewMemMapFs()), + config: &config.Config{WGQuickConfigPath: func(s string) *string { return &s }("a.conf")}, + vpnHandler: &stubVPNHandler{marshalRes: "config"}, }, - "invalid coordinator public key": { - coordinatorPubIP: "192.0.2.1", - clientVpnIp: "192.0.2.2", - fileHandler: file.NewHandler(afero.NewMemMapFs()), - config: &config.Config{WGQuickConfigPath: func(s string) *string { return &s }("a.conf")}, - clientPrivKey: testKey.String(), - wantErr: true, - }, - "invalid client vpn ip": { - coordinatorPubKey: testKey.PublicKey().String(), - coordinatorPubIP: "192.0.2.1", - fileHandler: file.NewHandler(afero.NewMemMapFs()), - config: &config.Config{WGQuickConfigPath: func(s string) *string { return &s }("a.conf")}, - clientPrivKey: testKey.String(), - wantErr: true, + "marshal failed": { + fileHandler: file.NewHandler(afero.NewMemMapFs()), + config: &config.Config{WGQuickConfigPath: func(s string) *string { return &s }("a.conf")}, + vpnHandler: &stubVPNHandler{marshalErr: errors.New("some err")}, + wantErr: true, }, "write fails": { - coordinatorPubKey: testKey.PublicKey().String(), - coordinatorPubIP: "192.0.2.1", - clientVpnIp: "192.0.2.2", - fileHandler: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())), - config: &config.Config{WGQuickConfigPath: func(s string) *string { return &s }("a.conf")}, - clientPrivKey: testKey.String(), - wantErr: true, + fileHandler: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())), + config: &config.Config{WGQuickConfigPath: func(s string) *string { return &s }("a.conf")}, + vpnHandler: &stubVPNHandler{marshalRes: "config"}, + wantErr: true, }, } @@ -716,12 +747,7 @@ func TestWriteWGQuickFile(t *testing.T) { t.Run(name, func(t *testing.T) { assert := assert.New(t) - result := activationResult{ - coordinatorPubKey: tc.coordinatorPubKey, - coordinatorPubIP: tc.coordinatorPubIP, - clientVpnIP: tc.clientVpnIp, - } - err := result.writeWGQuickFile(tc.fileHandler, tc.config, tc.clientPrivKey) + err := writeWGQuickFile(tc.fileHandler, tc.config, tc.vpnHandler, tc.vpnConfig) if tc.wantErr { assert.Error(err) @@ -729,7 +755,7 @@ func TestWriteWGQuickFile(t *testing.T) { assert.NoError(err) file, err := tc.fileHandler.Read(*tc.config.WGQuickConfigPath) assert.NoError(err) - assert.Contains(string(file), fmt.Sprint("MTU = ", wireguardAdminMTU)) + assert.Contains(string(file), tc.vpnHandler.marshalRes) } }) } diff --git a/cli/cmd/vpnconfig.go b/cli/cmd/vpnconfig.go new file mode 100644 index 000000000..4a49e699f --- /dev/null +++ b/cli/cmd/vpnconfig.go @@ -0,0 +1,9 @@ +package cmd + +import wgquick "github.com/nmiculinic/wg-quick-go" + +type vpnHandler interface { + Create(coordinatorPubKey string, coordinatorPubIP string, clientPrivKey string, clientVPNIP string, mtu int) (*wgquick.Config, error) + Apply(*wgquick.Config) error + Marshal(*wgquick.Config) ([]byte, error) +} diff --git a/cli/cmd/vpnconfig_test.go b/cli/cmd/vpnconfig_test.go new file mode 100644 index 000000000..147e5821e --- /dev/null +++ b/cli/cmd/vpnconfig_test.go @@ -0,0 +1,25 @@ +package cmd + +import wgquick "github.com/nmiculinic/wg-quick-go" + +type stubVPNHandler struct { + configured bool + marshalRes string + + createErr error + applyErr error + marshalErr error +} + +func (c *stubVPNHandler) Create(coordinatorPubKey string, coordinatorPubIP string, clientPrivKey string, clientVPNIP string, mtu int) (*wgquick.Config, error) { + return &wgquick.Config{}, c.createErr +} + +func (c *stubVPNHandler) Apply(*wgquick.Config) error { + c.configured = true + return c.applyErr +} + +func (c *stubVPNHandler) Marshal(*wgquick.Config) ([]byte, error) { + return []byte(c.marshalRes), c.marshalErr +} diff --git a/cli/cmd/vpnconfigurer.go b/cli/cmd/vpnconfigurer.go deleted file mode 100644 index 30484b525..000000000 --- a/cli/cmd/vpnconfigurer.go +++ /dev/null @@ -1,5 +0,0 @@ -package cmd - -type vpnConfigurer interface { - Configure(clientVpnIp string, coordinatorPubKey string, coordinatorPubIP string, clientPrivKey string) error -} diff --git a/cli/cmd/vpnconfigurer_test.go b/cli/cmd/vpnconfigurer_test.go deleted file mode 100644 index 8eabe0150..000000000 --- a/cli/cmd/vpnconfigurer_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package cmd - -type stubVPNConfigurer struct { - configured bool - configureErr error -} - -func (c *stubVPNConfigurer) Configure(clientVpnIp, coordinatorPubKey, coordinatorPubIP, clientPrivKey string) error { - c.configured = true - return c.configureErr -} - -type dummyVPNConfigurer struct{} - -func (c *dummyVPNConfigurer) Configure(clientVpnIp, coordinatorPubKey, coordinatorPubIP, clientPrivKey string) error { - panic("dummy doesn't implement this function") -} diff --git a/cli/vpn/vpn.go b/cli/vpn/vpn.go index c34373d79..b0224da74 100644 --- a/cli/vpn/vpn.go +++ b/cli/vpn/vpn.go @@ -6,8 +6,6 @@ import ( "time" wgquick "github.com/nmiculinic/wg-quick-go" - "github.com/vishvananda/netlink" - "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -16,98 +14,34 @@ const ( wireguardPort = 51820 ) -type vpn interface { - ConfigureDevice(name string, cfg wgtypes.Config) error +type ConfigHandler struct { + up func(cfg *wgquick.Config, iface string) error } -type networkLink interface { - LinkAdd(link netlink.Link) error - LinkByName(name string) (netlink.Link, error) - ParseAddr(s string) (*netlink.Addr, error) - AddrAdd(link netlink.Link, addr *netlink.Addr) error - LinkSetUp(link netlink.Link) error +func NewConfigHandler() *ConfigHandler { + return &ConfigHandler{up: wgquick.Up} } -type netLink struct{} - -func newNetLink() *netLink { - return &netLink{} +func (h *ConfigHandler) Create(coordinatorPubKey, coordinatorPubIP, clientPrivKey, clientVPNIP string, mtu int) (*wgquick.Config, error) { + return NewWGQuickConfig(coordinatorPubKey, coordinatorPubIP, clientPrivKey, clientVPNIP, mtu) } -func (n *netLink) LinkAdd(link netlink.Link) error { - return netlink.LinkAdd(link) +// Apply applies the generated WireGuard quick config. +func (h *ConfigHandler) Apply(conf *wgquick.Config) error { + return h.up(conf, interfaceName) } -func (n *netLink) LinkByName(name string) (netlink.Link, error) { - return netlink.LinkByName(name) -} - -func (n *netLink) ParseAddr(s string) (*netlink.Addr, error) { - return netlink.ParseAddr(s) -} - -func (n *netLink) AddrAdd(link netlink.Link, addr *netlink.Addr) error { - return netlink.AddrAdd(link, addr) -} - -func (n *netLink) LinkSetUp(link netlink.Link) error { - return netlink.LinkSetUp(link) -} - -type Configurer struct { - netLink networkLink - vpn vpn -} - -// NewConfigurerWithDefaults creates a new vpn client. -func NewConfigurerWithDefaults() (*Configurer, error) { - vpn, err := wgctrl.New() +// GetBytes returns the the bytes of the config. +func (h *ConfigHandler) Marshal(conf *wgquick.Config) ([]byte, error) { + data, err := conf.MarshalText() if err != nil { - return nil, err + return nil, fmt.Errorf("marshal wg-quick config: %w", err) } - return &Configurer{netLink: newNetLink(), vpn: vpn}, nil + return data, nil } -// NewConfigurer creates a new vpn client with the provided -// network link and vpn interface. -func NewConfigurer(netLink networkLink, vpn vpn) (*Configurer, error) { - return &Configurer{netLink: netLink, vpn: vpn}, nil -} - -// Configure configures a WireGuard interface -// WireGuard will listen on its default port. -// The peer must have the IP 10.118.0.1 in the vpn. -func (c *Configurer) Configure(clientVpnIp, coordinatorPubKey, coordinatorPubIP, clientPrivKey string) error { - wgLink := &netlink.Wireguard{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} - if err := c.netLink.LinkAdd(wgLink); err != nil { - return err - } - - link, err := c.netLink.LinkByName(interfaceName) - if err != nil { - return err - } - addr, err := c.netLink.ParseAddr(clientVpnIp + "/16") - if err != nil { - return err - } - if err := c.netLink.AddrAdd(link, addr); err != nil { - return err - } - if err := c.netLink.LinkSetUp(link); err != nil { - return err - } - - config, err := NewConfig(coordinatorPubKey, coordinatorPubIP, clientPrivKey) - if err != nil { - return err - } - - return c.vpn.ConfigureDevice(interfaceName, config) -} - -// NewConfig creates a new WireGuard configuration. -func NewConfig(coordinatorPubKey, coordinatorPubIP, clientPrivKey string) (wgtypes.Config, error) { +// newConfig creates a new WireGuard configuration. +func newConfig(coordinatorPubKey, coordinatorPubIP, clientPrivKey string) (wgtypes.Config, error) { _, allowedIPs, err := net.ParseCIDR("10.118.0.1/32") if err != nil { return wgtypes.Config{}, fmt.Errorf("parsing CIDR: %w", err) @@ -148,7 +82,12 @@ func NewConfig(coordinatorPubKey, coordinatorPubIP, clientPrivKey string) (wgtyp } // NewWGQuickConfig create a new WireGuard wg-quick configuration file and mashals it to bytes. -func NewWGQuickConfig(config wgtypes.Config, clientVPNIP string, mtu int) ([]byte, error) { +func NewWGQuickConfig(coordinatorPubKey, coordinatorPubIP, clientPrivKey, clientVPNIP string, mtu int) (*wgquick.Config, error) { + config, err := newConfig(coordinatorPubKey, coordinatorPubIP, clientPrivKey) + if err != nil { + return nil, err + } + clientIP := net.ParseIP(clientVPNIP) if clientIP == nil { return nil, fmt.Errorf("invalid client vpn ip '%s'", clientVPNIP) @@ -158,9 +97,5 @@ func NewWGQuickConfig(config wgtypes.Config, clientVPNIP string, mtu int) ([]byt Address: []net.IPNet{{IP: clientIP, Mask: []byte{255, 255, 0, 0}}}, MTU: mtu, } - data, err := quickfile.MarshalText() - if err != nil { - return nil, fmt.Errorf("marshal wg-quick config: %w", err) - } - return data, nil + return &quickfile, nil } diff --git a/cli/vpn/vpn_test.go b/cli/vpn/vpn_test.go index f5703e18c..a3d5620df 100644 --- a/cli/vpn/vpn_test.go +++ b/cli/vpn/vpn_test.go @@ -1,136 +1,55 @@ package vpn import ( - "fmt" - "net" + "errors" "testing" wgquick "github.com/nmiculinic/wg-quick-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type stubNetworkLink struct { - link netlink.Link - addr string - up bool -} - -func newStubNetworkLink() *stubNetworkLink { - return &stubNetworkLink{} -} - -func (s *stubNetworkLink) LinkAdd(link netlink.Link) error { - s.link = link - return nil -} - -func (s *stubNetworkLink) LinkByName(name string) (netlink.Link, error) { - if name != s.link.Attrs().Name { - return nil, fmt.Errorf("could not find interface with name %v", name) - } - return s.link, nil -} - -func (s *stubNetworkLink) ParseAddr(addr string) (*netlink.Addr, error) { - return netlink.ParseAddr(addr) -} - -func (s *stubNetworkLink) AddrAdd(link netlink.Link, addr *netlink.Addr) error { - if link.Attrs().Name != s.link.Attrs().Name { - return fmt.Errorf("could not find interface with name %v", link.Attrs().Name) - } - s.addr = addr.IP.String() - return nil -} - -func (s *stubNetworkLink) LinkSetUp(link netlink.Link) error { - if link.Attrs().Name != s.link.Attrs().Name { - return fmt.Errorf("could not find interface with name %v", link.Attrs().Name) - } - s.up = true - return nil -} - -type stubVPN struct { - name string - config wgtypes.Config -} - -func newStubVPN() *stubVPN { - return &stubVPN{} -} - -func (s *stubVPN) ConfigureDevice(name string, cfg wgtypes.Config) error { - s.name = name - s.config = cfg - return nil -} - -func TestConfigurer(t *testing.T) { - assert := assert.New(t) - require := require.New(t) - - link := newStubNetworkLink() - vpn := newStubVPN() - client, err := NewConfigurer(link, vpn) - require.NoError(err) - coordinatorPubKey, err := wgtypes.GenerateKey() - require.NoError(err) - clientPrivKey, err := wgtypes.GenerateKey() - require.NoError(err) - clientVpnIp := "192.0.2.1" - coordinatorPubIp := "192.0.2.2" - assert.NoError(client.Configure(clientVpnIp, coordinatorPubKey.String(), coordinatorPubIp, clientPrivKey.String())) - - // assert expected interface - assert.Equal(interfaceName, link.link.Attrs().Name) - assert.NotNil(link.addr) - assert.True(link.up) - - // assert vpn config - config := client.vpn.(*stubVPN).config - assert.Equal(wireguardPort, *config.ListenPort) - assert.Equal(clientPrivKey, *config.PrivateKey) - assert.Less(0, len(config.Peers)) - assert.Equal(coordinatorPubKey, config.Peers[0].PublicKey) - assert.Equal(net.JoinHostPort(coordinatorPubIp, "51820"), config.Peers[0].Endpoint.String()) - assert.Equal("10.118.0.1/32", config.Peers[0].AllowedIPs[0].String()) -} - -func TestNewConfig(t *testing.T) { +func TestCreate(t *testing.T) { require := require.New(t) testKey, err := wgtypes.GeneratePrivateKey() require.NoError(err) testCases := map[string]struct { - coordinatorPubKey wgtypes.Key + coordinatorPubKey string coordinatorPubIP string - clientPrivKey wgtypes.Key + clientPrivKey string + clientVPNIP string wantErr bool }{ - "valid": { - coordinatorPubKey: testKey.PublicKey(), + "valid config": { + clientPrivKey: testKey.String(), + clientVPNIP: "192.0.2.1", + coordinatorPubKey: testKey.PublicKey().String(), coordinatorPubIP: "192.0.2.1", - clientPrivKey: testKey, }, - "empty coordinator pub ip": { - coordinatorPubKey: testKey.PublicKey(), - clientPrivKey: testKey, + "valid missing endpoint": { + clientPrivKey: testKey.String(), + clientVPNIP: "192.0.2.1", + coordinatorPubKey: testKey.PublicKey().String(), }, - "empty coordinator public key": { - coordinatorPubKey: wgtypes.Key{}, + "invalid coordinator pub key": { + clientPrivKey: testKey.String(), + clientVPNIP: "192.0.2.1", + coordinatorPubIP: "192.0.2.1", + wantErr: true, + }, + "invalid client priv key": { + clientVPNIP: "192.0.2.1", + coordinatorPubKey: testKey.PublicKey().String(), coordinatorPubIP: "192.0.2.1", - clientPrivKey: testKey, wantErr: true, }, - "empty client private key": { - coordinatorPubKey: testKey.PublicKey(), + "invalid client ip": { + clientPrivKey: testKey.String(), + coordinatorPubKey: testKey.PublicKey().String(), coordinatorPubIP: "192.0.2.1", - clientPrivKey: wgtypes.Key{}, wantErr: true, }, } @@ -139,51 +58,42 @@ func TestNewConfig(t *testing.T) { t.Run(name, func(t *testing.T) { assert := assert.New(t) - var coordinatorPubKeyStr, clientPrivKeyStr string - if tc.coordinatorPubKey != (wgtypes.Key{}) { - coordinatorPubKeyStr = tc.coordinatorPubKey.String() - } - if tc.clientPrivKey != (wgtypes.Key{}) { - clientPrivKeyStr = tc.clientPrivKey.String() - } - config, err := NewConfig(coordinatorPubKeyStr, tc.coordinatorPubIP, clientPrivKeyStr) + handler := &ConfigHandler{} + const mtu = 2 + + quickConfig, err := handler.Create(tc.coordinatorPubKey, tc.coordinatorPubIP, tc.clientPrivKey, tc.clientVPNIP, mtu) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) - assert.Equal(tc.coordinatorPubKey, config.Peers[0].PublicKey) - assert.Equal(tc.clientPrivKey, *config.PrivateKey) + assert.Equal(tc.clientPrivKey, quickConfig.PrivateKey.String()) + assert.Equal(tc.clientVPNIP, quickConfig.Address[0].IP.String()) + + if tc.coordinatorPubIP != "" { + assert.Equal(tc.coordinatorPubIP, quickConfig.Peers[0].Endpoint.IP.String()) + } + assert.Equal(mtu, quickConfig.MTU) } }) } } -func TestNewWGQuickConfig(t *testing.T) { - require := require.New(t) - +func TestApply(t *testing.T) { testKey, err := wgtypes.GeneratePrivateKey() - require.NoError(err) - testConfig := wgtypes.Config{ - PrivateKey: &testKey, - } + require.NoError(t, err) testCases := map[string]struct { - config wgtypes.Config - clientVPNIP string + quickConfig *wgquick.Config + upErr error wantErr bool }{ - "valid config": { - clientVPNIP: "192.0.2.1", - config: testConfig, + "valid": { + quickConfig: &wgquick.Config{Config: wgtypes.Config{PrivateKey: &testKey}}, }, - "empty client vpn ip": { - config: testConfig, - wantErr: true, - }, - "config without private key": { - clientVPNIP: "192.0.2.1", - config: wgtypes.Config{}, + "invalid apply": { + quickConfig: &wgquick.Config{Config: wgtypes.Config{PrivateKey: &testKey}}, + upErr: errors.New("some err"), wantErr: true, }, } @@ -192,18 +102,58 @@ func TestNewWGQuickConfig(t *testing.T) { t.Run(name, func(t *testing.T) { assert := assert.New(t) - const mtu = 2 - quickFile, err := NewWGQuickConfig(tc.config, tc.clientVPNIP, mtu) + var ifaceSpy string + var cfgSpy *wgquick.Config + upSpy := func(cfg *wgquick.Config, iface string) error { + ifaceSpy = iface + cfgSpy = cfg + return tc.upErr + } + + handler := &ConfigHandler{up: upSpy} + + err := handler.Apply(tc.quickConfig) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) - var quickConfig wgquick.Config - assert.NoError(quickConfig.UnmarshalText(quickFile)) - assert.Equal(tc.config.PrivateKey, quickConfig.PrivateKey) - assert.Equal(tc.clientVPNIP, quickConfig.Address[0].IP.String()) - assert.Equal(mtu, quickConfig.MTU) + assert.Equal(interfaceName, ifaceSpy) + assert.Equal(tc.quickConfig, cfgSpy) + } + }) + } +} + +func TestMarshal(t *testing.T) { + require := require.New(t) + + testKey, err := wgtypes.GeneratePrivateKey() + require.NoError(err) + + testCases := map[string]struct { + quickConfig *wgquick.Config + wantErr bool + }{ + "valid": { + quickConfig: &wgquick.Config{Config: wgtypes.Config{PrivateKey: &testKey}}, + }, + "invalid config": { + quickConfig: &wgquick.Config{Config: wgtypes.Config{}}, + wantErr: true, + }, + } + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + handler := &ConfigHandler{} + + data, err := handler.Marshal(tc.quickConfig) + if tc.wantErr { + assert.Error(err) + } else { + assert.NoError(err) + assert.Greater(len(data), 0) } }) } diff --git a/go.mod b/go.mod index de992f2b4..0c9d75594 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ replace ( k8s.io/sample-controller => k8s.io/sample-controller v0.23.1 ) -replace github.com/nmiculinic/wg-quick-go v0.1.3 => github.com/katexochen/wg-quick-go v0.1.3-beta.0 +replace github.com/nmiculinic/wg-quick-go v0.1.3 => github.com/katexochen/wg-quick-go v0.1.3-beta.1 require ( cloud.google.com/go/compute v1.5.0 diff --git a/go.sum b/go.sum index 1e51ad45b..a3328d31b 100644 --- a/go.sum +++ b/go.sum @@ -966,8 +966,8 @@ github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8 github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= github.com/karrick/godirwalk v1.16.1/go.mod h1:j4mkqPuvaLI8mp1DroR3P6ad7cyYd4c1qeJ3RV7ULlk= -github.com/katexochen/wg-quick-go v0.1.3-beta.0 h1:3udSRb7g2RdXWlFxaOPhVRdkY7uAkGy+30pGo8+5pKo= -github.com/katexochen/wg-quick-go v0.1.3-beta.0/go.mod h1:m3npTHwS7XHeXPF1XbUb/XhHURVZCXMpurHabylSA4I= +github.com/katexochen/wg-quick-go v0.1.3-beta.1 h1:XQmfAGvw/uqYPKOElq9rWxZEKGv8NzaVzkgYJrkx9g8= +github.com/katexochen/wg-quick-go v0.1.3-beta.1/go.mod h1:xfNl8yinhUfZliaa9e6eIJDPBG/r+F/BoRmzcTSz4cA= github.com/kevinburke/ssh_config v0.0.0-20190725054713-01f96b0aa0cd/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= @@ -1322,7 +1322,6 @@ github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeV github.com/sirupsen/logrus v1.0.4-0.20170822132746-89742aefa4b2/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v1.0.6/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= -github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=